Merge remote-tracking branch 'upstream/debug'

This commit is contained in:
tcmofashi
2025-03-03 08:42:51 +08:00
25 changed files with 633 additions and 384 deletions

24
.env.prod Normal file
View File

@@ -0,0 +1,24 @@
HOST=127.0.0.1
PORT=8080
COMMAND_START=["/"]
# 插件配置
PLUGINS=["src2.plugins.chat"]
# 默认配置
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
#key and url
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

1
.gitignore vendored
View File

@@ -14,6 +14,7 @@ reasoning_content.bat
reasoning_window.bat
queue_update.txt
memory_graph.gml
.env.dev
# Byte-compiled / optimized / DLL files

View File

@@ -28,7 +28,7 @@
> ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语
关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言!
## 开发计划TODOLIST
@@ -41,16 +41,13 @@
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
- 改进表情包发送逻辑l
<div align="center">
<img src="docs/qq.png" width="300" />
</div>
## 📚 详细文档
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!)
### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!)
#### Linux 使用 Docker Compose 部署
获取项目根目录中的```docker-compose.yml```文件,运行以下命令

60
bot.py
View File

@@ -1,14 +1,62 @@
import os
import nonebot
from nonebot.adapters.onebot.v11 import Adapter
from dotenv import load_dotenv
from loguru import logger
'''彩蛋'''
from colorama import init, Fore
init()
text = "多年以后面对行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
rainbow_text = ""
for i, char in enumerate(text):
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
print(rainbow_text)
'''彩蛋'''
# 首先加载基础环境变量
if os.path.exists(".env"):
load_dotenv(".env")
logger.success("成功加载基础环境变量配置")
else:
logger.error("基础环境变量配置文件 .env 不存在")
exit(1)
# 根据 ENVIRONMENT 加载对应的环境配置
env = os.getenv("ENVIRONMENT")
env_file = f".env.{env}"
if env_file == ".env.dev" and os.path.exists(env_file):
logger.success("加载开发环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
elif env_file == ".env.prod" and os.path.exists(env_file):
logger.success("加载环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
else:
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
exit(1)
# 初始化 NoneBot
nonebot.init(
# napcat 默认使用 8080 端口
websocket_port=8080,
# 设置日志级别
# 从环境变量中读取配置
websocket_port=os.getenv("PORT", 8080),
host=os.getenv("HOST", "127.0.0.1"),
log_level="INFO",
# 设置超级用户
superusers={"你的QQ号"}
# 添加自定义配置
mongodb_host=os.getenv("MONGODB_HOST", "127.0.0.1"),
mongodb_port=os.getenv("MONGODB_PORT", 27017),
database_name=os.getenv("DATABASE_NAME", "MegBot"),
mongodb_username=os.getenv("MONGODB_USERNAME", ""),
mongodb_password=os.getenv("MONGODB_PASSWORD", ""),
mongodb_auth_source=os.getenv("MONGODB_AUTH_SOURCE", ""),
# API相关配置
chat_any_where_key=os.getenv("CHAT_ANY_WHERE_KEY", ""),
siliconflow_key=os.getenv("SILICONFLOW_KEY", ""),
chat_any_where_base_url=os.getenv("CHAT_ANY_WHERE_BASE_URL", "https://api.chatanywhere.tech/v1"),
siliconflow_base_url=os.getenv("SILICONFLOW_BASE_URL", "https://api.siliconflow.cn/v1/"),
deep_seek_key=os.getenv("DEEP_SEEK_KEY", ""),
deep_seek_base_url=os.getenv("DEEP_SEEK_BASE_URL", "https://api.deepseek.com/v1"),
# 插件配置
plugins=os.getenv("PLUGINS", ["src2.plugins.chat"])
)
# 注册适配器

45
config/bot_config_toml Normal file
View File

@@ -0,0 +1,45 @@
[bot]
qq = 123456 #填入你的机器人QQ
nickname = "麦麦" #你希望bot被称呼的名字
[message]
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
emoji_chance = 0.2 # 麦麦使用表情包的概率
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow" # 选择大模型API可选值为siliconflow,deepseek建议使用siliconflow因为识图api目前只支持siliconflow的deepseek-vl2模型
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
[memory]
build_memory_interval = 300 # 记忆构建间隔
[others]
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
[groups]
talk_allowed = [
123456,12345678
] #可以回复消息的群
talk_frequency_down = [
123456,12345678
] #降低回复频率的群
ban_user_id = [
123456,12345678
] #禁止回复消息的QQ号

View File

@@ -2,4 +2,5 @@ call conda activate niuniu
cd .
REM 执行nb run命令
nb run
nb run
pause

View File

@@ -331,9 +331,12 @@ class ReasoningGUI:
def main():
"""主函数"""
Database.initialize(
"127.0.0.1",
27017,
"MegBot"
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
app = ReasoningGUI()

View File

@@ -11,13 +11,18 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager
# 获取驱动器
driver = get_driver()
config = driver.config
Database.initialize(
global_config.MONGODB_HOST,
global_config.MONGODB_PORT,
global_config.DATABASE_NAME
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
)
print("\033[1;32m[初始化数据库完成]\033[0m")
@@ -34,7 +39,7 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例
chat_bot = ChatBot(global_config)
chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message()
# 创建定时任务

View File

@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
from ..memory_system.memory import memory_graph
class ChatBot:
def __init__(self, config: BotConfig):
self.config = config
def __init__(self):
self.storage = MessageStorage()
self.gpt = LLMResponseGenerator(config)
self.gpt = LLMResponseGenerator()
self.bot = None # bot 实例引用
self._started = False
@@ -39,11 +38,11 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
if event.group_id not in self.config.talk_allowed_groups:
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
if event.user_id in self.config.ban_user_id:
if event.user_id in global_config.ban_user_id:
return
# 打印原始消息内容
@@ -121,7 +120,7 @@ class ChatBot:
event.group_id,
topic[0] if topic else None,
is_mentioned,
self.config,
global_config,
event.user_id,
message.is_emoji,
interested_rate
@@ -147,10 +146,14 @@ class ChatBot:
thinking_message.interupt=True
# 如果生成了回复,发送并记录
'''
生成回复后的内容
'''
if response:
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
accu_typing_time = 0
for msg in response:
print(f"当前消息: {msg}")
@@ -161,7 +164,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
user_id=self.config.BOT_QQ,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
@@ -178,7 +181,7 @@ class ChatBot:
bot_response_time = tinking_time_point
if random() < self.config.emoji_chance:
if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path)
@@ -190,7 +193,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
user_id=self.config.BOT_QQ,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,

View File

@@ -7,22 +7,13 @@ import configparser
import tomli
import sys
from loguru import logger
from dotenv import load_dotenv
from nonebot import get_driver
@dataclass
class BotConfig:
"""机器人配置类"""
# 基础配置
MONGODB_HOST: str = "mongodb"
MONGODB_PORT: int = 27017
DATABASE_NAME: str = "MegBot"
MONGODB_USERNAME: Optional[str] = None # 默认空值
MONGODB_PASSWORD: Optional[str] = None # 默认空值
MONGODB_AUTH_SOURCE: Optional[str] = None # 默认空值
"""机器人配置类"""
BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None
@@ -75,17 +66,7 @@ class BotConfig:
if os.path.exists(config_path):
with open(config_path, "rb") as f:
toml_dict = tomli.load(f)
# 数据库配置
if "database" in toml_dict:
db_config = toml_dict["database"]
config.MONGODB_HOST = db_config.get("host", config.MONGODB_HOST)
config.MONGODB_PORT = db_config.get("port", config.MONGODB_PORT)
config.DATABASE_NAME = db_config.get("name", config.DATABASE_NAME)
config.MONGODB_USERNAME = db_config.get("username", config.MONGODB_USERNAME) or None # 空字符串转为 None
config.MONGODB_PASSWORD = db_config.get("password", config.MONGODB_PASSWORD) or None # 空字符串转为 None
config.MONGODB_AUTH_SOURCE = db_config.get("auth_source", config.MONGODB_AUTH_SOURCE) or None # 空字符串转为 None
if "emoji" in toml_dict:
emoji_config = toml_dict["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
@@ -146,20 +127,10 @@ class BotConfig:
# 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path)
env_path = os.path.join(config_dir, '.env')
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
global_config = BotConfig.load_config(config_path=bot_config_path)
# 加载环境变量
logger.info(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
logger.success("成功加载环境变量配置")
else:
logger.error(f"环境变量配置文件不存在: {env_path}")
@dataclass
class LLMConfig:
"""机器人配置类"""
@@ -170,10 +141,11 @@ class LLMConfig:
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:

View File

@@ -7,7 +7,7 @@ from PIL import Image
import os
from random import random
from nonebot.adapters.onebot.v11 import Bot
from .config import global_config, llm_config
from .config import global_config
import time
import asyncio
from .utils_image import storage_image,storage_emoji
@@ -16,6 +16,10 @@ from .utils_user import get_user_nickname
#包含CQ码类
import urllib3
from urllib3.util import create_urllib3_context
from nonebot import get_driver
driver = get_driver()
config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
ctx = create_urllib3_context()
@@ -179,7 +183,7 @@ class CQCode:
"""调用AI接口获取表情包描述"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
"Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -206,7 +210,7 @@ class CQCode:
}
response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload,
timeout=30
@@ -224,7 +228,7 @@ class CQCode:
"""调用AI接口获取普通图片描述"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
"Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -251,7 +255,7 @@ class CQCode:
}
response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload,
timeout=30

View File

@@ -10,10 +10,14 @@ import hashlib
from datetime import datetime
import base64
import shutil
from .config import global_config, llm_config
import asyncio
import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
class EmojiManager:
_instance = None
@@ -93,7 +97,7 @@ class EmojiManager:
# 准备请求数据
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
"Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -115,7 +119,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload
) as response:
@@ -249,7 +253,7 @@ class EmojiManager:
async with aiohttp.ClientSession() as session:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
"Authorization": f"Bearer {config.siliconflow_key}"
}
payload = {
@@ -276,7 +280,7 @@ class EmojiManager:
}
async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
f"{config.siliconflow_base_url}chat/completions",
headers=headers,
json=payload
) as response:

View File

@@ -1,40 +1,34 @@
from typing import Dict, Any, List, Optional, Union, Tuple
from openai import OpenAI
import asyncio
import requests
from functools import partial
from .message import Message
from .config import BotConfig, global_config
from .config import global_config
from ...common.database import Database
import random
import time
import os
import numpy as np
from dotenv import load_dotenv
from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .prompt_builder import prompt_builder
from .config import llm_config, global_config
from .config import global_config
from .utils import process_llm_response
from nonebot import get_driver
driver = get_driver()
config = driver.config
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
load_dotenv(os.path.join(root_dir, '.env'))
class LLMResponseGenerator:
def __init__(self, config: BotConfig):
self.config = config
if self.config.API_USING == "siliconflow":
def __init__(self):
if global_config.API_USING == "siliconflow":
self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY,
base_url=llm_config.SILICONFLOW_BASE_URL
api_key=config.siliconflow_key,
base_url=config.siliconflow_base_url
)
elif self.config.API_USING == "deepseek":
elif global_config.API_USING == "deepseek":
self.client = OpenAI(
api_key=llm_config.DEEP_SEEK_API_KEY,
base_url=llm_config.DEEP_SEEK_BASE_URL
api_key=config.deep_seek_key,
base_url=config.deep_seek_base_url
)
self.db = Database.get_instance()
@@ -58,6 +52,7 @@ class LLMResponseGenerator:
else:
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
if self.current_model_type == 'r1':
model_response = await self._generate_r1_response(message)
@@ -96,8 +91,9 @@ class LLMResponseGenerator:
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
else:
relationship_value = 0.0
# 构建prompt
''' 构建prompt '''
prompt,prompt_check = prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
@@ -105,6 +101,7 @@ class LLMResponseGenerator:
group_id=message.group_id
)
# 设置默认参数
default_params = {
"model": model_name,
@@ -121,11 +118,28 @@ class LLMResponseGenerator:
"max_tokens": 2048,
"temperature": 0.7
}
default_params_check = {
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"messages": [{"role": "user", "content": prompt_check}],
"stream": False,
"max_tokens": 1024,
"temperature": 0.7
}
default_params_check = {
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"messages": [{"role": "user", "content": prompt_check}],
"stream": False,
"max_tokens": 1024,
"temperature": 0.7
}
# 更新参数
if model_params:
default_params.update(model_params)
def create_completion():
return self.client.chat.completions.create(**default_params)
@@ -135,6 +149,7 @@ class LLMResponseGenerator:
loop = asyncio.get_event_loop()
# 读空气模块
air = 0
reasoning_content_check=''
content_check=''
if global_config.enable_kuuki_read:
@@ -148,21 +163,26 @@ class LLMResponseGenerator:
content_check = response_check.choices[0].message.content
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
if 'yes' not in content_check.lower():
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': model_name,
'reasoning_check': reasoning_content_check,
'response_check': content_check,
'reasoning': "",
'response': "",
'prompt': prompt,
'prompt_check': prompt_check,
'model_params': default_params
})
return None
air = 1
#稀释读空气的判定
if air == 1 and random.random() < 0.3:
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': model_name,
'reasoning_check': reasoning_content_check,
'response_check': content_check,
'reasoning': "",
'response': "",
'prompt': prompt,
'prompt_check': prompt_check,
'model_params': default_params
})
return None
@@ -206,7 +226,7 @@ class LLMResponseGenerator:
async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复"""
if self.config.API_USING == "deepseek":
if global_config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-reasoner",
@@ -221,7 +241,7 @@ class LLMResponseGenerator:
async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复"""
if self.config.API_USING == "deepseek":
if global_config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-chat",
@@ -274,7 +294,7 @@ class LLMResponseGenerator:
messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop()
if self.config.API_USING == "deepseek":
if global_config.API_USING == "deepseek":
model = "deepseek-chat"
else:
model = "Pro/deepseek-ai/DeepSeek-V3"
@@ -311,4 +331,4 @@ class LLMResponseGenerator:
return processed_response, emotion_tags
# 创建全局实例
llm_response = LLMResponseGenerator(global_config)
llm_response = LLMResponseGenerator()

View File

@@ -1,6 +1,5 @@
import time
import random
from dotenv import load_dotenv
from ..schedule.schedule_generator import bot_schedule
import os
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text
@@ -10,11 +9,6 @@ from .topic_identifier import topic_identifier
from ..memory_system.memory import memory_graph
from random import choice
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
load_dotenv(os.path.join(root_dir, '.env'))
class PromptBuilder:
def __init__(self):
@@ -72,12 +66,15 @@ class PromptBuilder:
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
if all_first_layer_items:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
if overlapping_second_layer:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
# if all_first_layer_items:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
# if overlapping_second_layer:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
all_memories = all_first_layer_items + list(overlapping_second_layer)
# 使用集合去重
all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
if all_memories:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
if all_memories: # 只在列表非空时选择随机项
random_item = choice(all_memories)
@@ -186,15 +183,19 @@ class PromptBuilder:
# prompt += f"{activate_prompt}\n"
prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n"
prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。不要太受现在做的事情影响,因为摸鱼是很正常的。"
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息,无论如何一定要回复,如果是刚刚理会过的人发送消息,且还在与那个人对话中的话一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < 4/6: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]},{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}曾经是一个学习地质的女大学生现在学习心理学和脑科学你会刷贴吧你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < 1: # 第二种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[1]},{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"

View File

@@ -1,14 +1,17 @@
from typing import Optional, Dict, List
from openai import OpenAI
from .message import Message
from .config import global_config, llm_config
import jieba
from nonebot import get_driver
driver = get_driver()
config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY,
base_url=llm_config.SILICONFLOW_BASE_URL
api_key=config.siliconflow_key,
base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:

View File

@@ -4,11 +4,15 @@ from typing import List
from .message import Message
import requests
import numpy as np
from .config import llm_config, global_config
from .config import global_config
import re
from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
@@ -64,7 +68,7 @@ def get_embedding(text):
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}

View File

@@ -7,6 +7,10 @@ from ...common.database import Database
import zlib # 用于 CRC32
import base64
from .config import global_config
from nonebot import get_driver
driver = get_driver()
config = driver.config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库
db = Database(
host=global_config.MONGODB_HOST,
port=global_config.MONGODB_PORT,
db_name=global_config.DATABASE_NAME,
username=global_config.MONGODB_USERNAME,
password=global_config.MONGODB_PASSWORD,
auth_source=global_config.MONGODB_AUTH_SOURCE
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片

View File

@@ -3,6 +3,10 @@ import sys
import numpy as np
import requests
import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -13,9 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
"127.0.0.1", # MongoDB 主机
27017, # MongoDB 端口
"MegBot" # 数据库名称
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
class KnowledgeLibrary:

View File

@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os
import sys
import jieba
from llm_module import LLMModel
@@ -157,9 +158,12 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
"127.0.0.1",
27017,
"MegBot"
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
memory_graph = Memory_graph()
@@ -168,10 +172,12 @@ def main():
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
# visualize_graph(memory_graph, color_by_memory=False)
visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# visualize_graph(memory_graph, color_by_memory=True)
visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
@@ -262,7 +268,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 2 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()

View File

@@ -1,19 +1,19 @@
import os
import requests
from dotenv import load_dotenv
from typing import Tuple, Union
import time
from nonebot import get_driver
# 加载环境变量
load_dotenv()
driver = get_driver()
config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""

View File

@@ -1,30 +1,20 @@
import os
import requests
from dotenv import load_dotenv
from typing import Tuple, Union
import time
from ..chat.config import BotConfig
from nonebot import get_driver
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
env_path = os.path.join(root_dir, 'config', '.env')
# 加载环境变量
print(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
print("成功加载环境变量配置")
else:
print(f"环境变量配置文件不存在: {env_path}")
driver = get_driver()
config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")

View File

@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os
import jieba
from .llm_module import LLMModel
import networkx as nx
@@ -197,8 +198,6 @@ class Hippocampus:
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
progress = (i / len(memory_sample)) * 100
@@ -206,26 +205,25 @@ class Hippocampus:
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
# 延时防止访问超频
# time.sleep(60)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
if '[' in topic or topic=='':
continue
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
self.memory_graph.save_graph_to_db()
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
# 延时防止访问超频
# time.sleep(5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
@@ -263,13 +261,19 @@ def topic_what(text, topic):
return prompt
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time()
Database.initialize(
global_config.MONGODB_HOST,
global_config.MONGODB_PORT,
global_config.DATABASE_NAME
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -9,12 +9,42 @@ import datetime
import random
import time
import os
from dotenv import load_dotenv
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
return chat_text
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -103,7 +133,8 @@ class Memory_graph:
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
@@ -192,166 +223,80 @@ class Memory_graph:
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
# Database.initialize(
# global_config.MONGODB_HOST,
# global_config.MONGODB_PORT,
# global_config.DATABASE_NAME
# )
# memory_graph = Memory_graph()
# llm_model = LLMModel()
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# memory_graph.load_graph_from_db()
def main():
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
env_path = os.path.join(root_dir, 'config', '.env')
# 加载环境变量
print(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
print("成功加载环境变量配置")
else:
print(f"环境变量配置文件不存在: {env_path}")
# 初始化数据库
Database.initialize(
"127.0.0.1",
27017,
"MegBot"
)
memory_graph = Memory_graph()
# 创建LLM模型实例
llm_model = LLMModel()
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# 使用当前时间戳进行测试
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
chat_size =25
for _ in range(30): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
chat_text.append(chat_) # 拼接所有text
# time.sleep(1)
for i, input_text in enumerate(chat_text, 1):
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
progress = (i / len(chat_text)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(chat_text))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(input_text)
first_memory = set()
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
# time.sleep(5)
#将记忆加入到图谱中
for topic, memory in first_memory:
# continue
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
memory_graph.connect_dot(split_topic, other_split_topic)
# memory_graph.store_memory()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
memory_graph.save_graph_to_db()
# memory_graph.load_graph_from_db()
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
#加载进度可视化
for i, input_text in enumerate(memory_sample, 1):
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# print(f"第{i}条消息: {input_text}")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print(f"空消息 跳过")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = llm_model.generate_response(topic_prompt)
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text):
seg_text = list(jieba.cut(text))
@@ -372,18 +317,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 1 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = G.nodes[node].get('memory_items', [])
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
@@ -401,9 +365,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = G.degree(node)
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
@@ -414,8 +378,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50)
nx.draw(G, pos,
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
@@ -427,6 +391,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
start_time = time.time()
# 创建记忆图
memory_graph = Memory_graph()
# 加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
hippocampus.build_memory(chat_size=25)
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# 交互式查询
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
if __name__ == "__main__":
main()

View File

@@ -2,29 +2,21 @@ import datetime
import os
from typing import List, Dict
from .schedule_llm_module import LLMModel
from dotenv import load_dotenv
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
from nonebot import get_driver
driver = get_driver()
config = driver.config
# import sys
# sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径
# from src.plugins.schedule.schedule_llm_module import LLMModel
# from src.common.database import Database # 使用正确的导入语法
# 获取当前文件的绝对路径
#TODO: 这个好几个地方用需要封装
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
load_dotenv(os.path.join(root_dir, '.env'))
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
class ScheduleGenerator:

View File

@@ -1,24 +1,24 @@
import os
import requests
from dotenv import load_dotenv
from typing import Tuple, Union
from nonebot import get_driver
# 加载环境变量
load_dotenv()
driver = get_driver()
config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
if api_using == "deepseek":
self.api_key = os.getenv("DEEPSEEK_API_KEY")
self.base_url = os.getenv("DEEPSEEK_BASE_URL")
self.api_key = config.deep_seek_key
self.base_url = config.deep_seek_base_url
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
self.model_name = model_name
else:
self.model_name = "deepseek-reasoner"
else:
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
self.model_name = model_name
self.params = kwargs