From 9057f972f754e65ac58217b3216fbb937b44f4e8 Mon Sep 17 00:00:00 2001 From: tcmofashi <107829254+tcmofashi@users.noreply.github.com> Date: Wed, 5 Mar 2025 09:26:37 +0800 Subject: [PATCH 1/4] =?UTF-8?q?Fix:=20=E5=AE=8C=E7=BE=8E=E7=9A=84=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E5=8E=8B=E7=BC=A9=20@sourcery-ai=20(#54)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: logger三合一 * fix: emoji压缩功能正常使用 * fix: 提高压缩率 * fix: 0.8MB --- src/plugins/chat/bot.py | 8 +-- src/plugins/chat/config.py | 10 ++-- src/plugins/chat/emoji_manager.py | 96 +++++++++++++++++++++++++------ 3 files changed, 87 insertions(+), 27 deletions(-) diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index caf27c3f3..6b0e76db5 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -16,6 +16,7 @@ from .relationship_manager import relationship_manager from .willing_manager import willing_manager # 导入意愿管理器 from .utils import is_mentioned_bot_in_txt, calculate_typing_time from ..memory_system.memory import memory_graph +from loguru import logger class ChatBot: def __init__(self): @@ -61,8 +62,8 @@ class ChatBot: # 过滤词 for word in global_config.ban_words: if word in message.detailed_plain_text: - print(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}") - print(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered") + logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}") + logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered") return current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) @@ -77,8 +78,7 @@ class ChatBot: # topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text) # topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text) # topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text) - print(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") - + logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") all_num = 0 interested_num = 0 diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index a74a668a1..96c83dfe0 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -1,8 +1,6 @@ from dataclasses import dataclass, field from typing import Dict, Any, Optional, Set import os -from nonebot.log import logger, default_format -import logging import configparser import tomli import sys @@ -85,9 +83,9 @@ class BotConfig: personality_config=toml_dict['personality'] personality=personality_config.get('prompt_personality') if len(personality) >= 2: - print(f"载入自定义人格:{personality}") + logger.info(f"载入自定义人格:{personality}") config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY) - print(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}") + logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}") config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN) if "emoji" in toml_dict: @@ -141,10 +139,10 @@ class BotConfig: topic_config=toml_dict['topic'] if 'topic_extract' in topic_config: config.topic_extract=topic_config.get('topic_extract',config.topic_extract) - print(f"载入自定义主题提取为{config.topic_extract}") + logger.info(f"载入自定义主题提取为{config.topic_extract}") if config.topic_extract=='llm' and 'llm_topic' in topic_config: config.llm_topic_extract=topic_config['llm_topic'] - print(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}") + logger.info(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}") # 消息配置 if "message" in toml_dict: diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index cf0adff2e..2311b2459 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -12,6 +12,8 @@ import base64 import shutil import asyncio import time +from PIL import Image +import io from nonebot import get_driver from ..chat.config import global_config @@ -240,41 +242,102 @@ class EmojiManager: print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral") return "skip" # 默认标签 + async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]: + """压缩图片并返回base64编码 + Args: + image_path: 图片文件路径 + target_size: 目标文件大小(字节),默认0.8MB + Returns: + Optional[str]: 成功返回base64编码的图片数据,失败返回None + """ + try: + file_size = os.path.getsize(image_path) + if file_size <= target_size: + # 如果文件已经小于目标大小,直接读取并返回base64 + with open(image_path, 'rb') as f: + return base64.b64encode(f.read()).decode('utf-8') + + # 打开图片 + with Image.open(image_path) as img: + # 获取原始尺寸 + original_width, original_height = img.size + + # 计算缩放比例 + scale = min(1.0, (target_size / file_size) ** 0.5) + + # 计算新的尺寸 + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # 创建内存缓冲区 + output_buffer = io.BytesIO() + + # 如果是GIF,处理所有帧 + if getattr(img, "is_animated", False): + frames = [] + for frame_idx in range(img.n_frames): + img.seek(frame_idx) + new_frame = img.copy() + new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format='GIF', + save_all=True, + append_images=frames[1:], + optimize=True, + duration=img.info.get('duration', 100), + loop=img.info.get('loop', 0) + ) + else: + # 处理静态图片 + resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # 保存到缓冲区,保持原始格式 + if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): + resized_img.save(output_buffer, format='PNG', optimize=True) + else: + resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) + + # 获取压缩后的数据并转换为base64 + compressed_data = output_buffer.getvalue() + print(f"\033[1;32m[成功]\033[0m 压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})") + + return base64.b64encode(compressed_data).decode('utf-8') + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}") + return None + async def scan_new_emojis(self): """扫描新的表情包""" try: emoji_dir = "data/emoji" os.makedirs(emoji_dir, exist_ok=True) - # 获取所有jpg文件 - files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')] + # 获取所有支持的图片文件 + files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))] for filename in files_to_process: image_path = os.path.join(emoji_dir, filename) - # 检查文件大小 - file_size = os.path.getsize(image_path) - if file_size > 5 * 1024 * 1024: # 5MB - print(f"\033[1;33m[警告]\033[0m 表情包文件过大 ({file_size/1024/1024:.2f}MB),删除: {filename}") - os.remove(image_path) - continue - # 检查是否已经注册过 existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) if existing_emoji: continue - - # 读取图片数据 - with open(image_path, 'rb') as f: - image_data = f.read() - # 将图片转换为base64 - image_base64 = base64.b64encode(image_data).decode('utf-8') + # 压缩图片并获取base64编码 + image_base64 = await self._compress_image(image_path) + if image_base64 is None: + os.remove(image_path) + continue # 获取表情包的情感标签 tag = await self._get_emoji_tag(image_base64) if not tag == "skip": - # 准备数据库记录 + # 准备数据库记录 emoji_record = { 'filename': filename, 'path': image_path, @@ -288,7 +351,6 @@ class EmojiManager: print(f"标签: {tag}") else: print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}") - except Exception as e: print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}") From 8e48e7201912f6706f295806adf50d05d738678c Mon Sep 17 00:00:00 2001 From: KawaiiYusora <40208202+SaigyoujiYusora@users.noreply.github.com> Date: Wed, 5 Mar 2025 09:27:52 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=B9=8B=E6=9F=A5=E7=BC=BA=E8=A1=A5?= =?UTF-8?q?=E6=BC=8F=20(#53)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/models/utils_model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 4741d2596..3ba873d74 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -60,8 +60,15 @@ class LLM_request: result = await response.json() if "choices" in result and len(result["choices"]) > 0: - content = result["choices"][0]["message"]["content"] - reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + message = result["choices"][0]["message"] + content = message.get("content", "") + think_match = None + reasoning_content = message.get("reasoning_content", "") + if not reasoning_content: + think_match = re.search(r'(.*?)', content, re.DOTALL) + if think_match: + reasoning_content = think_match.group(1).strip() + content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() return content, reasoning_content return "没有返回结果", "" From 543504858d9cb20bd5d466ea2218b21710c7f2e9 Mon Sep 17 00:00:00 2001 From: NepPure Date: Wed, 5 Mar 2025 10:00:44 +0800 Subject: [PATCH 3/4] CI: Support debug branch and improve Docker tagging/caching (#55) * ci docker * ci * ci --- .github/workflows/docker-image.yml | 34 +++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 669fb8a1e..2a5f497fd 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -3,10 +3,11 @@ name: Docker Build and Push on: push: branches: - - main # 推送到main分支时触发 + - main + - debug # 新增 debug 分支触发 tags: - - 'v*' # 推送v开头的tag时触发(例如v1.0.0) - workflow_dispatch: # 允许手动触发 + - 'v*' + workflow_dispatch: jobs: build-and-push: @@ -24,15 +25,24 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Determine Image Tags + id: tags + run: | + if [[ "${{ github.ref }}" == refs/tags/* ]]; then + echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT + elif [ "${{ github.ref }}" == "refs/heads/main" ]; then + echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT + elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then + echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT + fi + - name: Build and Push Docker Image uses: docker/build-push-action@v5 with: - context: . # Docker构建上下文路径 - file: ./Dockerfile # Dockerfile路径 - platforms: linux/amd64,linux/arm64 # 支持arm架构 - tags: | - ${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }} - ${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest - push: true - cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest - cache-to: type=inline + context: . + file: ./Dockerfile + platforms: linux/amd64,linux/arm64 + tags: ${{ steps.tags.outputs.tags }} + push: true + cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache + cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max \ No newline at end of file From 3fec29d0456c58db5cc5617c64df4b5e735567b7 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 5 Mar 2025 16:48:53 +0800 Subject: [PATCH 4/4] =?UTF-8?q?v0.5.2=20=E8=AE=B0=E5=BF=86=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +- config/bot_config_template.toml | 3 +- src/plugins/chat/__init__.py | 14 +- src/plugins/chat/config.py | 4 +- src/plugins/chat/del.message_send_control.py | 251 ------ src/plugins/chat/del.message_stream.py | 271 ------ src/plugins/chat/del.message_visualizer.py | 138 --- src/plugins/chat/willing_manager.py | 8 +- src/plugins/memory_system/draw_memory.py | 148 +--- src/plugins/memory_system/memory.py | 374 ++++++-- src/plugins/memory_system/memory_make.py | 463 ---------- .../memory_system/memory_manual_build.py | 805 ++++++++++++++++++ ...m_module_memory_make.py => offline_llm.py} | 92 +- 13 files changed, 1235 insertions(+), 1346 deletions(-) delete mode 100644 src/plugins/chat/del.message_send_control.py delete mode 100644 src/plugins/chat/del.message_stream.py delete mode 100644 src/plugins/chat/del.message_visualizer.py delete mode 100644 src/plugins/memory_system/memory_make.py create mode 100644 src/plugins/memory_system/memory_manual_build.py rename src/plugins/memory_system/{llm_module_memory_make.py => offline_llm.py} (50%) diff --git a/README.md b/README.md index c09b33c49..7bfa465ae 100644 --- a/README.md +++ b/README.md @@ -42,22 +42,22 @@ ## 🎯 功能介绍 ### 💬 聊天功能 -- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言,目前有bug,所以现在只会检测主题,不会进行存储 +- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言 - 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置 -- 使用硅基流动的api进行回复生成,可随机使用R1,V3,R1-distill等模型,未来将加入官网api支持 +- 支持多模型,多厂商自定义配置 - 动态的prompt构建器,更拟人 - 支持图片,转发消息,回复消息的识别 - 错别字和多条回复功能:麦麦可以随机生成错别字,会多条发送回复以及对消息进行reply ### 😊 表情包功能 -- 支持根据发言内容发送对应情绪的表情包:未完善,可以用 -- 会自动偷群友的表情包(未完善,暂时禁用)目前有bug +- 支持根据发言内容发送对应情绪的表情包 +- 会自动偷群友的表情包 ### 📅 日程功能 - 麦麦会自动生成一天的日程,实现更拟人的回复 ### 🧠 记忆功能 -- 对聊天记录进行概括存储,在需要时调用,没写完 +- 对聊天记录进行概括存储,在需要时调用,待完善 ### 📚 知识库功能 - 基于embedding模型的知识库,手动放入txt会自动识别,写完了,暂时禁用 diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index 5ad837f6d..28ffb0ce3 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -11,7 +11,7 @@ prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的 [message] min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息 -max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃 +max_context_size = 15 # 麦麦获得的上文数量 emoji_chance = 0.2 # 麦麦使用表情包的概率 ban_words = [ # "403","张三" @@ -31,6 +31,7 @@ model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概 [memory] build_memory_interval = 300 # 记忆构建间隔 单位秒 +forget_memory_interval = 300 # 记忆遗忘间隔 单位秒 [others] enable_advance_output = true # 是否启用高级输出 diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index ac04866a5..66824d986 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -98,7 +98,19 @@ async def monitor_relationships(): async def build_memory_task(): """每30秒执行一次记忆构建""" print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...") - await hippocampus.build_memory(chat_size=30) + await hippocampus.operation_build_memory(chat_size=30) print("\033[1;32m[记忆构建]\033[0m 记忆构建完成") +@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") +async def forget_memory_task(): + """每30秒执行一次记忆构建""" + print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") + await hippocampus.operation_forget_topic(percentage=0.1) + print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") +@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="build_memory") +async def build_memory_task(): + """每30秒执行一次记忆构建""" + print("\033[1;32m[记忆整合]\033[0m 开始整合") + await hippocampus.operation_merge_memory(percentage=0.1) + print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 96c83dfe0..298683054 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -27,6 +27,7 @@ class BotConfig: ban_user_id = set() build_memory_interval: int = 60 # 记忆构建间隔(秒) + forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) @@ -155,6 +156,7 @@ class BotConfig: if "memory" in toml_dict: memory_config = toml_dict["memory"] config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval) + config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval) # 群组配置 if "groups" in toml_dict: @@ -188,6 +190,6 @@ global_config = BotConfig.load_config(config_path=bot_config_path) if not global_config.enable_advance_output: - # logger.remove() + logger.remove() pass diff --git a/src/plugins/chat/del.message_send_control.py b/src/plugins/chat/del.message_send_control.py deleted file mode 100644 index 30ade9cd4..000000000 --- a/src/plugins/chat/del.message_send_control.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Union, List, Optional, Deque, Dict -from nonebot.adapters.onebot.v11 import Bot, MessageSegment -import asyncio -import random -import os -from .message import Message, Message_Thinking, MessageSet -from .cq_code import CQCode -from collections import deque -import time -from .storage import MessageStorage -from .config import global_config -from .cq_code import cq_code_tool - -if os.name == "nt": - from .message_visualizer import message_visualizer - - - -class SendTemp: - """单个群组的临时消息队列管理器""" - def __init__(self, group_id: int, max_size: int = 100): - self.group_id = group_id - self.max_size = max_size - self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size) - self.last_send_time = 0 - - def add(self, message: Message) -> None: - """按时间顺序添加消息到队列""" - if not self.messages: - self.messages.append(message) - return - - # 按时间顺序插入 - if message.time >= self.messages[-1].time: - self.messages.append(message) - return - - # 使用二分查找找到合适的插入位置 - messages_list = list(self.messages) - left, right = 0, len(messages_list) - - while left < right: - mid = (left + right) // 2 - if messages_list[mid].time < message.time: - left = mid + 1 - else: - right = mid - - # 重建消息队列,保持时间顺序 - new_messages = deque(maxlen=self.max_size) - new_messages.extend(messages_list[:left]) - new_messages.append(message) - new_messages.extend(messages_list[left:]) - self.messages = new_messages - def get_earliest_message(self) -> Optional[Message]: - """获取时间最早的消息""" - message = self.messages.popleft() if self.messages else None - return message - - def clear(self) -> None: - """清空队列""" - self.messages.clear() - - def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]: - """获取所有待发送的消息""" - if group_id is None: - return list(self.messages) - return [msg for msg in self.messages if msg.group_id == group_id] - - def peek_next(self) -> Optional[Union[Message, Message_Thinking]]: - """查看下一条要发送的消息(不移除)""" - return self.messages[0] if self.messages else None - - def has_messages(self) -> bool: - """检查是否有待发送的消息""" - return bool(self.messages) - - def count(self, group_id: Optional[int] = None) -> int: - """获取待发送消息数量""" - if group_id is None: - return len(self.messages) - return len([msg for msg in self.messages if msg.group_id == group_id]) - - def get_last_send_time(self) -> float: - """获取最后一次发送时间""" - return self.last_send_time - - def update_send_time(self): - """更新最后发送时间""" - self.last_send_time = time.time() - -class SendTempContainer: - """管理所有群组的消息缓存容器""" - def __init__(self): - self.temp_queues: Dict[int, SendTemp] = {} - - def get_queue(self, group_id: int) -> SendTemp: - """获取或创建群组的消息队列""" - if group_id not in self.temp_queues: - self.temp_queues[group_id] = SendTemp(group_id) - return self.temp_queues[group_id] - - def add_message(self, message: Message) -> None: - """添加消息到对应群组的队列""" - queue = self.get_queue(message.group_id) - queue.add(message) - - def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]: - """获取指定群组的所有待发送消息""" - queue = self.get_queue(group_id) - return queue.get_all() - - def has_messages(self, group_id: int) -> bool: - """检查指定群组是否有待发送消息""" - queue = self.get_queue(group_id) - return queue.has_messages() - - def get_all_groups(self) -> List[int]: - """获取所有有待发送消息的群组ID""" - return list(self.temp_queues.keys()) - - def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool: - queue = self.get_queue(message_obj.group_id) - # 使用列表解析找到匹配的消息索引 - matching_indices = [ - i for i, msg in enumerate(queue.messages) - if msg.message_id == message_obj.message_id - ] - - if not matching_indices: - return False - - index = matching_indices[0] # 获取第一个匹配的索引 - - # 将消息转换为列表以便修改 - messages = list(queue.messages) - - # 根据消息类型处理 - if isinstance(message_obj, MessageSet): - messages.pop(index) - # 在原位置插入新消息组 - for i, single_message in enumerate(message_obj.messages): - messages.insert(index + i, single_message) - # print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}") - else: - # 直接替换原消息 - messages[index] = message_obj - # print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}") - - # 重建队列 - queue.messages.clear() - for msg in messages: - queue.messages.append(msg) - - return True - - -class MessageSendControl: - """消息发送控制器""" - def __init__(self): - self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒) - self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒) - self.max_retry = 3 # 最大重试次数 - self.send_temp_container = SendTempContainer() - self._running = True - self._paused = False - self._current_bot = None - self.storage = MessageStorage() # 添加存储实例 - try: - message_visualizer.start() - except(NameError): - pass - - async def process_group_messages(self, group_id: int): - queue = self.send_temp_container.get_queue(group_id) - if queue.has_messages(): - message = queue.peek_next() - # 处理消息的逻辑 - if isinstance(message, Message_Thinking): - message.update_thinking_time() - thinking_time = message.thinking_time - if message.interupt: - print(f"\033[1;34m[调试]\033[0m 思考不打算回复,移除") - queue.get_earliest_message() - return - elif thinking_time < 90: # 最少思考2秒 - if int(thinking_time) % 15 == 0: - print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒") - return - else: - print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除") - queue.get_earliest_message() # 移除超时的思考消息 - return - elif isinstance(message, Message): - message = queue.get_earliest_message() - if message and message.processed_plain_text: - print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}") - cost_time = round(time.time(), 2) - message.time - if cost_time > 40: - message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text - cur_time = time.time() - await self._current_bot.send_group_msg( - group_id=group_id, - message=str(message.processed_plain_text), - auto_escape=False - ) - cost_time = round(time.time(), 2) - cur_time - print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒") - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) - print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}") - - if message.is_emoji: - message.processed_plain_text = "[表情包]" - await self.storage.store_message(message, None) - else: - await self.storage.store_message(message, None) - - - - queue.update_send_time() - if queue.has_messages(): - await asyncio.sleep( - random.uniform( - self.message_interval[0], - self.message_interval[1] - ) - ) - - async def start_processor(self, bot: Bot): - """启动消息处理器""" - self._current_bot = bot - - while self._running: - await asyncio.sleep(1.5) - tasks = [] - for group_id in self.send_temp_container.get_all_groups(): - tasks.append(self.process_group_messages(group_id)) - - # 并行处理所有群组的消息 - await asyncio.gather(*tasks) - try: - message_visualizer.update_content(self.send_temp_container) - except(NameError): - pass - - def set_typing_speed(self, min_speed: float, max_speed: float): - """设置打字速度范围""" - self.typing_speed = (min_speed, max_speed) - -# 创建全局实例 -message_sender_control = MessageSendControl() diff --git a/src/plugins/chat/del.message_stream.py b/src/plugins/chat/del.message_stream.py deleted file mode 100644 index 07809caa7..000000000 --- a/src/plugins/chat/del.message_stream.py +++ /dev/null @@ -1,271 +0,0 @@ -from typing import List, Optional, Dict -from .message import Message -import time -from collections import deque -from datetime import datetime, timedelta -import os -import json -import asyncio - -class MessageStream: - """单个群组的消息流容器""" - def __init__(self, group_id: int, max_size: int = 1000): - self.group_id = group_id - self.messages = deque(maxlen=max_size) - self.max_size = max_size - self.last_save_time = time.time() - - # 确保日志目录存在 - self.log_dir = os.path.join("log", str(self.group_id)) - os.makedirs(self.log_dir, exist_ok=True) - - # 启动自动保存任务 - asyncio.create_task(self._auto_save()) - - async def _auto_save(self): - """每30秒自动保存一次消息记录""" - while True: - await asyncio.sleep(30) # 等待30秒 - await self.save_to_log() - - async def save_to_log(self): - """将消息保存到日志文件""" - try: - current_time = time.time() - # 只有有新消息时才保存 - if not self.messages or self.last_save_time == current_time: - return - - # 生成日志文件名 (使用当前日期) - date_str = time.strftime("%Y-%m-%d", time.localtime(current_time)) - log_file = os.path.join(self.log_dir, f"chat_{date_str}.log") - - # 获取需要保存的新消息 - new_messages = [ - msg for msg in self.messages - if msg.time > self.last_save_time - ] - - if not new_messages: - return - - # 将消息转换为可序列化的格式 - message_logs = [] - for msg in new_messages: - message_logs.append({ - "time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)), - "user_id": msg.user_id, - "user_nickname": msg.user_nickname, - "user_cardname": msg.user_cardname, - "message_id": msg.message_id, - "raw_message": msg.raw_message, - "processed_text": msg.processed_plain_text - }) - - # 追加写入日志文件 - with open(log_file, "a", encoding="utf-8") as f: - for log in message_logs: - f.write(json.dumps(log, ensure_ascii=False) + "\n") - - self.last_save_time = current_time - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}") - - def add_message(self, message: Message) -> None: - """按时间顺序添加新消息到队列 - - 使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。 - - Args: - message: Message对象,要添加的新消息 - """ - - # 空队列或消息应该添加到末尾的情况 - if (not self.messages or - message.time >= self.messages[-1].time): - self.messages.append(message) - return - - # 消息应该添加到开头的情况 - if message.time <= self.messages[0].time: - self.messages.appendleft(message) - return - - # 使用二分查找在现有队列中找到合适的插入位置 - left, right = 0, len(self.messages) - 1 - while left <= right: - mid = (left + right) // 2 - if self.messages[mid].time < message.time: - left = mid + 1 - else: - right = mid - 1 - - temp = list(self.messages) - temp.insert(left, message) - - # 如果超出最大长度,移除多余的消息 - if len(temp) > self.max_size: - temp = temp[-self.max_size:] - - # 重建队列 - self.messages = deque(temp, maxlen=self.max_size) - - async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]: - """从数据库中获取最近的消息记录 - - Args: - count: 需要获取的消息数量 - - Returns: - List[Message]: 最近的消息列表 - """ - try: - from ...common.database import Database - db = Database.get_instance() - - # 从数据库中查询最近的消息 - recent_messages = list(db.db.messages.find( - {"group_id": self.group_id}, - # { - # "time": 1, - # "user_id": 1, - # "user_nickname": 1, - # # "user_cardname": 1, - # "message_id": 1, - # "raw_message": 1, - # "processed_text": 1 - # } - ).sort("time", -1).limit(count)) - - if not recent_messages: - return [] - - # 转换为 Message 对象 - from .message import Message - messages = [] - for msg_data in recent_messages: - try: - msg = Message( - time=msg_data["time"], - user_id=msg_data["user_id"], - user_nickname=msg_data.get("user_nickname", ""), - user_cardname=msg_data.get("user_cardname", ""), - message_id=msg_data["message_id"], - raw_message=msg_data["raw_message"], - processed_plain_text=msg_data.get("processed_text", ""), - group_id=self.group_id - ) - messages.append(msg) - except KeyError: - print("[WARNING] 数据库中存在无效的消息") - continue - - return list(reversed(messages)) # 返回按时间正序的消息 - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}") - return [] - - def get_recent_messages(self, count: int = 10) -> List[Message]: - """获取最近的n条消息(从内存队列)""" - print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录") - return list(self.messages)[-count:] - - def get_messages_in_timerange(self, - start_time: Optional[float] = None, - end_time: Optional[float] = None) -> List[Message]: - """获取时间范围内的消息""" - if start_time is None: - start_time = time.time() - 3600 - if end_time is None: - end_time = time.time() - - return [ - msg for msg in self.messages - if start_time <= msg.time <= end_time - ] - - def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]: - """获取特定用户的最近消息""" - user_messages = [msg for msg in self.messages if msg.user_id == user_id] - return user_messages[-count:] - - def clear_old_messages(self, hours: int = 24) -> None: - """清理旧消息""" - cutoff_time = time.time() - (hours * 3600) - self.messages = deque( - [msg for msg in self.messages if msg.time > cutoff_time], - maxlen=self.max_size - ) - -class MessageStreamContainer: - """管理所有群组的消息流容器""" - def __init__(self, max_size: int = 1000): - self.streams: Dict[int, MessageStream] = {} - self.max_size = max_size - - async def save_all_logs(self): - """保存所有群组的消息日志""" - for stream in self.streams.values(): - await stream.save_to_log() - - def add_message(self, message: Message) -> None: - """添加消息到对应群组的消息流""" - if not message.group_id: - return - - if message.group_id not in self.streams: - self.streams[message.group_id] = MessageStream(message.group_id, self.max_size) - - self.streams[message.group_id].add_message(message) - - def get_stream(self, group_id: int) -> Optional[MessageStream]: - """获取特定群组的消息流""" - return self.streams.get(group_id) - - def get_all_streams(self) -> Dict[int, MessageStream]: - """获取所有群组的消息流""" - return self.streams - - def clear_old_messages(self, hours: int = 24) -> None: - """清理所有群组的旧消息""" - for stream in self.streams.values(): - stream.clear_old_messages(hours) - - def get_group_stats(self, group_id: int) -> Dict: - """获取群组的消息统计信息""" - stream = self.streams.get(group_id) - if not stream: - return { - "total_messages": 0, - "unique_users": 0, - "active_hours": [], - "most_active_user": None - } - - messages = stream.messages - user_counts = {} - hour_counts = {} - - for msg in messages: - user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1 - hour = datetime.fromtimestamp(msg.time).hour - hour_counts[hour] = hour_counts.get(hour, 0) + 1 - - most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None - active_hours = sorted( - hour_counts.items(), - key=lambda x: x[1], - reverse=True - )[:5] - - return { - "total_messages": len(messages), - "unique_users": len(user_counts), - "active_hours": active_hours, - "most_active_user": most_active_user - } - -# 创建全局实例 -message_stream_container = MessageStreamContainer() diff --git a/src/plugins/chat/del.message_visualizer.py b/src/plugins/chat/del.message_visualizer.py deleted file mode 100644 index 0469af8f6..000000000 --- a/src/plugins/chat/del.message_visualizer.py +++ /dev/null @@ -1,138 +0,0 @@ -import subprocess -import threading -import queue -import os -import time -from typing import Dict -from .message import Message_Thinking - -class MessageVisualizer: - def __init__(self): - self.process = None - self.message_queue = queue.Queue() - self.is_running = False - self.content_file = "message_queue_content.txt" - - def start(self): - if self.process is None: - # 创建用于显示的批处理文件 - with open("message_queue_window.bat", "w", encoding="utf-8") as f: - f.write('@echo off\n') - f.write('chcp 65001\n') # 设置UTF-8编码 - f.write('title Message Queue Visualizer\n') - f.write('echo Waiting for message queue updates...\n') - f.write(':loop\n') - f.write('if exist "queue_update.txt" (\n') - f.write(' type "queue_update.txt" > "message_queue_content.txt"\n') - f.write(' del "queue_update.txt"\n') - f.write(' cls\n') - f.write(' type "message_queue_content.txt"\n') - f.write(')\n') - f.write('timeout /t 1 /nobreak >nul\n') - f.write('goto loop\n') - - # 清空内容文件 - with open(self.content_file, "w", encoding="utf-8") as f: - f.write("") - - # 启动新窗口 - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - self.process = subprocess.Popen( - ['cmd', '/c', 'start', 'message_queue_window.bat'], - shell=True, - startupinfo=startupinfo - ) - self.is_running = True - - # 启动处理线程 - threading.Thread(target=self._process_messages, daemon=True).start() - - def _process_messages(self): - while self.is_running: - try: - # 获取新消息 - text = self.message_queue.get(timeout=1) - # 写入更新文件 - with open("queue_update.txt", "w", encoding="utf-8") as f: - f.write(text) - except queue.Empty: - continue - except Exception as e: - print(f"处理队列可视化内容时出错: {e}") - - def update_content(self, send_temp_container): - """更新显示内容""" - if not self.is_running: - return - - current_time = time.strftime("%Y-%m-%d %H:%M:%S") - display_text = f"Message Queue Status - {current_time}\n" - display_text += "=" * 50 + "\n\n" - - # 遍历所有群组的队列 - for group_id, queue in send_temp_container.temp_queues.items(): - display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n" - display_text += f"消息队列长度: {len(queue.messages)}\n" - display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n" - display_text += "\n消息队列内容:\n" - - # 显示队列中的消息 - if not queue.messages: - display_text += " [空队列]\n" - else: - for i, msg in enumerate(queue.messages): - msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time)) - display_text += f"\n--- 消息 {i+1} ---\n" - - if isinstance(msg, Message_Thinking): - display_text += f"类型: \033[1;33m思考中消息\033[0m\n" - display_text += f"时间: {msg_time}\n" - display_text += f"消息ID: {msg.message_id}\n" - display_text += f"群组: {msg.group_id}\n" - display_text += f"用户: {msg.user_nickname}({msg.user_id})\n" - display_text += f"内容: {msg.thinking_text}\n" - display_text += f"思考时间: {int(msg.thinking_time)}秒\n" - else: - display_text += f"类型: 普通消息\n" - display_text += f"时间: {msg_time}\n" - display_text += f"消息ID: {msg.message_id}\n" - display_text += f"群组: {msg.group_id}\n" - display_text += f"用户: {msg.user_nickname}({msg.user_id})\n" - if hasattr(msg, 'is_emoji') and msg.is_emoji: - display_text += f"内容: [表情包消息]\n" - else: - # 显示原始消息和处理后的消息 - display_text += f"原始内容: {msg.raw_message[:50]}...\n" - display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n" - - if msg.reply_message: - display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n" - - display_text += f"\n{'-' * 50}\n" - - # 添加统计信息 - display_text += "\n总体统计:\n" - display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n" - total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values()) - display_text += f"总消息数: {total_messages}\n" - thinking_messages = sum( - sum(1 for msg in q.messages if isinstance(msg, Message_Thinking)) - for q in send_temp_container.temp_queues.values() - ) - display_text += f"思考中消息数: {thinking_messages}\n" - - self.message_queue.put(display_text) - - def stop(self): - self.is_running = False - if self.process: - self.process.terminate() - self.process = None - # 清理文件 - for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]: - if os.path.exists(file): - os.remove(file) - -# 创建全局单例 -message_visualizer = MessageVisualizer() diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index f90889f77..ab8c5ee25 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -9,7 +9,7 @@ class WillingManager: async def _decay_reply_willing(self): """定期衰减回复意愿""" while True: - await asyncio.sleep(3) + await asyncio.sleep(5) for group_id in self.group_reply_willing: self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6) @@ -39,11 +39,11 @@ class WillingManager: if interested_rate > 0.65: print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") - current_willing += interested_rate-0.5 + current_willing += interested_rate-0.6 self.group_reply_willing[group_id] = min(current_willing, 3.0) - reply_probability = max((current_willing - 0.5) * 2, 0) + reply_probability = max((current_willing - 0.55) * 1.9, 0) if group_id not in config.talk_allowed_groups: current_willing = 0 reply_probability = 0 @@ -65,7 +65,7 @@ class WillingManager: """发送消息后提高群组的回复意愿""" current_willing = self.group_reply_willing.get(group_id, 0) if current_willing < 1: - self.group_reply_willing[group_id] = min(1, current_willing + 0.3) + self.group_reply_willing[group_id] = min(1, current_willing + 0.2) async def ensure_started(self): """确保衰减任务已启动""" diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index ddb11d574..fad3f5f30 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -22,63 +22,6 @@ from src.common.database import Database # 使用正确的导入语法 env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev') load_dotenv(env_path) -class LLMModel: - def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs): - self.model_name = model_name - self.params = kwargs - self.api_key = os.getenv("SILICONFLOW_KEY") - self.base_url = os.getenv("SILICONFLOW_BASE_URL") - - async def generate_response(self, prompt: str) -> Tuple[str, str]: - """根据输入的提示生成模型的响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - # 构建请求体 - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - "temperature": 0.5, - **self.params - } - - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - print(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - content = result["choices"][0]["message"]["content"] - reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") - await asyncio.sleep(wait_time) - else: - return f"请求失败: {str(e)}", "" - - return "达到最大重试次数,请求仍然失败", "" - class Memory_graph: def __init__(self): @@ -232,19 +175,10 @@ def main(): ) memory_graph = Memory_graph() - # 创建LLM模型实例 - memory_graph.load_graph_from_db() - # 展示两种不同的可视化方式 - print("\n按连接数量着色的图谱:") - # visualize_graph(memory_graph, color_by_memory=False) - visualize_graph_lite(memory_graph, color_by_memory=False) - print("\n按记忆数量着色的图谱:") - # visualize_graph(memory_graph, color_by_memory=True) - visualize_graph_lite(memory_graph, color_by_memory=True) - - # memory_graph.save_graph_to_db() + # 只显示一次优化后的图形 + visualize_graph_lite(memory_graph) while True: query = input("请输入新的查询概念(输入'退出'以结束):") @@ -327,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): nx.draw(G, pos, with_labels=True, node_color=node_colors, - node_size=2000, + node_size=200, font_size=10, font_family='SimHei', font_weight='bold') @@ -353,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal memory_items = H.nodes[node].get('memory_items', []) memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) degree = H.degree(node) - if memory_count <= 2 or degree <= 2: + if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2 nodes_to_remove.append(node) H.remove_nodes_from(nodes_to_remove) @@ -366,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal # 保存图到本地 nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 - # 根据连接条数或记忆数量设置节点颜色 + # 计算节点大小和颜色 node_colors = [] - nodes = list(H.nodes()) # 获取图中实际的节点列表 + node_sizes = [] + nodes = list(H.nodes()) - if color_by_memory: - # 计算每个节点的记忆数量 - memory_counts = [] - for node in nodes: - memory_items = H.nodes[node].get('memory_items', []) - if isinstance(memory_items, list): - count = len(memory_items) - else: - count = 1 if memory_items else 0 - memory_counts.append(count) - max_memories = max(memory_counts) if memory_counts else 1 + # 获取最大记忆数和最大度数用于归一化 + max_memories = 1 + max_degree = 1 + for node in nodes: + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + degree = H.degree(node) + max_memories = max(max_memories, memory_count) + max_degree = max(max_degree, degree) + + # 计算每个节点的大小和颜色 + for node in nodes: + # 计算节点大小(基于记忆数量) + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + # 使用指数函数使变化更明显 + ratio = memory_count / max_memories + size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显 + node_sizes.append(size) - for count in memory_counts: - # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 - if max_memories > 0: - intensity = min(1.0, count / max_memories) - color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 - else: - color = (0, 0, 1) # 如果没有记忆,则为蓝色 - node_colors.append(color) - else: - # 使用原来的连接数量着色方案 - max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 - for node in nodes: - degree = H.degree(node) - if max_degree > 0: - red = min(1.0, degree / max_degree) - blue = 1.0 - red - color = (red, 0, blue) - else: - color = (0, 0, 1) - node_colors.append(color) + # 计算节点颜色(基于连接数) + degree = H.degree(node) + # 红色分量随着度数增加而增加 + red = min(1.0, degree / max_degree) + # 蓝色分量随着度数减少而增加 + blue = 1.0 - red + color = (red, 0, blue) + node_colors.append(color) # 绘制图形 plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(H, k=1, iterations=50) + pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开 nx.draw(H, pos, with_labels=True, node_color=node_colors, - node_size=2000, + node_size=node_sizes, font_size=10, font_family='SimHei', - font_weight='bold') + font_weight='bold', + edge_color='gray', + width=0.5, + alpha=0.7) - title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') + title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数' plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index e0095dada..9ad740844 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -17,7 +17,12 @@ class Memory_graph: self.db = Database.get_instance() def connect_dot(self, concept1, concept2): - self.G.add_edge(concept1, concept2) + # 如果边已存在,增加 strength + if self.G.has_edge(concept1, concept2): + self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1 + else: + # 如果是新边,初始化 strength 为 1 + self.G.add_edge(concept1, concept2, strength=1) def add_dot(self, concept, memory): if concept in self.G: @@ -38,9 +43,7 @@ class Memory_graph: if concept in self.G: # 从图中获取节点数据 node_data = self.G.nodes[concept] - # print(node_data) - # 创建新的Memory_dot对象 - return concept,node_data + return concept, node_data return None def get_related_item(self, topic, depth=1): @@ -52,7 +55,6 @@ class Memory_graph: # 获取相邻节点 neighbors = list(self.G.neighbors(topic)) - # print(f"第一层: {topic}") # 获取当前节点的记忆项 node_data = self.get_dot(topic) @@ -69,7 +71,6 @@ class Memory_graph: if depth >= 2: # 获取相邻节点的记忆项 for neighbor in neighbors: - # print(f"第二层: {neighbor}") node_data = self.get_dot(neighbor) if node_data: concept, data = node_data @@ -87,79 +88,38 @@ class Memory_graph: # 返回所有节点对应的 Memory_dot 对象 return [self.get_dot(node) for node in self.G.nodes()] - def save_graph_to_db(self): - # 保存节点 - for node in self.G.nodes(data=True): - concept = node[0] - memory_items = node[1].get('memory_items', []) + def forget_topic(self, topic): + """随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点""" + if topic not in self.G: + return None - # 查找是否存在同名节点 - existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept}) - if existing_node: - # 如果存在,合并memory_items并去重 - existing_items = existing_node.get('memory_items', []) - if not isinstance(existing_items, list): - existing_items = [existing_items] if existing_items else [] - - # 合并并去重 - all_items = list(set(existing_items + memory_items)) - - # 更新节点 - self.db.db.graph_data.nodes.update_one( - {'concept': concept}, - {'$set': {'memory_items': all_items}} - ) - else: - # 如果不存在,创建新节点 - node_data = { - 'concept': concept, - 'memory_items': memory_items - } - self.db.db.graph_data.nodes.insert_one(node_data) + # 获取话题节点数据 + node_data = self.G.nodes[topic] - # 保存边 - for edge in self.G.edges(): - source, target = edge + # 如果节点存在memory_items + if 'memory_items' in node_data: + memory_items = node_data['memory_items'] - # 查找是否存在同样的边 - existing_edge = self.db.db.graph_data.edges.find_one({ - 'source': source, - 'target': target - }) - - if existing_edge: - # 如果存在,增加num属性 - num = existing_edge.get('num', 1) + 1 - self.db.db.graph_data.edges.update_one( - {'source': source, 'target': target}, - {'$set': {'num': num}} - ) - else: - # 如果不存在,创建新边 - edge_data = { - 'source': source, - 'target': target, - 'num': 1 - } - self.db.db.graph_data.edges.insert_one(edge_data) - - def load_graph_from_db(self): - # 清空当前图 - self.G.clear() - # 加载节点 - nodes = self.db.db.graph_data.nodes.find() - for node in nodes: - memory_items = node.get('memory_items', []) + # 确保memory_items是列表 if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - self.G.add_node(node['concept'], memory_items=memory_items) - # 加载边 - edges = self.db.db.graph_data.edges.find() - for edge in edges: - self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) - - - + + # 如果有记忆项可以删除 + if memory_items: + # 随机选择一个记忆项删除 + removed_item = random.choice(memory_items) + memory_items.remove(removed_item) + + # 更新节点的记忆项 + if memory_items: + self.G.nodes[topic]['memory_items'] = memory_items + else: + # 如果没有记忆项了,删除整个节点 + self.G.remove_node(topic) + + return removed_item + + return None # 海马体 @@ -169,23 +129,33 @@ class Hippocampus: self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5) self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5) + def calculate_node_hash(self, concept, memory_items): + """计算节点的特征值""" + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + sorted_items = sorted(memory_items) + content = f"{concept}:{'|'.join(sorted_items)}" + return hash(content) + + def calculate_edge_hash(self, source, target): + """计算边的特征值""" + nodes = sorted([source, target]) + return hash(f"{nodes[0]}:{nodes[1]}") + def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}): current_timestamp = datetime.datetime.now().timestamp() chat_text = [] #短期:1h 中期:4h 长期:24h for _ in range(time_frequency.get('near')): # 循环10次 random_time = current_timestamp - random.randint(1, 3600) # 随机时间 - # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) chat_text.append(chat_) for _ in range(time_frequency.get('mid')): # 循环10次 random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 - # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) chat_text.append(chat_) for _ in range(time_frequency.get('far')): # 循环10次 random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 - # print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}") chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) chat_text.append(chat_) return chat_text @@ -207,8 +177,8 @@ class Hippocampus: topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt) compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 return compressed_memory - - async def build_memory(self,chat_size=12): + + async def operation_build_memory(self,chat_size=12): #最近消息获取频率 time_frequency = {'near':1,'mid':2,'far':2} memory_sample = self.get_memory_sample(chat_size,time_frequency) @@ -236,7 +206,247 @@ class Hippocampus: self.memory_graph.connect_dot(split_topic, other_split_topic) else: print(f"空消息 跳过") - self.memory_graph.save_graph_to_db() + self.sync_memory_to_db() + + def sync_memory_to_db(self): + """检查并同步内存中的图结构与数据库""" + # 获取数据库中所有节点和内存中所有节点 + db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find()) + memory_nodes = list(self.memory_graph.G.nodes(data=True)) + + # 转换数据库节点为字典格式,方便查找 + db_nodes_dict = {node['concept']: node for node in db_nodes} + + # 检查并更新节点 + for concept, data in memory_nodes: + memory_items = data.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 计算内存中节点的特征值 + memory_hash = self.calculate_node_hash(concept, memory_items) + + if concept not in db_nodes_dict: + # 数据库中缺少的节点,添加 + node_data = { + 'concept': concept, + 'memory_items': memory_items, + 'hash': memory_hash + } + self.memory_graph.db.db.graph_data.nodes.insert_one(node_data) + else: + # 获取数据库中节点的特征值 + db_node = db_nodes_dict[concept] + db_hash = db_node.get('hash', None) + + # 如果特征值不同,则更新节点 + if db_hash != memory_hash: + self.memory_graph.db.db.graph_data.nodes.update_one( + {'concept': concept}, + {'$set': { + 'memory_items': memory_items, + 'hash': memory_hash + }} + ) + + # 检查并删除数据库中多余的节点 + memory_concepts = set(node[0] for node in memory_nodes) + for db_node in db_nodes: + if db_node['concept'] not in memory_concepts: + self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) + + # 处理边的信息 + db_edges = list(self.memory_graph.db.db.graph_data.edges.find()) + memory_edges = list(self.memory_graph.G.edges()) + + # 创建边的哈希值字典 + db_edge_dict = {} + for edge in db_edges: + edge_hash = self.calculate_edge_hash(edge['source'], edge['target']) + db_edge_dict[(edge['source'], edge['target'])] = { + 'hash': edge_hash, + 'strength': edge.get('strength', 1) + } + + # 检查并更新边 + for source, target in memory_edges: + edge_hash = self.calculate_edge_hash(source, target) + edge_key = (source, target) + strength = self.memory_graph.G[source][target].get('strength', 1) + + if edge_key not in db_edge_dict: + # 添加新边 + edge_data = { + 'source': source, + 'target': target, + 'strength': strength, + 'hash': edge_hash + } + self.memory_graph.db.db.graph_data.edges.insert_one(edge_data) + else: + # 检查边的特征值是否变化 + if db_edge_dict[edge_key]['hash'] != edge_hash: + self.memory_graph.db.db.graph_data.edges.update_one( + {'source': source, 'target': target}, + {'$set': { + 'hash': edge_hash, + 'strength': strength + }} + ) + + # 删除多余的边 + memory_edge_set = set(memory_edges) + for edge_key in db_edge_dict: + if edge_key not in memory_edge_set: + source, target = edge_key + self.memory_graph.db.db.graph_data.edges.delete_one({ + 'source': source, + 'target': target + }) + + def sync_memory_from_db(self): + """从数据库同步数据到内存中的图结构""" + # 清空当前图 + self.memory_graph.G.clear() + + # 从数据库加载所有节点 + nodes = self.memory_graph.db.db.graph_data.nodes.find() + for node in nodes: + concept = node['concept'] + memory_items = node.get('memory_items', []) + # 确保memory_items是列表 + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + # 添加节点到图中 + self.memory_graph.G.add_node(concept, memory_items=memory_items) + + # 从数据库加载所有边 + edges = self.memory_graph.db.db.graph_data.edges.find() + for edge in edges: + source = edge['source'] + target = edge['target'] + strength = edge.get('strength', 1) # 获取 strength,默认为 1 + # 只有当源节点和目标节点都存在时才添加边 + if source in self.memory_graph.G and target in self.memory_graph.G: + self.memory_graph.G.add_edge(source, target, strength=strength) + + async def operation_forget_topic(self, percentage=0.1): + """随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘""" + # 获取所有节点 + all_nodes = list(self.memory_graph.G.nodes()) + # 计算要检查的节点数量 + check_count = max(1, int(len(all_nodes) * percentage)) + # 随机选择节点 + nodes_to_check = random.sample(all_nodes, check_count) + + forgotten_nodes = [] + for node in nodes_to_check: + # 获取节点的连接数 + connections = self.memory_graph.G.degree(node) + + # 获取节点的内容条数 + memory_items = self.memory_graph.G.nodes[node].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + + # 检查连接强度 + weak_connections = True + if connections > 1: # 只有当连接数大于1时才检查强度 + for neighbor in self.memory_graph.G.neighbors(node): + strength = self.memory_graph.G[node][neighbor].get('strength', 1) + if strength > 2: + weak_connections = False + break + + # 如果满足遗忘条件 + if (connections <= 1 and weak_connections) or content_count <= 2: + removed_item = self.memory_graph.forget_topic(node) + if removed_item: + forgotten_nodes.append((node, removed_item)) + print(f"遗忘节点 {node} 的记忆: {removed_item}") + + # 同步到数据库 + if forgotten_nodes: + self.sync_memory_to_db() + print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆") + else: + print("本次检查没有节点满足遗忘条件") + + async def merge_memory(self, topic): + """ + 对指定话题的记忆进行合并压缩 + + Args: + topic: 要合并的话题节点 + """ + # 获取节点的记忆项 + memory_items = self.memory_graph.G.nodes[topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 如果记忆项不足,直接返回 + if len(memory_items) < 10: + return + + # 随机选择10条记忆 + selected_memories = random.sample(memory_items, 10) + + # 拼接成文本 + merged_text = "\n".join(selected_memories) + print(f"\n[合并记忆] 话题: {topic}") + print(f"选择的记忆:\n{merged_text}") + + # 使用memory_compress生成新的压缩记忆 + compressed_memories = await self.memory_compress(merged_text, 0.1) + + # 从原记忆列表中移除被选中的记忆 + for memory in selected_memories: + memory_items.remove(memory) + + # 添加新的压缩记忆 + for _, compressed_memory in compressed_memories: + memory_items.append(compressed_memory) + print(f"添加压缩记忆: {compressed_memory}") + + # 更新节点的记忆项 + self.memory_graph.G.nodes[topic]['memory_items'] = memory_items + print(f"完成记忆合并,当前记忆数量: {len(memory_items)}") + + async def operation_merge_memory(self, percentage=0.1): + """ + 随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并 + + Args: + percentage: 要检查的节点比例,默认为0.1(10%) + """ + # 获取所有节点 + all_nodes = list(self.memory_graph.G.nodes()) + # 计算要检查的节点数量 + check_count = max(1, int(len(all_nodes) * percentage)) + # 随机选择节点 + nodes_to_check = random.sample(all_nodes, check_count) + + merged_nodes = [] + for node in nodes_to_check: + # 获取节点的内容条数 + memory_items = self.memory_graph.G.nodes[node].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + + # 如果内容数量超过100,进行合并 + if content_count > 100: + print(f"\n检查节点: {node}, 当前记忆数量: {content_count}") + await self.merge_memory(node) + merged_nodes.append(node) + + # 同步到数据库 + if merged_nodes: + self.sync_memory_to_db() + print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点") + else: + print("\n本次检查没有需要合并的节点") def segment_text(text): @@ -268,10 +478,10 @@ Database.initialize( ) #创建记忆图 memory_graph = Memory_graph() -#加载数据库中存储的记忆图 -memory_graph.load_graph_from_db() #创建海马体 hippocampus = Hippocampus(memory_graph) +#从数据库加载记忆图 +hippocampus.sync_memory_from_db() end_time = time.time() print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") \ No newline at end of file diff --git a/src/plugins/memory_system/memory_make.py b/src/plugins/memory_system/memory_make.py deleted file mode 100644 index d1757b246..000000000 --- a/src/plugins/memory_system/memory_make.py +++ /dev/null @@ -1,463 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -import jieba -import networkx as nx -import matplotlib.pyplot as plt -import math -from collections import Counter -import datetime -import random -import time -import os -# from chat.config import global_config -sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 -from src.common.database import Database # 使用正确的导入语法 -from src.plugins.memory_system.llm_module import LLMModel - -def calculate_information_content(text): - """计算文本的信息量(熵)""" - # 统计字符频率 - char_count = Counter(text) - total_chars = len(text) - - # 计算熵 - entropy = 0 - for count in char_count.values(): - probability = count / total_chars - entropy -= probability * math.log2(probability) - - return entropy - -def get_cloest_chat_from_db(db, length: int, timestamp: str): - """从数据库中获取最接近指定时间戳的聊天记录""" - chat_text = '' - closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) - - if closest_record: - closest_time = closest_record['time'] - group_id = closest_record['group_id'] # 获取groupid - # 获取该时间戳之后的length条消息,且groupid相同 - chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) - for record in chat_record: - time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) - chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' - return chat_text - - return '' - -class Memory_graph: - def __init__(self): - self.G = nx.Graph() # 使用 networkx 的图结构 - self.db = Database.get_instance() - - def connect_dot(self, concept1, concept2): - self.G.add_edge(concept1, concept2) - - def add_dot(self, concept, memory): - if concept in self.G: - # 如果节点已存在,将新记忆添加到现有列表中 - if 'memory_items' in self.G.nodes[concept]: - if not isinstance(self.G.nodes[concept]['memory_items'], list): - # 如果当前不是列表,将其转换为列表 - self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']] - self.G.nodes[concept]['memory_items'].append(memory) - else: - self.G.nodes[concept]['memory_items'] = [memory] - else: - # 如果是新节点,创建新的记忆列表 - self.G.add_node(concept, memory_items=[memory]) - - def get_dot(self, concept): - # 检查节点是否存在于图中 - if concept in self.G: - # 从图中获取节点数据 - node_data = self.G.nodes[concept] - # print(node_data) - # 创建新的Memory_dot对象 - return concept,node_data - return None - - def get_related_item(self, topic, depth=1): - if topic not in self.G: - return [], [] - - first_layer_items = [] - second_layer_items = [] - - # 获取相邻节点 - neighbors = list(self.G.neighbors(topic)) - # print(f"第一层: {topic}") - - # 获取当前节点的记忆项 - node_data = self.get_dot(topic) - if node_data: - concept, data = node_data - if 'memory_items' in data: - memory_items = data['memory_items'] - if isinstance(memory_items, list): - first_layer_items.extend(memory_items) - else: - first_layer_items.append(memory_items) - - # 只在depth=2时获取第二层记忆 - if depth >= 2: - # 获取相邻节点的记忆项 - for neighbor in neighbors: - # print(f"第二层: {neighbor}") - node_data = self.get_dot(neighbor) - if node_data: - concept, data = node_data - if 'memory_items' in data: - memory_items = data['memory_items'] - if isinstance(memory_items, list): - second_layer_items.extend(memory_items) - else: - second_layer_items.append(memory_items) - - return first_layer_items, second_layer_items - - def store_memory(self): - for node in self.G.nodes(): - dot_data = { - "concept": node - } - self.db.db.store_memory_dots.insert_one(dot_data) - - @property - def dots(self): - # 返回所有节点对应的 Memory_dot 对象 - return [self.get_dot(node) for node in self.G.nodes()] - - - def get_random_chat_from_db(self, length: int, timestamp: str): - # 从数据库中根据时间戳获取离其最近的聊天记录 - chat_text = '' - closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 - - # print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") - - if closest_record: - closest_time = closest_record['time'] - group_id = closest_record['group_id'] # 获取groupid - # 获取该时间戳之后的length条消息,且groupid相同 - chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) - for record in chat_record: - if record: - time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) - try: - displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"]) - except: - displayname=record["user_nickname"] or "用户" + str(record["user_id"]) - chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 - return chat_text - - return [] # 如果没有找到记录,返回空列表 - - def save_graph_to_db(self): - # 保存节点 - for node in self.G.nodes(data=True): - concept = node[0] - memory_items = node[1].get('memory_items', []) - - # 查找是否存在同名节点 - existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept}) - if existing_node: - # 如果存在,合并memory_items并去重 - existing_items = existing_node.get('memory_items', []) - if not isinstance(existing_items, list): - existing_items = [existing_items] if existing_items else [] - - # 合并并去重 - all_items = list(set(existing_items + memory_items)) - - # 更新节点 - self.db.db.graph_data.nodes.update_one( - {'concept': concept}, - {'$set': {'memory_items': all_items}} - ) - else: - # 如果不存在,创建新节点 - node_data = { - 'concept': concept, - 'memory_items': memory_items - } - self.db.db.graph_data.nodes.insert_one(node_data) - - # 保存边 - for edge in self.G.edges(): - source, target = edge - - # 查找是否存在同样的边 - existing_edge = self.db.db.graph_data.edges.find_one({ - 'source': source, - 'target': target - }) - - if existing_edge: - # 如果存在,增加num属性 - num = existing_edge.get('num', 1) + 1 - self.db.db.graph_data.edges.update_one( - {'source': source, 'target': target}, - {'$set': {'num': num}} - ) - else: - # 如果不存在,创建新边 - edge_data = { - 'source': source, - 'target': target, - 'num': 1 - } - self.db.db.graph_data.edges.insert_one(edge_data) - - def load_graph_from_db(self): - # 清空当前图 - self.G.clear() - # 加载节点 - nodes = self.db.db.graph_data.nodes.find() - for node in nodes: - memory_items = node.get('memory_items', []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - self.G.add_node(node['concept'], memory_items=memory_items) - # 加载边 - edges = self.db.db.graph_data.edges.find() - for edge in edges: - self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1)) - -# 海马体 -class Hippocampus: - def __init__(self,memory_graph:Memory_graph): - self.memory_graph = memory_graph - self.llm_model = LLMModel() - self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") - - def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}): - current_timestamp = datetime.datetime.now().timestamp() - chat_text = [] - #短期:1h 中期:4h 长期:24h - for _ in range(time_frequency.get('near')): # 循环10次 - random_time = current_timestamp - random.randint(1, 3600) # 随机时间 - chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) - chat_text.append(chat_) - for _ in range(time_frequency.get('mid')): # 循环10次 - random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 - chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) - chat_text.append(chat_) - for _ in range(time_frequency.get('far')): # 循环10次 - random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 - chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) - chat_text.append(chat_) - return chat_text - - def build_memory(self,chat_size=12): - #最近消息获取频率 - time_frequency = {'near':1,'mid':2,'far':2} - memory_sample = self.get_memory_sample(chat_size,time_frequency) - - #加载进度可视化 - for i, input_text in enumerate(memory_sample, 1): - progress = (i / len(memory_sample)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(memory_sample)) - bar = '█' * filled_length + '-' * (bar_length - filled_length) - print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") - # print(f"第{i}条消息: {input_text}") - if input_text: - # 生成压缩后记忆 - first_memory = set() - first_memory = self.memory_compress(input_text, 2.5) - #将记忆加入到图谱中 - for topic, memory in first_memory: - topics = segment_text(topic) - print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}") - for split_topic in topics: - self.memory_graph.add_dot(split_topic,memory) - for split_topic in topics: - for other_split_topic in topics: - if split_topic != other_split_topic: - self.memory_graph.connect_dot(split_topic, other_split_topic) - else: - print(f"空消息 跳过") - - self.memory_graph.save_graph_to_db() - - def memory_compress(self, input_text, rate=1): - information_content = calculate_information_content(input_text) - print(f"文本的信息量(熵): {information_content:.4f} bits") - topic_num = max(1, min(5, int(information_content * rate / 4))) - topic_prompt = find_topic(input_text, topic_num) - topic_response = self.llm_model.generate_response(topic_prompt) - # 检查 topic_response 是否为元组 - if isinstance(topic_response, tuple): - topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 - else: - topics = topic_response.split(",") - compressed_memory = set() - for topic in topics: - topic_what_prompt = topic_what(input_text,topic) - topic_what_response = self.llm_model_small.generate_response(topic_what_prompt) - compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储 - return compressed_memory - -def segment_text(text): - seg_text = list(jieba.cut(text)) - return seg_text - -def find_topic(text, topic_num): - prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。' - return prompt - -def topic_what(text, topic): - prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' - return prompt - -def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False): - # 设置中文字体 - plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 - plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 - - G = memory_graph.G - - # 创建一个新图用于可视化 - H = G.copy() - - # 移除只有一条记忆的节点和连接数少于3的节点 - nodes_to_remove = [] - for node in H.nodes(): - memory_items = H.nodes[node].get('memory_items', []) - memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) - degree = H.degree(node) - if memory_count <= 1 or degree <= 2: - nodes_to_remove.append(node) - - H.remove_nodes_from(nodes_to_remove) - - # 如果过滤后没有节点,则返回 - if len(H.nodes()) == 0: - print("过滤后没有符合条件的节点可显示") - return - - # 保存图到本地 - nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 - - # 根据连接条数或记忆数量设置节点颜色 - node_colors = [] - nodes = list(H.nodes()) # 获取图中实际的节点列表 - - if color_by_memory: - # 计算每个节点的记忆数量 - memory_counts = [] - for node in nodes: - memory_items = H.nodes[node].get('memory_items', []) - if isinstance(memory_items, list): - count = len(memory_items) - else: - count = 1 if memory_items else 0 - memory_counts.append(count) - max_memories = max(memory_counts) if memory_counts else 1 - - for count in memory_counts: - # 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少 - if max_memories > 0: - intensity = min(1.0, count / max_memories) - color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色 - else: - color = (0, 0, 1) # 如果没有记忆,则为蓝色 - node_colors.append(color) - else: - # 使用原来的连接数量着色方案 - max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1 - for node in nodes: - degree = H.degree(node) - if max_degree > 0: - red = min(1.0, degree / max_degree) - blue = 1.0 - red - color = (red, 0, blue) - else: - color = (0, 0, 1) - node_colors.append(color) - - # 绘制图形 - plt.figure(figsize=(12, 8)) - pos = nx.spring_layout(H, k=1, iterations=50) - nx.draw(H, pos, - with_labels=True, - node_color=node_colors, - node_size=2000, - font_size=10, - font_family='SimHei', - font_weight='bold') - - title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色') - plt.title(title, fontsize=16, fontfamily='SimHei') - plt.show() - -def main(): - # 初始化数据库 - Database.initialize( - host= os.getenv("MONGODB_HOST"), - port= int(os.getenv("MONGODB_PORT")), - db_name= os.getenv("DATABASE_NAME"), - username= os.getenv("MONGODB_USERNAME"), - password= os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE") - ) - - start_time = time.time() - - # 创建记忆图 - memory_graph = Memory_graph() - # 加载数据库中存储的记忆图 - memory_graph.load_graph_from_db() - # 创建海马体 - hippocampus = Hippocampus(memory_graph) - - end_time = time.time() - print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") - - # 构建记忆 - hippocampus.build_memory(chat_size=25) - - # 展示两种不同的可视化方式 - print("\n按连接数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=False) - - print("\n按记忆数量着色的图谱:") - visualize_graph(memory_graph, color_by_memory=True) - - # 交互式查询 - while True: - query = input("请输入新的查询概念(输入'退出'以结束):") - if query.lower() == '退出': - break - items_list = memory_graph.get_related_item(query) - if items_list: - for memory_item in items_list: - print(memory_item) - else: - print("未找到相关记忆。") - - while True: - query = input("请输入问题:") - - if query.lower() == '退出': - break - - topic_prompt = find_topic(query, 3) - topic_response = hippocampus.llm_model.generate_response(topic_prompt) - # 检查 topic_response 是否为元组 - if isinstance(topic_response, tuple): - topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串 - else: - topics = topic_response.split(",") - print(topics) - - for keyword in topics: - items_list = memory_graph.get_related_item(keyword) - if items_list: - print(items_list) - -if __name__ == "__main__": - main() - - diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py new file mode 100644 index 000000000..66933dd04 --- /dev/null +++ b/src/plugins/memory_system/memory_manual_build.py @@ -0,0 +1,805 @@ +# -*- coding: utf-8 -*- +import sys +import jieba +import networkx as nx +import matplotlib.pyplot as plt +import math +from collections import Counter +import datetime +import random +import time +import os +from dotenv import load_dotenv +import pymongo +from loguru import logger +from pathlib import Path +from snownlp import SnowNLP +# from chat.config import global_config +sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 +from src.common.database import Database +from src.plugins.memory_system.offline_llm import LLMModel + +# 获取当前文件的目录 +current_dir = Path(__file__).resolve().parent +# 获取项目根目录(上三层目录) +project_root = current_dir.parent.parent.parent +# env.dev文件路径 +env_path = project_root / ".env.dev" + +# 加载环境变量 +if env_path.exists(): + logger.info(f"从 {env_path} 加载环境变量") + load_dotenv(env_path) +else: + logger.warning(f"未找到环境变量文件: {env_path}") + logger.info("将使用默认配置") + +class Database: + _instance = None + db = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + if not Database.db: + Database.initialize( + host=os.getenv("MONGODB_HOST"), + port=int(os.getenv("MONGODB_PORT")), + db_name=os.getenv("DATABASE_NAME"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE") + ) + + @classmethod + def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"): + try: + if username and password: + uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}" + else: + uri = f"mongodb://{host}:{port}" + + client = pymongo.MongoClient(uri) + cls.db = client[db_name] + # 测试连接 + client.server_info() + logger.success("MongoDB连接成功!") + + except Exception as e: + logger.error(f"初始化MongoDB失败: {str(e)}") + raise + + + +def calculate_information_content(text): + """计算文本的信息量(熵)""" + char_count = Counter(text) + total_chars = len(text) + + entropy = 0 + for count in char_count.values(): + probability = count / total_chars + entropy -= probability * math.log2(probability) + + return entropy + +def get_cloest_chat_from_db(db, length: int, timestamp: str): + """从数据库中获取最接近指定时间戳的聊天记录""" + chat_text = '' + closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) + + if closest_record: + closest_time = closest_record['time'] + group_id = closest_record['group_id'] # 获取groupid + # 获取该时间戳之后的length条消息,且groupid相同 + chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) + for record in chat_record: + chat_text += record["detailed_plain_text"] + return chat_text + + return '' + +class Memory_graph: + def __init__(self): + self.G = nx.Graph() # 使用 networkx 的图结构 + self.db = Database.get_instance() + + def connect_dot(self, concept1, concept2): + # 如果边已存在,增加 strength + if self.G.has_edge(concept1, concept2): + self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1 + else: + # 如果是新边,初始化 strength 为 1 + self.G.add_edge(concept1, concept2, strength=1) + + def add_dot(self, concept, memory): + if concept in self.G: + # 如果节点已存在,将新记忆添加到现有列表中 + if 'memory_items' in self.G.nodes[concept]: + if not isinstance(self.G.nodes[concept]['memory_items'], list): + # 如果当前不是列表,将其转换为列表 + self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']] + self.G.nodes[concept]['memory_items'].append(memory) + else: + self.G.nodes[concept]['memory_items'] = [memory] + else: + # 如果是新节点,创建新的记忆列表 + self.G.add_node(concept, memory_items=[memory]) + + def get_dot(self, concept): + # 检查节点是否存在于图中 + if concept in self.G: + # 从图中获取节点数据 + node_data = self.G.nodes[concept] + return concept, node_data + return None + + def get_related_item(self, topic, depth=1): + if topic not in self.G: + return [], [] + + first_layer_items = [] + second_layer_items = [] + + # 获取相邻节点 + neighbors = list(self.G.neighbors(topic)) + + # 获取当前节点的记忆项 + node_data = self.get_dot(topic) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + first_layer_items.extend(memory_items) + else: + first_layer_items.append(memory_items) + + # 只在depth=2时获取第二层记忆 + if depth >= 2: + # 获取相邻节点的记忆项 + for neighbor in neighbors: + node_data = self.get_dot(neighbor) + if node_data: + concept, data = node_data + if 'memory_items' in data: + memory_items = data['memory_items'] + if isinstance(memory_items, list): + second_layer_items.extend(memory_items) + else: + second_layer_items.append(memory_items) + + return first_layer_items, second_layer_items + + @property + def dots(self): + # 返回所有节点对应的 Memory_dot 对象 + return [self.get_dot(node) for node in self.G.nodes()] + +# 海马体 +class Hippocampus: + def __init__(self, memory_graph: Memory_graph): + self.memory_graph = memory_graph + self.llm_model = LLMModel() + self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5") + + def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}): + current_timestamp = datetime.datetime.now().timestamp() + chat_text = [] + #短期:1h 中期:4h 长期:24h + for _ in range(time_frequency.get('near')): # 循环10次 + random_time = current_timestamp - random.randint(1, 3600) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('mid')): # 循环10次 + random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + for _ in range(time_frequency.get('far')): # 循环10次 + random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间 + chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + chat_text.append(chat_) + return chat_text + + def calculate_topic_num(self,text, compress_rate): + """计算文本的话题数量""" + information_content = calculate_information_content(text) + topic_by_length = text.count('\n')*compress_rate + topic_by_information_content = max(1, min(5, int((information_content-3) * 2))) + topic_num = int((topic_by_length + topic_by_information_content)/2) + print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}") + return topic_num + + async def memory_compress(self, input_text, compress_rate=0.1): + print(input_text) + + #获取topics + topic_num = self.calculate_topic_num(input_text, compress_rate) + topics_response = await self.llm_model_small.generate_response_async(self.find_topic_llm(input_text, topic_num)) + topics = topics_response[0].split(",") + print(f"话题: {topics}") + + # 创建所有话题的请求任务 + tasks = [] + for topic in topics: + topic_what_prompt = self.topic_what(input_text, topic) + # 创建异步任务 + task = self.llm_model_small.generate_response_async(topic_what_prompt) + tasks.append((topic.strip(), task)) + + # 等待所有任务完成 + compressed_memory = set() + for topic, task in tasks: + response = await task + if response: + compressed_memory.add((topic, response[0])) + + return compressed_memory + + async def operation_build_memory(self, chat_size=12): + #最近消息获取频率 + time_frequency = {'near':1,'mid':2,'far':2} + memory_sample = self.get_memory_sample(chat_size,time_frequency) + + for i, input_text in enumerate(memory_sample, 1): + #加载进度可视化 + progress = (i / len(memory_sample)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_sample)) + bar = '█' * filled_length + '-' * (bar_length - filled_length) + print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") + + if input_text: + # 生成压缩后记忆 ,表现为 (话题,记忆) 的元组 + compressed_memory = set() + compress_rate = 0.15 + compressed_memory = await self.memory_compress(input_text,compress_rate) + print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}") + + #将记忆加入到图谱中 + for topic, memory in compressed_memory: + # 将jieba分词结果转换为列表以便多次使用 + topics = list(jieba.cut(topic)) + print(f"\033[1;34m话题\033[0m: {topic}") + print(f"\033[1;34m分词结果\033[0m: {topics}") + print(f"\033[1;34m记忆\033[0m: {memory}") + + # 如果分词结果少于2个词,跳过连接 + if len(topics) < 2: + print(f"\033[1;31m分词结果少于2个词,跳过连接\033[0m") + # 仍然添加单个节点 + for split_topic in topics: + self.memory_graph.add_dot(split_topic, memory) + continue + + # 先添加所有节点 + for split_topic in topics: + print(f"\033[1;32m添加节点\033[0m: {split_topic}") + self.memory_graph.add_dot(split_topic, memory) + + # 再添加节点之间的连接 + for i, split_topic in enumerate(topics): + for j, other_split_topic in enumerate(topics): + if i < j: # 只连接一次,避免重复连接 + print(f"\033[1;32m连接节点\033[0m: {split_topic} 和 {other_split_topic}") + self.memory_graph.connect_dot(split_topic, other_split_topic) + else: + print(f"空消息 跳过") + + # 每处理完一条消息就同步一次到数据库 + self.sync_memory_to_db_2() + + def sync_memory_from_db(self): + """ + 从数据库同步数据到内存中的图结构 + 将清空当前内存中的图,并从数据库重新加载所有节点和边 + """ + # 清空当前图 + self.memory_graph.G.clear() + + # 从数据库加载所有节点 + nodes = self.memory_graph.db.db.graph_data.nodes.find() + for node in nodes: + concept = node['concept'] + memory_items = node.get('memory_items', []) + # 确保memory_items是列表 + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + # 添加节点到图中 + self.memory_graph.G.add_node(concept, memory_items=memory_items) + + # 从数据库加载所有边 + edges = self.memory_graph.db.db.graph_data.edges.find() + for edge in edges: + source = edge['source'] + target = edge['target'] + strength = edge.get('strength', 1) # 获取 strength,默认为 1 + # 只有当源节点和目标节点都存在时才添加边 + if source in self.memory_graph.G and target in self.memory_graph.G: + self.memory_graph.G.add_edge(source, target, strength=strength) + + logger.success("从数据库同步记忆图谱完成") + + def calculate_node_hash(self, concept, memory_items): + """ + 计算节点的特征值 + """ + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + # 将记忆项排序以确保相同内容生成相同的哈希值 + sorted_items = sorted(memory_items) + # 组合概念和记忆项生成特征值 + content = f"{concept}:{'|'.join(sorted_items)}" + return hash(content) + + def calculate_edge_hash(self, source, target): + """ + 计算边的特征值 + """ + # 对源节点和目标节点排序以确保相同的边生成相同的哈希值 + nodes = sorted([source, target]) + return hash(f"{nodes[0]}:{nodes[1]}") + + def sync_memory_to_db_2(self): + """ + 检查并同步内存中的图结构与数据库 + 使用特征值(哈希值)快速判断是否需要更新 + """ + # 获取数据库中所有节点和内存中所有节点 + db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find()) + memory_nodes = list(self.memory_graph.G.nodes(data=True)) + + # 转换数据库节点为字典格式,方便查找 + db_nodes_dict = {node['concept']: node for node in db_nodes} + + # 检查并更新节点 + for concept, data in memory_nodes: + memory_items = data.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 计算内存中节点的特征值 + memory_hash = self.calculate_node_hash(concept, memory_items) + + if concept not in db_nodes_dict: + # 数据库中缺少的节点,添加 + logger.info(f"添加新节点: {concept}") + node_data = { + 'concept': concept, + 'memory_items': memory_items, + 'hash': memory_hash + } + self.memory_graph.db.db.graph_data.nodes.insert_one(node_data) + else: + # 获取数据库中节点的特征值 + db_node = db_nodes_dict[concept] + db_hash = db_node.get('hash', None) + + # 如果特征值不同,则更新节点 + if db_hash != memory_hash: + logger.info(f"更新节点内容: {concept}") + self.memory_graph.db.db.graph_data.nodes.update_one( + {'concept': concept}, + {'$set': { + 'memory_items': memory_items, + 'hash': memory_hash + }} + ) + + # 检查并删除数据库中多余的节点 + memory_concepts = set(node[0] for node in memory_nodes) + for db_node in db_nodes: + if db_node['concept'] not in memory_concepts: + logger.info(f"删除多余节点: {db_node['concept']}") + self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) + + # 处理边的信息 + db_edges = list(self.memory_graph.db.db.graph_data.edges.find()) + memory_edges = list(self.memory_graph.G.edges()) + + # 创建边的哈希值字典 + db_edge_dict = {} + for edge in db_edges: + edge_hash = self.calculate_edge_hash(edge['source'], edge['target']) + db_edge_dict[(edge['source'], edge['target'])] = { + 'hash': edge_hash, + 'num': edge.get('num', 1) + } + + # 检查并更新边 + for source, target in memory_edges: + edge_hash = self.calculate_edge_hash(source, target) + edge_key = (source, target) + + if edge_key not in db_edge_dict: + # 添加新边 + logger.info(f"添加新边: {source} - {target}") + edge_data = { + 'source': source, + 'target': target, + 'num': 1, + 'hash': edge_hash + } + self.memory_graph.db.db.graph_data.edges.insert_one(edge_data) + else: + # 检查边的特征值是否变化 + if db_edge_dict[edge_key]['hash'] != edge_hash: + logger.info(f"更新边: {source} - {target}") + self.memory_graph.db.db.graph_data.edges.update_one( + {'source': source, 'target': target}, + {'$set': {'hash': edge_hash}} + ) + + # 删除多余的边 + memory_edge_set = set(memory_edges) + for edge_key in db_edge_dict: + if edge_key not in memory_edge_set: + source, target = edge_key + logger.info(f"删除多余边: {source} - {target}") + self.memory_graph.db.db.graph_data.edges.delete_one({ + 'source': source, + 'target': target + }) + + logger.success("完成记忆图谱与数据库的差异同步") + + def find_topic_llm(self,text, topic_num): + prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。' + return prompt + + def topic_what(self,text, topic): + prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' + return prompt + + def remove_node_from_db(self, topic): + """ + 从数据库中删除指定节点及其相关的边 + + Args: + topic: 要删除的节点概念 + """ + # 删除节点 + self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic}) + # 删除所有涉及该节点的边 + self.memory_graph.db.db.graph_data.edges.delete_many({ + '$or': [ + {'source': topic}, + {'target': topic} + ] + }) + + def forget_topic(self, topic): + """ + 随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点 + 只在内存中的图上操作,不直接与数据库交互 + + Args: + topic: 要删除记忆的话题 + + Returns: + removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None + """ + if topic not in self.memory_graph.G: + return None + + # 获取话题节点数据 + node_data = self.memory_graph.G.nodes[topic] + + # 如果节点存在memory_items + if 'memory_items' in node_data: + memory_items = node_data['memory_items'] + + # 确保memory_items是列表 + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 如果有记忆项可以删除 + if memory_items: + # 随机选择一个记忆项删除 + removed_item = random.choice(memory_items) + memory_items.remove(removed_item) + + # 更新节点的记忆项 + if memory_items: + self.memory_graph.G.nodes[topic]['memory_items'] = memory_items + else: + # 如果没有记忆项了,删除整个节点 + self.memory_graph.G.remove_node(topic) + + return removed_item + + return None + + async def operation_forget_topic(self, percentage=0.1): + """ + 随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘 + + Args: + percentage: 要检查的节点比例,默认为0.1(10%) + """ + # 获取所有节点 + all_nodes = list(self.memory_graph.G.nodes()) + # 计算要检查的节点数量 + check_count = max(1, int(len(all_nodes) * percentage)) + # 随机选择节点 + nodes_to_check = random.sample(all_nodes, check_count) + + forgotten_nodes = [] + for node in nodes_to_check: + # 获取节点的连接数 + connections = self.memory_graph.G.degree(node) + + # 获取节点的内容条数 + memory_items = self.memory_graph.G.nodes[node].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + + # 检查连接强度 + weak_connections = True + if connections > 1: # 只有当连接数大于1时才检查强度 + for neighbor in self.memory_graph.G.neighbors(node): + strength = self.memory_graph.G[node][neighbor].get('strength', 1) + if strength > 2: + weak_connections = False + break + + # 如果满足遗忘条件 + if (connections <= 1 and weak_connections) or content_count <= 2: + removed_item = self.forget_topic(node) + if removed_item: + forgotten_nodes.append((node, removed_item)) + logger.info(f"遗忘节点 {node} 的记忆: {removed_item}") + + # 同步到数据库 + if forgotten_nodes: + self.sync_memory_to_db_2() + logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆") + else: + logger.info("本次检查没有节点满足遗忘条件") + + async def merge_memory(self, topic): + """ + 对指定话题的记忆进行合并压缩 + + Args: + topic: 要合并的话题节点 + """ + # 获取节点的记忆项 + memory_items = self.memory_graph.G.nodes[topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 如果记忆项不足,直接返回 + if len(memory_items) < 10: + return + + # 随机选择10条记忆 + selected_memories = random.sample(memory_items, 10) + + # 拼接成文本 + merged_text = "\n".join(selected_memories) + print(f"\n[合并记忆] 话题: {topic}") + print(f"选择的记忆:\n{merged_text}") + + # 使用memory_compress生成新的压缩记忆 + compressed_memories = await self.memory_compress(merged_text, 0.1) + + # 从原记忆列表中移除被选中的记忆 + for memory in selected_memories: + memory_items.remove(memory) + + # 添加新的压缩记忆 + for _, compressed_memory in compressed_memories: + memory_items.append(compressed_memory) + print(f"添加压缩记忆: {compressed_memory}") + + # 更新节点的记忆项 + self.memory_graph.G.nodes[topic]['memory_items'] = memory_items + print(f"完成记忆合并,当前记忆数量: {len(memory_items)}") + + async def operation_merge_memory(self, percentage=0.1): + """ + 随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并 + + Args: + percentage: 要检查的节点比例,默认为0.1(10%) + """ + # 获取所有节点 + all_nodes = list(self.memory_graph.G.nodes()) + # 计算要检查的节点数量 + check_count = max(1, int(len(all_nodes) * percentage)) + # 随机选择节点 + nodes_to_check = random.sample(all_nodes, check_count) + + merged_nodes = [] + for node in nodes_to_check: + # 获取节点的内容条数 + memory_items = self.memory_graph.G.nodes[node].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + + # 如果内容数量超过100,进行合并 + if content_count > 100: + print(f"\n检查节点: {node}, 当前记忆数量: {content_count}") + await self.merge_memory(node) + merged_nodes.append(node) + + # 同步到数据库 + if merged_nodes: + self.sync_memory_to_db_2() + print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点") + else: + print("\n本次检查没有需要合并的节点") + + +def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False): + # 设置中文字体 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 + plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 + + G = memory_graph.G + + # 创建一个新图用于可视化 + H = G.copy() + + # 计算节点大小和颜色 + node_colors = [] + node_sizes = [] + nodes = list(H.nodes()) + + # 获取最大记忆数用于归一化节点大小 + max_memories = 1 + for node in nodes: + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + max_memories = max(max_memories, memory_count) + + # 计算每个节点的大小和颜色 + for node in nodes: + # 计算节点大小(基于记忆数量) + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + # 使用指数函数使变化更明显 + ratio = memory_count / max_memories + size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显 + node_sizes.append(size) + + # 计算节点颜色(基于连接数) + degree = H.degree(node) + if degree >= 30: + node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000) + else: + # 将1-10映射到0-1的范围 + color_ratio = (degree - 1) / 29.0 if degree > 1 else 0 + # 使用蓝到红的渐变 + red = min(0.9, color_ratio) + blue = max(0.0, 1.0 - color_ratio) + node_colors.append((red, 0, blue)) + + # 获取边的权重和透明度 + edge_colors = [] + max_strength = 1 + + # 找出最大强度值 + for (u, v) in H.edges(): + strength = H[u][v].get('strength', 1) + max_strength = max(max_strength, strength) + + # 创建边权重字典用于布局 + edge_weights = {} + + # 计算每条边的透明度和权重 + for (u, v) in H.edges(): + strength = H[u][v].get('strength', 1) + # 将强度映射到透明度范围 [0.05, 0.8] + alpha = 0.02 + 0.55 * (strength / max_strength) + # 使用统一的蓝色,但透明度不同 + edge_colors.append((0, 0, 1, alpha)) + # 设置边的权重(强度越大,权重越大,节点间距离越小) + edge_weights[(u, v)] = strength + + # 绘制图形 + plt.figure(figsize=(20, 16)) # 增加图形尺寸 + # 调整弹簧布局参数,使用边权重影响布局 + pos = nx.spring_layout(H, + k=2.0, # 增加节点间斥力 + iterations=100, # 增加迭代次数 + scale=2.0, # 增加布局尺寸 + weight='strength') # 使用边的strength属性作为权重 + + nx.draw(H, pos, + with_labels=True, + node_color=node_colors, + node_size=node_sizes, + font_size=8, # 稍微减小字体大小 + font_family='SimHei', + font_weight='bold', + edge_color=edge_colors, + width=1.5) # 统一的边宽度 + + title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近' + plt.title(title, fontsize=16, fontfamily='SimHei') + plt.show() + +async def main(): + # 初始化数据库 + logger.info("正在初始化数据库连接...") + db = Database.get_instance() + start_time = time.time() + + test_pare = {'do_build_memory':False,'do_forget_topic':True,'do_visualize_graph':True,'do_query':False,'do_merge_memory':True} + + # 创建记忆图 + memory_graph = Memory_graph() + + # 创建海马体 + hippocampus = Hippocampus(memory_graph) + + # 从数据库同步数据 + hippocampus.sync_memory_from_db() + + end_time = time.time() + logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m") + + # 构建记忆 + if test_pare['do_build_memory']: + logger.info("开始构建记忆...") + chat_size = 25 + await hippocampus.operation_build_memory(chat_size=chat_size) + + end_time = time.time() + logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = {chat_size}]\033[0m") + + if test_pare['do_forget_topic']: + logger.info("开始遗忘记忆...") + await hippocampus.operation_forget_topic(percentage=0.1) + + end_time = time.time() + logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") + + if test_pare['do_merge_memory']: + logger.info("开始合并记忆...") + await hippocampus.operation_merge_memory(percentage=0.1) + + end_time = time.time() + logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m") + + if test_pare['do_visualize_graph']: + # 展示优化后的图形 + logger.info("生成记忆图谱可视化...") + print("\n生成优化后的记忆图谱:") + visualize_graph_lite(memory_graph) + + if test_pare['do_query']: + # 交互式查询 + while True: + query = input("\n请输入新的查询概念(输入'退出'以结束):") + if query.lower() == '退出': + break + + items_list = memory_graph.get_related_item(query) + if items_list: + first_layer, second_layer = items_list + if first_layer: + print("\n直接相关的记忆:") + for item in first_layer: + print(f"- {item}") + if second_layer: + print("\n间接相关的记忆:") + for item in second_layer: + print(f"- {item}") + else: + print("未找到相关记忆。") + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + + diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/offline_llm.py similarity index 50% rename from src/plugins/memory_system/llm_module_memory_make.py rename to src/plugins/memory_system/offline_llm.py index 41a5d7c0f..5e877dceb 100644 --- a/src/plugins/memory_system/llm_module_memory_make.py +++ b/src/plugins/memory_system/offline_llm.py @@ -2,28 +2,23 @@ import os import requests from typing import Tuple, Union import time -from nonebot import get_driver import aiohttp import asyncio from loguru import logger -from src.plugins.chat.config import BotConfig, global_config - -driver = get_driver() -config = driver.config class LLMModel: - def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs): + def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name self.params = kwargs - self.api_key = config.siliconflow_key - self.base_url = config.siliconflow_base_url + self.api_key = os.getenv("SILICONFLOW_KEY") + self.base_url = os.getenv("SILICONFLOW_BASE_URL") if not self.api_key or not self.base_url: raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置") logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url - async def generate_response(self, prompt: str) -> Tuple[str, str]: + def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]: """根据输入的提示生成模型的响应""" headers = { "Authorization": f"Bearer {self.api_key}", @@ -47,7 +42,60 @@ class LLMModel: for retry in range(max_retries): try: - async with aiohttp.ClientSession() as session: + response = requests.post(api_url, headers=headers, json=data) + + if response.status_code == 429: + wait_time = base_wait_time * (2 ** retry) # 指数退避 + logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + response.raise_for_status() # 检查其他响应状态 + + result = response.json() + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + return content, reasoning_content + return "没有返回结果", "" + + except Exception as e: + if retry < max_retries - 1: # 如果还有重试机会 + wait_time = base_wait_time * (2 ** retry) + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + time.sleep(wait_time) + else: + logger.error(f"请求失败: {str(e)}") + return f"请求失败: {str(e)}", "" + + logger.error("达到最大重试次数,请求仍然失败") + return "达到最大重试次数,请求仍然失败", "" + + async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: + """异步方式根据输入的提示生成模型的响应""" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 构建请求体 + data = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.5, + **self.params + } + + # 发送请求到完整的 chat/completions 端点 + api_url = f"{self.base_url.rstrip('/')}/chat/completions" + logger.info(f"Request URL: {api_url}") # 记录请求的 URL + + max_retries = 3 + base_wait_time = 15 + + async with aiohttp.ClientSession() as session: + for retry in range(max_retries): + try: async with session.post(api_url, headers=headers, json=data) as response: if response.status == 429: wait_time = base_wait_time * (2 ** retry) # 指数退避 @@ -63,15 +111,15 @@ class LLMModel: reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") return content, reasoning_content return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") - await asyncio.sleep(wait_time) - else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" - - logger.error("达到最大重试次数,请求仍然失败") - return "达到最大重试次数,请求仍然失败", "" + + except Exception as e: + if retry < max_retries - 1: # 如果还有重试机会 + wait_time = base_wait_time * (2 ** retry) + logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + await asyncio.sleep(wait_time) + else: + logger.error(f"请求失败: {str(e)}") + return f"请求失败: {str(e)}", "" + + logger.error("达到最大重试次数,请求仍然失败") + return "达到最大重试次数,请求仍然失败", ""