better:海马体2.0升级-进度30%

This commit is contained in:
SengokuCola
2025-03-27 22:14:23 +08:00
parent 2812b0df3c
commit b474da3875
30 changed files with 433 additions and 2410 deletions

View File

@@ -9,7 +9,7 @@ from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics from ..utils.statistic import LLMStatistics
from .bot import chat_bot from .bot import chat_bot
from .config import global_config from ..config.config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from ..willing.willing_manager import willing_manager from ..willing.willing_manager import willing_manager

View File

@@ -14,7 +14,7 @@ from nonebot.adapters.onebot.v11 import (
from ..memory_system.Hippocampus import HippocampusManager from ..memory_system.Hippocampus import HippocampusManager
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from ..config.config import global_config
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
from nonebot import get_driver from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .message_base import Seg from .message_base import Seg
from .utils_user import get_user_nickname, get_groupname from .utils_user import get_user_nickname, get_groupname

View File

@@ -12,7 +12,7 @@ import io
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request

View File

@@ -6,7 +6,7 @@ from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from .message import MessageRecv, MessageThinking, Message from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .utils import process_llm_response from .utils import process_llm_response

View File

@@ -9,7 +9,7 @@ from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageSet from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .config import global_config from ..config.config import global_config
from .utils import truncate_message, calculate_typing_time from .utils import truncate_message, calculate_typing_time
from src.common.logger import LogConfig, SENDER_STYLE_CONFIG from src.common.logger import LogConfig, SENDER_STYLE_CONFIG

View File

@@ -6,7 +6,7 @@ from ...common.database import db
from ..memory_system.Hippocampus import HippocampusManager from ..memory_system.Hippocampus import HippocampusManager
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .config import global_config from ..config.config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker from .utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
@@ -82,7 +82,8 @@ class PromptBuilder:
relevant_memories = await HippocampusManager.get_instance().get_memory_from_text( relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, num=3, max_depth=2, fast_retrieval=True text=message_txt, num=3, max_depth=2, fast_retrieval=True
) )
memory_str = "\n".join(memory for topic, memories, _ in relevant_memories for memory in memories) # memory_str = "\n".join(memory for topic, memories, _ in relevant_memories for memory in memories)
memory_str = ""
print(f"memory_str: {memory_str}") print(f"memory_str: {memory_str}")
if relevant_memories: if relevant_memories:

View File

@@ -3,7 +3,7 @@ from typing import List, Optional
from nonebot import get_driver from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG
# 定义日志配置 # 定义日志配置

View File

@@ -12,7 +12,7 @@ from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from ..config.config import global_config
from .message import MessageRecv, Message from .message import MessageRecv, Message
from .message_base import UserInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
@@ -62,60 +62,6 @@ async def get_embedding(text, request_type="embedding"):
return await llm.get_embedding(text) return await llm.get_embedding(text)
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
# print(f"最接近的记录: {closest_record}")
if closest_record:
closest_time = closest_record["time"]
chat_id = closest_record["chat_id"] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(
db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id, # 添加chat_id过滤
}
)
.sort("time", 1)
.limit(length)
)
# print(f"获取到的记录: {chat_records}")
length = len(chat_records)
# print(f"获取到的记录长度: {length}")
# 转换记录格式
formatted_records = []
for record in chat_records:
# 兼容行为,前向兼容老数据
formatted_records.append(
{
"_id": record["_id"],
"time": record["time"],
"chat_id": record["chat_id"],
"detailed_plain_text": record.get("detailed_plain_text", ""), # 添加文本内容
"memorized_times": record.get("memorized_times", 0), # 添加记忆次数
}
)
return formatted_records
return []
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录

View File

@@ -9,7 +9,7 @@ import io
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger from src.common.logger import get_module_logger

View File

@@ -1,4 +1,4 @@
from .config import global_config from ..config.config import global_config
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager

View File

@@ -0,0 +1,55 @@
import os
from pathlib import Path
from dotenv import load_dotenv
class EnvConfig:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EnvConfig, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self.ROOT_DIR = Path(__file__).parent.parent.parent.parent
self.load_env()
def load_env(self):
env_file = self.ROOT_DIR / '.env'
if env_file.exists():
load_dotenv(env_file)
# 根据ENVIRONMENT变量加载对应的环境文件
env_type = os.getenv('ENVIRONMENT', 'prod')
if env_type == 'dev':
env_file = self.ROOT_DIR / '.env.dev'
elif env_type == 'prod':
env_file = self.ROOT_DIR / '.env.prod'
if env_file.exists():
load_dotenv(env_file, override=True)
def get(self, key, default=None):
return os.getenv(key, default)
def get_all(self):
return dict(os.environ)
def __getattr__(self, name):
return self.get(name)
# 创建全局实例
env_config = EnvConfig()
# 导出环境变量
def get_env(key, default=None):
return os.getenv(key, default)
# 导出所有环境变量
def get_all_env():
return dict(os.environ)

View File

@@ -6,19 +6,77 @@ import time
import re import re
import jieba import jieba
import networkx as nx import networkx as nx
import numpy as np
# from nonebot import get_driver from collections import Counter
from ...common.database import db from ...common.database import db
# from ..chat.config import global_config from ...plugins.models.utils_model import LLM_request
from ..chat.utils import (
calculate_information_content,
cosine_similarity,
get_closest_chat_from_db,
)
from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler #分布生成器 from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler #分布生成器
from .config import MemoryConfig from .memory_config import MemoryConfig
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
# print(f"最接近的记录: {closest_record}")
if closest_record:
closest_time = closest_record["time"]
chat_id = closest_record["chat_id"] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(
db.messages.find(
{
"time": {"$gt": closest_time},
"chat_id": chat_id, # 添加chat_id过滤
}
)
.sort("time", 1)
.limit(length)
)
# print(f"获取到的记录: {chat_records}")
length = len(chat_records)
# print(f"获取到的记录长度: {length}")
# 转换记录格式
formatted_records = []
for record in chat_records:
# 兼容行为,前向兼容老数据
formatted_records.append(
{
"_id": record["_id"],
"time": record["time"],
"chat_id": record["chat_id"],
"detailed_plain_text": record.get("detailed_plain_text", ""), # 添加文本内容
"memorized_times": record.get("memorized_times", 0), # 添加记忆次数
}
)
return formatted_records
return []
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
# 定义日志配置 # 定义日志配置
memory_config = LogConfig( memory_config = LogConfig(
@@ -393,6 +451,59 @@ class EntorhinalCortex:
if need_update: if need_update:
logger.success("[数据库] 已为缺失的时间字段进行补充") logger.success("[数据库] 已为缺失的时间字段进行补充")
async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据"""
start_time = time.time()
logger.info("[数据库] 开始重新同步所有记忆数据...")
# 清空数据库
clear_start = time.time()
db.graph_data.nodes.delete_many({})
db.graph_data.edges.delete_many({})
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
# 获取所有节点和边
memory_nodes = list(self.memory_graph.G.nodes(data=True))
memory_edges = list(self.memory_graph.G.edges(data=True))
# 重新写入节点
node_start = time.time()
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
node_data = {
"concept": concept,
"memory_items": memory_items,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.nodes.insert_one(node_data)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}")
# 重新写入边
edge_start = time.time()
for source, target, data in memory_edges:
edge_data = {
"source": source,
"target": target,
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", datetime.datetime.now().timestamp()),
"last_modified": data.get("last_modified", datetime.datetime.now().timestamp()),
}
db.graph_data.edges.insert_one(edge_data)
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}")
end_time = time.time()
logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
#负责整合,遗忘,合并记忆 #负责整合,遗忘,合并记忆
class ParahippocampalGyrus: class ParahippocampalGyrus:
def __init__(self, hippocampus): def __init__(self, hippocampus):
@@ -582,7 +693,8 @@ class ParahippocampalGyrus:
"秒---------------------" "秒---------------------"
) )
async def operation_forget_topic(self, percentage=0.1): async def operation_forget_topic(self, percentage=0.005):
start_time = time.time()
logger.info("[遗忘] 开始检查数据库...") logger.info("[遗忘] 开始检查数据库...")
all_nodes = list(self.memory_graph.G.nodes()) all_nodes = list(self.memory_graph.G.nodes())
@@ -598,12 +710,20 @@ class ParahippocampalGyrus:
nodes_to_check = random.sample(all_nodes, check_nodes_count) nodes_to_check = random.sample(all_nodes, check_nodes_count)
edges_to_check = random.sample(all_edges, check_edges_count) edges_to_check = random.sample(all_edges, check_edges_count)
edge_changes = {"weakened": 0, "removed": 0} # 使用列表存储变化信息
node_changes = {"reduced": 0, "removed": 0} edge_changes = {
"weakened": [], # 存储减弱的边
"removed": [] # 存储移除的边
}
node_changes = {
"reduced": [], # 存储减少记忆的节点
"removed": [] # 存储移除的节点
}
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
logger.info("[遗忘] 开始检查连接...") logger.info("[遗忘] 开始检查连接...")
edge_check_start = time.time()
for source, target in edges_to_check: for source, target in edges_to_check:
edge_data = self.memory_graph.G[source][target] edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified") last_modified = edge_data.get("last_modified")
@@ -614,15 +734,16 @@ class ParahippocampalGyrus:
if new_strength <= 0: if new_strength <= 0:
self.memory_graph.G.remove_edge(source, target) self.memory_graph.G.remove_edge(source, target)
edge_changes["removed"] += 1 edge_changes["removed"].append(f"{source} -> {target}")
logger.info(f"[遗忘] 连接移除: {source} -> {target}")
else: else:
edge_data["strength"] = new_strength edge_data["strength"] = new_strength
edge_data["last_modified"] = current_time edge_data["last_modified"] = current_time
edge_changes["weakened"] += 1 edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})")
logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})") edge_check_end = time.time()
logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}")
logger.info("[遗忘] 开始检查节点...") logger.info("[遗忘] 开始检查节点...")
node_check_start = time.time()
for node in nodes_to_check: for node in nodes_to_check:
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
last_modified = node_data.get("last_modified", current_time) last_modified = node_data.get("last_modified", current_time)
@@ -640,21 +761,40 @@ class ParahippocampalGyrus:
if memory_items: if memory_items:
self.memory_graph.G.nodes[node]["memory_items"] = memory_items self.memory_graph.G.nodes[node]["memory_items"] = memory_items
self.memory_graph.G.nodes[node]["last_modified"] = current_time self.memory_graph.G.nodes[node]["last_modified"] = current_time
node_changes["reduced"] += 1 node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
logger.info(f"[遗忘] 记忆减少: {node} (数量: {current_count} -> {len(memory_items)})")
else: else:
self.memory_graph.G.remove_node(node) self.memory_graph.G.remove_node(node)
node_changes["removed"] += 1 node_changes["removed"].append(node)
logger.info(f"[遗忘] 节点移除: {node}") node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()): if any(edge_changes.values()) or any(node_changes.values()):
await self.hippocampus.entorhinal_cortex.sync_memory_to_db() sync_start = time.time()
logger.info("[遗忘] 统计信息:")
logger.info(f"[遗忘] 连接变化: {edge_changes['weakened']} 个减弱, {edge_changes['removed']} 个移除") await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
logger.info(f"[遗忘] 节点变化: {node_changes['reduced']} 个减少记忆, {node_changes['removed']} 个移除")
sync_end = time.time()
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}")
# 汇总输出所有变化
logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]:
logger.info(f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}")
if edge_changes["removed"]:
logger.info(f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}")
if node_changes["reduced"]:
logger.info(f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}")
if node_changes["removed"]:
logger.info(f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}")
else: else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件") logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
@@ -696,7 +836,7 @@ class Hippocampus:
prompt = ( prompt = (
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果找不出主题或者没有明显主题,返回<none>。" f"如果确定找不出主题或者没有明显主题,返回<none>。"
) )
return prompt return prompt
@@ -763,7 +903,7 @@ class Hippocampus:
memories.sort(key=lambda x: x[2], reverse=True) memories.sort(key=lambda x: x[2], reverse=True)
return memories return memories
async def get_memory_from_text(self, text: str, num: int = 5, max_depth: int = 2, async def get_memory_from_text(self, text: str, num: int = 5, max_depth: int = 3,
fast_retrieval: bool = False) -> list: fast_retrieval: bool = False) -> list:
"""从文本中提取关键词并获取相关记忆。 """从文本中提取关键词并获取相关记忆。
@@ -795,7 +935,8 @@ class Hippocampus:
keywords = keywords[:5] keywords = keywords[:5]
else: else:
# 使用LLM提取关键词 # 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 topic_num = min(5, max(1, int(len(text) * 0.2))) # 根据文本长度动态调整关键词数量
print(f"提取关键词数量: {topic_num}")
topics_response = await self.llm_topic_judge.generate_response( topics_response = await self.llm_topic_judge.generate_response(
self.find_topic_llm(text, topic_num) self.find_topic_llm(text, topic_num)
) )
@@ -811,12 +952,85 @@ class Hippocampus:
if keyword.strip() if keyword.strip()
] ]
logger.info(f"提取的关键词: {', '.join(keywords)}")
# 从每个关键词获取记忆 # 从每个关键词获取记忆
all_memories = [] all_memories = []
keyword_connections = [] # 存储关键词之间的连接关系
# 检查关键词之间的连接
for i in range(len(keywords)):
for j in range(i + 1, len(keywords)):
keyword1, keyword2 = keywords[i], keywords[j]
# 检查节点是否存在于图中
if keyword1 not in self.memory_graph.G or keyword2 not in self.memory_graph.G:
logger.debug(f"关键词 {keyword1}{keyword2} 不在记忆图中")
continue
# 检查直接连接
if self.memory_graph.G.has_edge(keyword1, keyword2):
keyword_connections.append((keyword1, keyword2, 1))
logger.info(f"发现直接连接: {keyword1} <-> {keyword2} (长度: 1)")
continue
# 检查间接连接(通过其他节点)
for depth in range(2, max_depth + 1):
# 使用networkx的shortest_path_length检查是否存在指定长度的路径
try:
path_length = nx.shortest_path_length(self.memory_graph.G, keyword1, keyword2)
if path_length <= depth:
keyword_connections.append((keyword1, keyword2, path_length))
logger.info(f"发现间接连接: {keyword1} <-> {keyword2} (长度: {path_length})")
# 输出连接路径
path = nx.shortest_path(self.memory_graph.G, keyword1, keyword2)
logger.info(f"连接路径: {' -> '.join(path)}")
break
except nx.NetworkXNoPath:
continue
if not keyword_connections:
logger.info("未发现任何关键词之间的连接")
# 记录已处理的关键词连接
processed_connections = set()
# 从每个关键词获取记忆
for keyword in keywords: for keyword in keywords:
if keyword in self.memory_graph.G: # 只处理存在于图中的关键词
memories = self.get_memory_from_keyword(keyword, max_depth) memories = self.get_memory_from_keyword(keyword, max_depth)
all_memories.extend(memories) all_memories.extend(memories)
# 处理关键词连接相关的记忆
for keyword1, keyword2, path_length in keyword_connections:
if (keyword1, keyword2) in processed_connections or (keyword2, keyword1) in processed_connections:
continue
processed_connections.add((keyword1, keyword2))
# 获取连接路径上的所有节点
try:
path = nx.shortest_path(self.memory_graph.G, keyword1, keyword2)
for node in path:
if node not in keywords: # 只处理路径上的非关键词节点
node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算与输入文本的相似度
node_words = set(jieba.cut(node))
text_words = set(jieba.cut(text))
all_words = node_words | text_words
v1 = [1 if word in node_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.3: # 相似度阈值
all_memories.append((node, memory_items, similarity))
except nx.NetworkXNoPath:
continue
# 去重(基于主题) # 去重(基于主题)
seen_topics = set() seen_topics = set()
unique_memories = [] unique_memories = []
@@ -871,6 +1085,16 @@ class HippocampusManager:
logger.success(f"记忆构建分布: {config.memory_build_distribution}") logger.success(f"记忆构建分布: {config.memory_build_distribution}")
logger.success("--------------------------------") logger.success("--------------------------------")
# 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes())
edge_count = len(memory_graph.edges())
logger.success("--------------------------------")
logger.success("记忆图统计信息:")
logger.success(f"记忆节点数量: {node_count}")
logger.success(f"记忆连接数量: {edge_count}")
logger.success("--------------------------------")
return self._hippocampus return self._hippocampus
async def build_memory(self): async def build_memory(self):
@@ -879,7 +1103,7 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.parahippocampal_gyrus.operation_build_memory() return await self._hippocampus.parahippocampal_gyrus.operation_build_memory()
async def forget_memory(self, percentage: float = 0.1): async def forget_memory(self, percentage: float = 0.005):
"""遗忘记忆的公共接口""" """遗忘记忆的公共接口"""
if not self._initialized: if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
import asyncio
import time
import sys
import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.plugins.config.config import global_config
async def test_memory_system():
"""测试记忆系统的主要功能"""
try:
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
print("记忆系统初始化完成")
# 测试记忆构建
# print("开始测试记忆构建...")
# await hippocampus_manager.build_memory()
# print("记忆构建完成")
# 测试记忆检索
test_text = "千石可乐在群里聊天"
test_text = '''[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊
[03-24 10:39:37] 麦麦(ta的id:2814567326): [回复:变量] 变量就像今天计划总变
[03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗
[03-24 10:40:35] 状态异常(ta的id:535554838): [图片这张图片显示的是Windows系统的环境变量设置界面。界面左侧列出了多个环境变量的值包括Intel Dev Redist、Windows、Windows PowerShell、OpenSSH、NVIDIA Corporation的目录等。右侧有新建、编辑、浏览、删除、上移、下移和编辑文本等操作按钮。图片下方有一个错误提示框显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时系统找不到该文件。这可能是因为MongoDB的安装路径未正确添加到系统环境变量中或者文件路径有误。
图片的含义可能是用户正在尝试设置MongoDB的环境变量以便在命令行或其他程序中使用MongoDB。如果用户正确设置了环境变量那么他们应该能够通过命令行或其他方式启动MongoDB服务。]
[03-24 10:41:08] 一根猫(ta的id:108886006): [回复 麦麦 的消息: [回复某人消息] 改系统变量或者删库重配 ] [@麦麦] 我中途修改人格,需要重配吗
[03-24 10:41:54] 麦麦(ta的id:2814567326): [回复:[回复 麦麦 的消息: [回复某人消息] 改系统变量或者删库重配 ] [@麦麦] 我中途修改人格,需要重配吗] 看情况
[03-24 10:41:54] 麦麦(ta的id:2814567326): 难
[03-24 10:41:54] 麦麦(ta的id:2814567326): 小改变量就行,大动骨安排重配像游戏副本南度改太大会崩
[03-24 10:45:33] 霖泷(ta的id:1967075066): 话说现在思考高达一分钟
[03-24 10:45:38] 霖泷(ta的id:1967075066): 是不是哪里出问题了
[03-24 10:45:39] 艾卡(ta的id:1786525298): [表情包:这张表情包展示了一个动漫角色,她有着紫色的头发和大大的眼睛,表情显得有些困惑或不解。她的头上有一个问号,进一步强调了她的疑惑。整体情感表达的是困惑或不解。]
[03-24 10:46:12] (ta的id:3229291803): [表情包:这张表情包显示了一只手正在做"点赞"的动作,通常表示赞同、喜欢或支持。这个表情包所表达的情感是积极的、赞同的或支持的。]
[03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达
[03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们'''
test_text = '''千石可乐niko分不清AI的陪伴和人类的陪伴,是这样吗?'''
print(f"开始测试记忆检索,测试文本: {test_text}\n")
memories = await hippocampus_manager.get_memory_from_text(
text=test_text,
num=3,
max_depth=3,
fast_retrieval=False
)
print("检索到的记忆:")
for topic, memory_items, similarity in memories:
print(f"主题: {topic}")
print(f"相似度: {similarity:.2f}")
for memory in memory_items:
print(f"- {memory}")
# 测试记忆遗忘
# forget_start_time = time.time()
# # print("开始测试记忆遗忘...")
# await hippocampus_manager.forget_memory(percentage=0.005)
# # print("记忆遗忘完成")
# forget_end_time = time.time()
# print(f"记忆遗忘耗时: {forget_end_time - forget_start_time:.2f} 秒")
# 获取所有节点
# nodes = hippocampus_manager.get_all_node_names()
# print(f"当前记忆系统中的节点数量: {len(nodes)}")
# print("节点列表:")
# for node in nodes:
# print(f"- {node}")
except Exception as e:
print(f"测试过程中出现错误: {e}")
raise
async def main():
"""主函数"""
try:
start_time = time.time()
await test_memory_system()
end_time = time.time()
print(f"测试完成,总耗时: {end_time - start_time:.2f}")
except Exception as e:
print(f"程序执行出错: {e}")
raise
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,298 +0,0 @@
# -*- coding: utf-8 -*-
import os
import sys
import time
import jieba
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from loguru import logger
# from src.common.logger import get_module_logger
# logger = get_module_logger("draw_memory")
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
print(root_path)
from src.common.database import db # noqa: E402
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), ".env.dev")
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {"concept": node}
db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ""
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}"
)
if closest_record:
closest_time = closest_record["time"]
group_id = closest_record["group_id"] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
for record in chat_record:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(record["time"])))
try:
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except (KeyError, TypeError):
# 处理缺少键或类型错误的情况
displayname = record.get("user_nickname", "") or "用户" + str(record.get("user_id", "未知"))
chat_text += f"[{time_str}] {displayname}: {record['processed_plain_text']}\n" # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 清空现有的图数据
db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
"concept": node[0],
"memory_items": node[1].get("memory_items", []), # 默认为空列表
}
db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {"source": edge[0], "target": edge[1]}
db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node["concept"], memory_items=memory_items)
# 加载边
edges = db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge["source"], edge["target"])
def main():
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == "退出":
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
logger.debug("第一层记忆:")
for item in first_layer_items:
logger.debug(item)
logger.debug("第二层记忆:")
for item in second_layer_items:
logger.debug(item)
else:
logger.debug("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。"
f"只需要列举{topic_num}个话题就好,不要告诉我其他内容。"
)
return prompt
def topic_what(text, topic):
prompt = (
f"这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。"
f"只输出这句话就好"
)
return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
logger.debug("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
r = (degree / max_degree) ** 0.3
red = min(1.0, r)
# 蓝色分量随着度数减少而增加
blue = max(0.0, 1 - red)
# blue = 1
color = (red, 0.1, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=0.5,
alpha=0.9,
)
title = "记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数"
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,992 +0,0 @@
# -*- coding: utf-8 -*-
import datetime
import math
import os
import random
import sys
import time
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
sys.path.insert(0, sys.path[0]+"/../")
from src.common.logger import get_module_logger
import jieba
# from chat.config import global_config
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import db # noqa E402
from src.plugins.memory_system.offline_llm import LLMModel # noqa E402
# 获取当前文件的目录
current_dir = Path(__file__).resolve().parent
# 获取项目根目录(上三层目录)
project_root = current_dir.parent.parent.parent
# env.dev文件路径
env_path = project_root / ".env.dev"
logger = get_module_logger("mem_manual_bd")
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
list: 消息记录字典列表,每个字典包含消息内容和时间信息
"""
chat_records = []
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[("time", -1)])
if closest_record and closest_record.get("memorized", 0) < 4:
closest_time = closest_record["time"]
group_id = closest_record["group_id"]
# 获取该时间戳之后的length条消息且groupid相同
records = list(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort("time", 1).limit(length)
)
# 更新每条消息的memorized属性
for record in records:
current_memorized = record.get("memorized", 0)
if current_memorized > 3:
print("消息已读取3次跳过")
return ""
# 更新memorized值
db.messages.update_one({"_id": record["_id"]}, {"$set": {"memorized": current_memorized + 1}})
# 添加到记录列表中
chat_records.append(
{"text": record["detailed_plain_text"], "time": record["time"], "group_id": record["group_id"]}
)
return chat_records
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]["strength"] = self.G[concept1][concept2].get("strength", 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if "memory_items" in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
else:
self.G.nodes[concept]["memory_items"] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
def get_memory_sample(self, chat_size=20, time_frequency=None):
"""获取记忆样本
Returns:
list: 消息记录列表,每个元素是一个消息记录字典列表
"""
if time_frequency is None:
time_frequency = {"near": 2, "mid": 4, "far": 3}
current_timestamp = datetime.datetime.now().timestamp()
chat_samples = []
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get("near")):
random_time = current_timestamp - random.randint(1, 3600 * 4)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get("mid")):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get("far")):
random_time = current_timestamp - random.randint(3600 * 24, 3600 * 24 * 7)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
return chat_samples
def calculate_topic_num(self, text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content) / 2)
print(
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
f"topic_num: {topic_num}"
)
return topic_num
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩消息记录为记忆
Args:
messages: 消息记录字典列表每个字典包含text和time字段
compress_rate: 压缩率
Returns:
set: (话题, 记忆) 元组集合
"""
if not messages:
return set()
# 合并消息文本,同时保留时间信息
input_text = ""
time_info = ""
# 计算最早和最晚时间
earliest_time = min(msg["time"] for msg in messages)
latest_time = max(msg["time"] for msg in messages)
earliest_dt = datetime.datetime.fromtimestamp(earliest_time)
latest_dt = datetime.datetime.fromtimestamp(latest_time)
# 如果是同一年
if earliest_dt.year == latest_dt.year:
earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%m-%d %H:%M:%S")
time_info += f"是在{earliest_dt.year}年,{earliest_str}{latest_str} 的对话:\n"
else:
earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S")
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages:
input_text += f"{msg['text']}\n"
print(input_text)
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
# 过滤topics
filter_keywords = ["表情包", "图片", "回复", "聊天记录"]
topics = [
topic.strip()
for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
# print(f"原始话题: {topics}")
print(f"过滤后话题: {filtered_topics}")
# 创建所有话题的请求任务
tasks = []
for topic in filtered_topics:
topic_what_prompt = self.topic_what(input_text, topic, time_info)
# 创建异步任务
task = self.llm_model_small.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def operation_build_memory(self, chat_size=12):
# 最近消息获取频率
time_frequency = {"near": 3, "mid": 8, "far": 5}
memory_samples = self.get_memory_sample(chat_size, time_frequency)
all_topics = [] # 用于存储所有话题
for i, messages in enumerate(memory_samples, 1):
# 加载进度可视化
all_topics = []
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))
bar = "" * filled_length + "-" * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 生成压缩后记忆
compress_rate = 0.1
compressed_memory = await self.memory_compress(messages, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic)
# 连接相关话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_from_db(self):
"""
从数据库同步数据到内存中的图结构
将清空当前内存中的图,并从数据库重新加载所有节点和边
"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = db.graph_data.nodes.find()
for node in nodes:
concept = node["concept"]
memory_items = node.get("memory_items", [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = db.graph_data.edges.find()
for edge in edges:
source = edge["source"]
target = edge["target"]
strength = edge.get("strength", 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
logger.success("从数据库同步记忆图谱完成")
def calculate_node_hash(self, concept, memory_items):
"""
计算节点的特征值
"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 将记忆项排序以确保相同内容生成相同的哈希值
sorted_items = sorted(memory_items)
# 组合概念和记忆项生成特征值
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""
计算边的特征值
"""
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def sync_memory_to_db(self):
"""
检查并同步内存中的图结构与数据库
使用特征值(哈希值)快速判断是否需要更新
"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node["concept"]: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
# logger.info(f"添加新节点: {concept}")
node_data = {"concept": concept, "memory_items": memory_items, "hash": memory_hash}
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get("hash", None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
# logger.info(f"更新节点内容: {concept}")
db.graph_data.nodes.update_one(
{"concept": concept}, {"$set": {"memory_items": memory_items, "hash": memory_hash}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node["concept"] not in memory_concepts:
# logger.info(f"删除多余节点: {db_node['concept']}")
db.graph_data.nodes.delete_one({"concept": db_node["concept"]})
# 处理边的信息
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge["source"], edge["target"])
db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "num": edge.get("num", 1)}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
if edge_key not in db_edge_dict:
# 添加新边
logger.info(f"添加新边: {source} - {target}")
edge_data = {"source": source, "target": target, "num": 1, "hash": edge_hash}
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
logger.info(f"更新边: {source} - {target}")
db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": {"hash": edge_hash}})
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
logger.info(f"删除多余边: {source} - {target}")
db.graph_data.edges.delete_one({"source": source, "target": target})
logger.success("完成记忆图谱与数据库的差异同步")
def find_topic_llm(self, text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
)
return prompt
def topic_what(self, text, topic, time_info):
# 获取当前时间
prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
)
return prompt
def remove_node_from_db(self, topic):
"""
从数据库中删除指定节点及其相关的边
Args:
topic: 要删除的节点概念
"""
# 删除节点
db.graph_data.nodes.delete_one({"concept": topic})
# 删除所有涉及该节点的边
db.graph_data.edges.delete_many({"$or": [{"source": topic}, {"target": topic}]})
def forget_topic(self, topic):
"""
随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
只在内存中的图上操作,不直接与数据库交互
Args:
topic: 要删除记忆的话题
Returns:
removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
"""
if topic not in self.memory_graph.G:
return None
# 获取话题节点数据
node_data = self.memory_graph.G.nodes[topic]
# 如果节点存在memory_items
if "memory_items" in node_data:
memory_items = node_data["memory_items"]
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.memory_graph.G.nodes[topic]["memory_items"] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.memory_graph.G.remove_node(topic)
return removed_item
return None
async def operation_forget_topic(self, percentage=0.1):
"""
随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get("strength", 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
logger.info("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(selected_memories, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]["memory_items"] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
async def _identify_topics(self, text: str) -> list:
"""从文本中识别可能的主题"""
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(text, 5))
topics = [
topic.strip()
for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
return topics
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
"""查找与给定主题相似的记忆主题"""
all_memory_topics = list(self.memory_graph.G.nodes())
all_similar_topics = []
for topic in topics:
if debug_info:
pass
topic_vector = text_to_vector(topic)
for memory_topic in all_memory_topics:
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= similarity_threshold:
all_similar_topics.append((memory_topic, similarity))
return all_similar_topics
def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list:
"""获取相似度最高的主题"""
seen_topics = set()
top_topics = []
for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True):
if topic not in seen_topics and len(top_topics) < max_topics:
seen_topics.add(topic)
top_topics.append((topic, score))
return top_topics
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度"""
logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
identified_topics = await self._identify_topics(text)
if not identified_topics:
return 0
all_similar_topics = self._find_similar_topics(
identified_topics, similarity_threshold=similarity_threshold, debug_info="记忆激活"
)
if not all_similar_topics:
return 0
top_topics = self._get_top_topics(all_similar_topics, max_topics)
if len(top_topics) == 1:
topic, score = top_topics[0]
memory_items = self.memory_graph.G.nodes[topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
activation = int(score * 50 * penalty)
print(
f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, "
f"激活值: {activation}"
)
return activation
matched_topics = set()
topic_similarities = {}
for memory_topic, _similarity in top_topics:
memory_items = self.memory_graph.G.nodes[memory_topic].get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
for input_topic in identified_topics:
topic_vector = text_to_vector(input_topic)
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
sim = cosine_similarity(v1, v2)
if sim >= similarity_threshold:
matched_topics.add(input_topic)
adjusted_sim = sim * penalty
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
print(
f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> "
f"{memory_topic}」(内容数: {content_count}, "
f"相似度: {adjusted_sim:.3f})"
)
topic_match = len(matched_topics) / len(identified_topics)
average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0
activation = int((topic_match + average_similarities) / 2 * 100)
print(
f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, "
f"激活值: {activation}"
)
return activation
async def get_relevant_memories(
self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5
) -> list:
"""根据输入文本获取相关的记忆内容"""
identified_topics = await self._identify_topics(text)
all_similar_topics = self._find_similar_topics(
identified_topics, similarity_threshold=similarity_threshold, debug_info="记忆检索"
)
relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
relevant_memories = []
for topic, score in relevant_topics:
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer:
if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num // 2)
for memory in first_layer:
relevant_memories.append({"topic": topic, "similarity": score, "content": memory})
relevant_memories.sort(key=lambda x: x["similarity"], reverse=True)
if len(relevant_memories) > max_memory_num:
relevant_memories = random.sample(relevant_memories, max_memory_num)
return relevant_memories
def segment_text(text):
"""使用jieba进行文本分词"""
seg_text = list(jieba.cut(text))
return seg_text
def text_to_vector(text):
"""将文本转换为词频向量"""
words = segment_text(text)
vector = {}
for word in words:
vector[word] = vector.get(word, 0) + 1
return vector
def cosine_similarity(v1, v2):
"""计算两个向量的余弦相似度"""
dot_product = sum(a * b for a, b in zip(v1, v2))
norm1 = math.sqrt(sum(a * a for a in v1))
norm2 = math.sqrt(sum(b * b for b in v2))
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei"] # 用来正常显示中文标签
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 过滤掉内容数量小于2的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
if memory_count < 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果没有符合条件的节点,直接返回
if len(H.nodes()) == 0:
print("没有找到内容数量大于等于2的节点")
return
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数用于归一化节点大小
max_memories = 1
for node in nodes:
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
max_memories = max(max_memories, memory_count)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get("memory_items", [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 400 + 2000 * (ratio**2) # 增大节点大小
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
if degree >= 30:
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
else:
# 将1-10映射到0-1的范围
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
# 使用蓝到红的渐变
red = min(0.9, color_ratio)
blue = max(0.0, 1.0 - color_ratio)
node_colors.append((red, 0, blue))
# 绘制图形
plt.figure(figsize=(16, 12)) # 减小图形尺寸
pos = nx.spring_layout(
H,
k=1, # 调整节点间斥力
iterations=100, # 增加迭代次数
scale=1.5, # 减小布局尺寸
weight="strength",
) # 使用边的strength属性作为权重
nx.draw(
H,
pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=12, # 保持增大的字体大小
font_family="SimHei",
font_weight="bold",
edge_color="gray",
width=1.5,
) # 统一的边宽度
title = """记忆图谱可视化仅显示内容≥2的节点
节点大小表示记忆数量
节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度
连接强度越大的节点距离越近"""
plt.title(title, fontsize=16, fontfamily="SimHei")
plt.show()
async def main():
start_time = time.time()
test_pare = {
"do_build_memory": False,
"do_forget_topic": False,
"do_visualize_graph": True,
"do_query": False,
"do_merge_memory": False,
}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare["do_build_memory"]:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(
f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m"
)
if test_pare["do_forget_topic"]:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare["do_merge_memory"]:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare["do_visualize_graph"]:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare["do_query"]:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == "退出":
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
if first_layer:
print("\n直接相关的记忆:")
for item in first_layer:
print(f"- {item}")
if second_layer:
print("\n间接相关的记忆:")
for item in second_layer:
print(f"- {item}")
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
class LLMModel: class LLM_request_off:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs

View File

@@ -6,15 +6,13 @@ from typing import Tuple, Union
import aiohttp import aiohttp
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from nonebot import get_driver
import base64 import base64
from PIL import Image from PIL import Image
import io import io
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..config.config import global_config
from ..config.config_env import env_config
driver = get_driver()
config = driver.config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
@@ -34,8 +32,9 @@ class LLM_request:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = getattr(config, model["key"]) self.api_key = getattr(env_config, model["key"])
self.base_url = getattr(config, model["base_url"]) self.base_url = getattr(env_config, model["base_url"])
# print(self.api_key, self.base_url)
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")

View File

@@ -3,7 +3,7 @@ import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from ..chat.config import global_config from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG
mood_config = LogConfig( mood_config = LogConfig(

View File

@@ -6,7 +6,7 @@ import os
import json import json
import threading import threading
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config
logger = get_module_logger("remote") logger = get_module_logger("remote")

View File

@@ -10,7 +10,7 @@ sys.path.append(root_path)
from src.common.database import db # noqa: E402 from src.common.database import db # noqa: E402
from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402 from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402
from src.plugins.models.utils_model import LLM_request # noqa: E402 from src.plugins.models.utils_model import LLM_request # noqa: E402
from src.plugins.chat.config import global_config # noqa: E402 from src.plugins.config.config import global_config # noqa: E402
schedule_config = LogConfig( schedule_config = LogConfig(

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
from typing import Dict from typing import Dict
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..chat.config import global_config from ..config.config import global_config
class WillingManager: class WillingManager:

View File

@@ -3,7 +3,7 @@ import random
import time import time
from typing import Dict from typing import Dict
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.config import global_config from ..config.config import global_config
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
logger = get_module_logger("mode_dynamic") logger = get_module_logger("mode_dynamic")

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.config import global_config from ..config.config import global_config
from .mode_classical import WillingManager as ClassicalWillingManager from .mode_classical import WillingManager as ClassicalWillingManager
from .mode_dynamic import WillingManager as DynamicWillingManager from .mode_dynamic import WillingManager as DynamicWillingManager
from .mode_custom import WillingManager as CustomWillingManager from .mode_custom import WillingManager as CustomWillingManager

View File

@@ -2,7 +2,7 @@ from .outer_world import outer_world
import asyncio import asyncio
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config, BotConfig from src.plugins.config.config import global_config, BotConfig
import re import re
import time import time
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule

View File

@@ -1,7 +1,7 @@
from .current_mind import SubHeartflow from .current_mind import SubHeartflow
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config, BotConfig from src.plugins.config.config import global_config, BotConfig
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule
import asyncio import asyncio
from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONFIG # noqa: E402

View File

@@ -2,7 +2,7 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLM_request
from src.plugins.chat.config import global_config from src.plugins.config.config import global_config
from src.common.database import db from src.common.database import db
#存储一段聊天的大致内容 #存储一段聊天的大致内容