Merge branch 'debug' into feature

This commit is contained in:
tcmofashi
2025-03-06 01:55:47 +08:00
26 changed files with 1692 additions and 1564 deletions

View File

@@ -3,10 +3,11 @@ name: Docker Build and Push
on:
push:
branches:
- main # 推送到main分支时触发
- main
- debug # 新增 debug 分支触发
tags:
- 'v*' # 推送v开头的tag时触发例如v1.0.0
workflow_dispatch: # 允许手动触发
- 'v*'
workflow_dispatch:
jobs:
build-and-push:
@@ -24,15 +25,24 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Determine Image Tags
id: tags
run: |
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
fi
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: . # Docker构建上下文路径
file: ./Dockerfile # Dockerfile路径
platforms: linux/amd64,linux/arm64 # 支持arm架构
tags: |
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }}
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
push: true
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
cache-to: type=inline
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
tags: ${{ steps.tags.outputs.tags }}
push: true
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max

3
.gitignore vendored
View File

@@ -185,3 +185,6 @@ cython_debug/
# PyPI configuration file
.pypirc
.env
# jieba
jieba.cache

View File

@@ -42,22 +42,22 @@
## 🎯 功能介绍
### 💬 聊天功能
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言目前有bug,所以现在只会检测主题,不会进行存储
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言
- 支持bot名字呼唤发言检测到"麦麦"会主动发言,可配置
- 使用硅基流动的api进行回复生成可随机使用R1V3R1-distill等模型未来将加入官网api支持
- 支持多模型,多厂商自定义配置
- 动态的prompt构建器更拟人
- 支持图片,转发消息,回复消息的识别
- 错别字和多条回复功能麦麦可以随机生成错别字会多条发送回复以及对消息进行reply
### 😊 表情包功能
- 支持根据发言内容发送对应情绪的表情包:未完善,可以用
- 会自动偷群友的表情包未完善暂时禁用目前有bug
- 支持根据发言内容发送对应情绪的表情包
- 会自动偷群友的表情包
### 📅 日程功能
- 麦麦会自动生成一天的日程,实现更拟人的回复
### 🧠 记忆功能
- 对聊天记录进行概括存储,在需要时调用,没写完
- 对聊天记录进行概括存储,在需要时调用,待完善
### 📚 知识库功能
- 基于embedding模型的知识库手动放入txt会自动识别写完了暂时禁用

View File

@@ -11,7 +11,7 @@ prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的
[message]
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
max_context_size = 15 # 麦麦获得的上文数量,超出数量后自动丢弃
max_context_size = 15 # 麦麦获得的上文数量
emoji_chance = 0.2 # 麦麦使用表情包的概率
ban_words = [
# "403","张三"
@@ -31,6 +31,7 @@ model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概
[memory]
build_memory_interval = 300 # 记忆构建间隔 单位秒
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
[others]
enable_advance_output = true # 是否启用高级输出

View File

@@ -0,0 +1,165 @@
from typing import Dict, List, Union, Optional, Any
import base64
import os
"""
OneBot v11 Message Segment Builder
This module provides classes for building message segments that conform to the
OneBot v11 standard. These segments can be used to construct complex messages
for sending through bots that implement the OneBot interface.
"""
class Segment:
"""Base class for all message segments."""
def __init__(self, type_: str, data: Dict[str, Any]):
self.type = type_
self.data = data
def to_dict(self) -> Dict[str, Any]:
"""Convert the segment to a dictionary format."""
return {
"type": self.type,
"data": self.data
}
class Text(Segment):
"""Text message segment."""
def __init__(self, text: str):
super().__init__("text", {"text": text})
class Face(Segment):
"""Face/emoji message segment."""
def __init__(self, face_id: int):
super().__init__("face", {"id": str(face_id)})
class Image(Segment):
"""Image message segment."""
@classmethod
def from_url(cls, url: str) -> 'Image':
"""Create an Image segment from a URL."""
return cls(url=url)
@classmethod
def from_path(cls, path: str) -> 'Image':
"""Create an Image segment from a file path."""
with open(path, 'rb') as f:
file_b64 = base64.b64encode(f.read()).decode('utf-8')
return cls(file=f"base64://{file_b64}")
def __init__(self, file: str = None, url: str = None, cache: bool = True):
data = {}
if file:
data["file"] = file
if url:
data["url"] = url
if not cache:
data["cache"] = "0"
super().__init__("image", data)
class At(Segment):
"""@Someone message segment."""
def __init__(self, user_id: Union[int, str]):
data = {"qq": str(user_id)}
super().__init__("at", data)
class Record(Segment):
"""Voice message segment."""
def __init__(self, file: str, magic: bool = False, cache: bool = True):
data = {"file": file}
if magic:
data["magic"] = "1"
if not cache:
data["cache"] = "0"
super().__init__("record", data)
class Video(Segment):
"""Video message segment."""
def __init__(self, file: str):
super().__init__("video", {"file": file})
class Reply(Segment):
"""Reply message segment."""
def __init__(self, message_id: int):
super().__init__("reply", {"id": str(message_id)})
class MessageBuilder:
"""Helper class for building complex messages."""
def __init__(self):
self.segments: List[Segment] = []
def text(self, text: str) -> 'MessageBuilder':
"""Add a text segment."""
self.segments.append(Text(text))
return self
def face(self, face_id: int) -> 'MessageBuilder':
"""Add a face/emoji segment."""
self.segments.append(Face(face_id))
return self
def image(self, file: str = None) -> 'MessageBuilder':
"""Add an image segment."""
self.segments.append(Image(file=file))
return self
def at(self, user_id: Union[int, str]) -> 'MessageBuilder':
"""Add an @someone segment."""
self.segments.append(At(user_id))
return self
def record(self, file: str, magic: bool = False) -> 'MessageBuilder':
"""Add a voice record segment."""
self.segments.append(Record(file, magic))
return self
def video(self, file: str) -> 'MessageBuilder':
"""Add a video segment."""
self.segments.append(Video(file))
return self
def reply(self, message_id: int) -> 'MessageBuilder':
"""Add a reply segment."""
self.segments.append(Reply(message_id))
return self
def build(self) -> List[Dict[str, Any]]:
"""Build the message into a list of segment dictionaries."""
return [segment.to_dict() for segment in self.segments]
'''Convenience functions
def text(content: str) -> Dict[str, Any]:
"""Create a text message segment."""
return Text(content).to_dict()
def image_url(url: str) -> Dict[str, Any]:
"""Create an image message segment from URL."""
return Image.from_url(url).to_dict()
def image_path(path: str) -> Dict[str, Any]:
"""Create an image message segment from file path."""
return Image.from_path(path).to_dict()
def at(user_id: Union[int, str]) -> Dict[str, Any]:
"""Create an @someone message segment."""
return At(user_id).to_dict()'''

View File

@@ -13,6 +13,7 @@ from .willing_manager import willing_manager
from nonebot.rule import to_me
from .bot import chat_bot
from .emoji_manager import emoji_manager
import time
# 获取驱动器
@@ -86,19 +87,27 @@ async def _(bot: Bot):
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
'''
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
async def monitor_relationships():
"""每15秒打印一次关系数据"""
relationship_manager.print_all_relationships()
'''
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
await hippocampus.build_memory(chat_size=30)
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time()
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆整合]\033[0m 开始整合")
# await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")

View File

@@ -69,11 +69,9 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
identifier=topic_identifier.identify_topic()
if global_config.topic_extract=='llm':
topic=await identifier(message.processed_plain_text)
else:
topic=identifier(message.detailed_plain_text)
topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
# topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
# topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)

View File

@@ -26,7 +26,8 @@ class BotConfig:
talk_frequency_down_groups = set()
ban_user_id = set()
build_memory_interval: int = 60 # 记忆构建间隔(秒)
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
@@ -155,6 +156,7 @@ class BotConfig:
if "memory" in toml_dict:
memory_config = toml_dict["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
# 群组配置
if "groups" in toml_dict:
@@ -188,6 +190,6 @@ global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
# logger.remove()
logger.remove()
pass

View File

@@ -1,251 +0,0 @@
from typing import Union, List, Optional, Deque, Dict
from nonebot.adapters.onebot.v11 import Bot, MessageSegment
import asyncio
import random
import os
from .message import Message, Message_Thinking, MessageSet
from .cq_code import CQCode
from collections import deque
import time
from .storage import MessageStorage
from .config import global_config
from .cq_code import cq_code_tool
if os.name == "nt":
from .message_visualizer import message_visualizer
class SendTemp:
"""单个群组的临时消息队列管理器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
self.max_size = max_size
self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
self.last_send_time = 0
def add(self, message: Message) -> None:
"""按时间顺序添加消息到队列"""
if not self.messages:
self.messages.append(message)
return
# 按时间顺序插入
if message.time >= self.messages[-1].time:
self.messages.append(message)
return
# 使用二分查找找到合适的插入位置
messages_list = list(self.messages)
left, right = 0, len(messages_list)
while left < right:
mid = (left + right) // 2
if messages_list[mid].time < message.time:
left = mid + 1
else:
right = mid
# 重建消息队列,保持时间顺序
new_messages = deque(maxlen=self.max_size)
new_messages.extend(messages_list[:left])
new_messages.append(message)
new_messages.extend(messages_list[left:])
self.messages = new_messages
def get_earliest_message(self) -> Optional[Message]:
"""获取时间最早的消息"""
message = self.messages.popleft() if self.messages else None
return message
def clear(self) -> None:
"""清空队列"""
self.messages.clear()
def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
"""获取所有待发送的消息"""
if group_id is None:
return list(self.messages)
return [msg for msg in self.messages if msg.group_id == group_id]
def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
"""查看下一条要发送的消息(不移除)"""
return self.messages[0] if self.messages else None
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def count(self, group_id: Optional[int] = None) -> int:
"""获取待发送消息数量"""
if group_id is None:
return len(self.messages)
return len([msg for msg in self.messages if msg.group_id == group_id])
def get_last_send_time(self) -> float:
"""获取最后一次发送时间"""
return self.last_send_time
def update_send_time(self):
"""更新最后发送时间"""
self.last_send_time = time.time()
class SendTempContainer:
"""管理所有群组的消息缓存容器"""
def __init__(self):
self.temp_queues: Dict[int, SendTemp] = {}
def get_queue(self, group_id: int) -> SendTemp:
"""获取或创建群组的消息队列"""
if group_id not in self.temp_queues:
self.temp_queues[group_id] = SendTemp(group_id)
return self.temp_queues[group_id]
def add_message(self, message: Message) -> None:
"""添加消息到对应群组的队列"""
queue = self.get_queue(message.group_id)
queue.add(message)
def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
"""获取指定群组的所有待发送消息"""
queue = self.get_queue(group_id)
return queue.get_all()
def has_messages(self, group_id: int) -> bool:
"""检查指定群组是否有待发送消息"""
queue = self.get_queue(group_id)
return queue.has_messages()
def get_all_groups(self) -> List[int]:
"""获取所有有待发送消息的群组ID"""
return list(self.temp_queues.keys())
def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
queue = self.get_queue(message_obj.group_id)
# 使用列表解析找到匹配的消息索引
matching_indices = [
i for i, msg in enumerate(queue.messages)
if msg.message_id == message_obj.message_id
]
if not matching_indices:
return False
index = matching_indices[0] # 获取第一个匹配的索引
# 将消息转换为列表以便修改
messages = list(queue.messages)
# 根据消息类型处理
if isinstance(message_obj, MessageSet):
messages.pop(index)
# 在原位置插入新消息组
for i, single_message in enumerate(message_obj.messages):
messages.insert(index + i, single_message)
# print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
else:
# 直接替换原消息
messages[index] = message_obj
# print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
# 重建队列
queue.messages.clear()
for msg in messages:
queue.messages.append(msg)
return True
class MessageSendControl:
"""消息发送控制器"""
def __init__(self):
self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
self.max_retry = 3 # 最大重试次数
self.send_temp_container = SendTempContainer()
self._running = True
self._paused = False
self._current_bot = None
self.storage = MessageStorage() # 添加存储实例
try:
message_visualizer.start()
except(NameError):
pass
async def process_group_messages(self, group_id: int):
queue = self.send_temp_container.get_queue(group_id)
if queue.has_messages():
message = queue.peek_next()
# 处理消息的逻辑
if isinstance(message, Message_Thinking):
message.update_thinking_time()
thinking_time = message.thinking_time
if message.interupt:
print(f"\033[1;34m[调试]\033[0m 思考不打算回复,移除")
queue.get_earliest_message()
return
elif thinking_time < 90: # 最少思考2秒
if int(thinking_time) % 15 == 0:
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}")
return
else:
print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
queue.get_earliest_message() # 移除超时的思考消息
return
elif isinstance(message, Message):
message = queue.get_earliest_message()
if message and message.processed_plain_text:
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
cost_time = round(time.time(), 2) - message.time
if cost_time > 40:
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
cur_time = time.time()
await self._current_bot.send_group_msg(
group_id=group_id,
message=str(message.processed_plain_text),
auto_escape=False
)
cost_time = round(time.time(), 2) - cur_time
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}")
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
if message.is_emoji:
message.processed_plain_text = "[表情包]"
await self.storage.store_message(message, None)
else:
await self.storage.store_message(message, None)
queue.update_send_time()
if queue.has_messages():
await asyncio.sleep(
random.uniform(
self.message_interval[0],
self.message_interval[1]
)
)
async def start_processor(self, bot: Bot):
"""启动消息处理器"""
self._current_bot = bot
while self._running:
await asyncio.sleep(1.5)
tasks = []
for group_id in self.send_temp_container.get_all_groups():
tasks.append(self.process_group_messages(group_id))
# 并行处理所有群组的消息
await asyncio.gather(*tasks)
try:
message_visualizer.update_content(self.send_temp_container)
except(NameError):
pass
def set_typing_speed(self, min_speed: float, max_speed: float):
"""设置打字速度范围"""
self.typing_speed = (min_speed, max_speed)
# 创建全局实例
message_sender_control = MessageSendControl()

View File

@@ -1,271 +0,0 @@
from typing import List, Optional, Dict
from .message import Message
import time
from collections import deque
from datetime import datetime, timedelta
import os
import json
import asyncio
class MessageStream:
"""单个群组的消息流容器"""
def __init__(self, group_id: int, max_size: int = 1000):
self.group_id = group_id
self.messages = deque(maxlen=max_size)
self.max_size = max_size
self.last_save_time = time.time()
# 确保日志目录存在
self.log_dir = os.path.join("log", str(self.group_id))
os.makedirs(self.log_dir, exist_ok=True)
# 启动自动保存任务
asyncio.create_task(self._auto_save())
async def _auto_save(self):
"""每30秒自动保存一次消息记录"""
while True:
await asyncio.sleep(30) # 等待30秒
await self.save_to_log()
async def save_to_log(self):
"""将消息保存到日志文件"""
try:
current_time = time.time()
# 只有有新消息时才保存
if not self.messages or self.last_save_time == current_time:
return
# 生成日志文件名 (使用当前日期)
date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
# 获取需要保存的新消息
new_messages = [
msg for msg in self.messages
if msg.time > self.last_save_time
]
if not new_messages:
return
# 将消息转换为可序列化的格式
message_logs = []
for msg in new_messages:
message_logs.append({
"time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
"user_id": msg.user_id,
"user_nickname": msg.user_nickname,
"user_cardname": msg.user_cardname,
"message_id": msg.message_id,
"raw_message": msg.raw_message,
"processed_text": msg.processed_plain_text
})
# 追加写入日志文件
with open(log_file, "a", encoding="utf-8") as f:
for log in message_logs:
f.write(json.dumps(log, ensure_ascii=False) + "\n")
self.last_save_time = current_time
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
def add_message(self, message: Message) -> None:
"""按时间顺序添加新消息到队列
使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
Args:
message: Message对象要添加的新消息
"""
# 空队列或消息应该添加到末尾的情况
if (not self.messages or
message.time >= self.messages[-1].time):
self.messages.append(message)
return
# 消息应该添加到开头的情况
if message.time <= self.messages[0].time:
self.messages.appendleft(message)
return
# 使用二分查找在现有队列中找到合适的插入位置
left, right = 0, len(self.messages) - 1
while left <= right:
mid = (left + right) // 2
if self.messages[mid].time < message.time:
left = mid + 1
else:
right = mid - 1
temp = list(self.messages)
temp.insert(left, message)
# 如果超出最大长度,移除多余的消息
if len(temp) > self.max_size:
temp = temp[-self.max_size:]
# 重建队列
self.messages = deque(temp, maxlen=self.max_size)
async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
"""从数据库中获取最近的消息记录
Args:
count: 需要获取的消息数量
Returns:
List[Message]: 最近的消息列表
"""
try:
from ...common.database import Database
db = Database.get_instance()
# 从数据库中查询最近的消息
recent_messages = list(db.db.messages.find(
{"group_id": self.group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# # "user_cardname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(count))
if not recent_messages:
return []
# 转换为 Message 对象
from .message import Message
messages = []
for msg_data in recent_messages:
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
return list(reversed(messages)) # 返回按时间正序的消息
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
return []
def get_recent_messages(self, count: int = 10) -> List[Message]:
"""获取最近的n条消息从内存队列"""
print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
return list(self.messages)[-count:]
def get_messages_in_timerange(self,
start_time: Optional[float] = None,
end_time: Optional[float] = None) -> List[Message]:
"""获取时间范围内的消息"""
if start_time is None:
start_time = time.time() - 3600
if end_time is None:
end_time = time.time()
return [
msg for msg in self.messages
if start_time <= msg.time <= end_time
]
def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
"""获取特定用户的最近消息"""
user_messages = [msg for msg in self.messages if msg.user_id == user_id]
return user_messages[-count:]
def clear_old_messages(self, hours: int = 24) -> None:
"""清理旧消息"""
cutoff_time = time.time() - (hours * 3600)
self.messages = deque(
[msg for msg in self.messages if msg.time > cutoff_time],
maxlen=self.max_size
)
class MessageStreamContainer:
"""管理所有群组的消息流容器"""
def __init__(self, max_size: int = 1000):
self.streams: Dict[int, MessageStream] = {}
self.max_size = max_size
async def save_all_logs(self):
"""保存所有群组的消息日志"""
for stream in self.streams.values():
await stream.save_to_log()
def add_message(self, message: Message) -> None:
"""添加消息到对应群组的消息流"""
if not message.group_id:
return
if message.group_id not in self.streams:
self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
self.streams[message.group_id].add_message(message)
def get_stream(self, group_id: int) -> Optional[MessageStream]:
"""获取特定群组的消息流"""
return self.streams.get(group_id)
def get_all_streams(self) -> Dict[int, MessageStream]:
"""获取所有群组的消息流"""
return self.streams
def clear_old_messages(self, hours: int = 24) -> None:
"""清理所有群组的旧消息"""
for stream in self.streams.values():
stream.clear_old_messages(hours)
def get_group_stats(self, group_id: int) -> Dict:
"""获取群组的消息统计信息"""
stream = self.streams.get(group_id)
if not stream:
return {
"total_messages": 0,
"unique_users": 0,
"active_hours": [],
"most_active_user": None
}
messages = stream.messages
user_counts = {}
hour_counts = {}
for msg in messages:
user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
hour = datetime.fromtimestamp(msg.time).hour
hour_counts[hour] = hour_counts.get(hour, 0) + 1
most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
active_hours = sorted(
hour_counts.items(),
key=lambda x: x[1],
reverse=True
)[:5]
return {
"total_messages": len(messages),
"unique_users": len(user_counts),
"active_hours": active_hours,
"most_active_user": most_active_user
}
# 创建全局实例
message_stream_container = MessageStreamContainer()

View File

@@ -1,138 +0,0 @@
import subprocess
import threading
import queue
import os
import time
from typing import Dict
from .message import Message_Thinking
class MessageVisualizer:
def __init__(self):
self.process = None
self.message_queue = queue.Queue()
self.is_running = False
self.content_file = "message_queue_content.txt"
def start(self):
if self.process is None:
# 创建用于显示的批处理文件
with open("message_queue_window.bat", "w", encoding="utf-8") as f:
f.write('@echo off\n')
f.write('chcp 65001\n') # 设置UTF-8编码
f.write('title Message Queue Visualizer\n')
f.write('echo Waiting for message queue updates...\n')
f.write(':loop\n')
f.write('if exist "queue_update.txt" (\n')
f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
f.write(' del "queue_update.txt"\n')
f.write(' cls\n')
f.write(' type "message_queue_content.txt"\n')
f.write(')\n')
f.write('timeout /t 1 /nobreak >nul\n')
f.write('goto loop\n')
# 清空内容文件
with open(self.content_file, "w", encoding="utf-8") as f:
f.write("")
# 启动新窗口
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
self.process = subprocess.Popen(
['cmd', '/c', 'start', 'message_queue_window.bat'],
shell=True,
startupinfo=startupinfo
)
self.is_running = True
# 启动处理线程
threading.Thread(target=self._process_messages, daemon=True).start()
def _process_messages(self):
while self.is_running:
try:
# 获取新消息
text = self.message_queue.get(timeout=1)
# 写入更新文件
with open("queue_update.txt", "w", encoding="utf-8") as f:
f.write(text)
except queue.Empty:
continue
except Exception as e:
print(f"处理队列可视化内容时出错: {e}")
def update_content(self, send_temp_container):
"""更新显示内容"""
if not self.is_running:
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
display_text = f"Message Queue Status - {current_time}\n"
display_text += "=" * 50 + "\n\n"
# 遍历所有群组的队列
for group_id, queue in send_temp_container.temp_queues.items():
display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
display_text += f"消息队列长度: {len(queue.messages)}\n"
display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
display_text += "\n消息队列内容:\n"
# 显示队列中的消息
if not queue.messages:
display_text += " [空队列]\n"
else:
for i, msg in enumerate(queue.messages):
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
display_text += f"\n--- 消息 {i+1} ---\n"
if isinstance(msg, Message_Thinking):
display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
display_text += f"时间: {msg_time}\n"
display_text += f"消息ID: {msg.message_id}\n"
display_text += f"群组: {msg.group_id}\n"
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
display_text += f"内容: {msg.thinking_text}\n"
display_text += f"思考时间: {int(msg.thinking_time)}\n"
else:
display_text += f"类型: 普通消息\n"
display_text += f"时间: {msg_time}\n"
display_text += f"消息ID: {msg.message_id}\n"
display_text += f"群组: {msg.group_id}\n"
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
if hasattr(msg, 'is_emoji') and msg.is_emoji:
display_text += f"内容: [表情包消息]\n"
else:
# 显示原始消息和处理后的消息
display_text += f"原始内容: {msg.raw_message[:50]}...\n"
display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
if msg.reply_message:
display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
display_text += f"\n{'-' * 50}\n"
# 添加统计信息
display_text += "\n总体统计:\n"
display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
display_text += f"总消息数: {total_messages}\n"
thinking_messages = sum(
sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
for q in send_temp_container.temp_queues.values()
)
display_text += f"思考中消息数: {thinking_messages}\n"
self.message_queue.put(display_text)
def stop(self):
self.is_running = False
if self.process:
self.process.terminate()
self.process = None
# 清理文件
for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
if os.path.exists(file):
os.remove(file)
# 创建全局单例
message_visualizer = MessageVisualizer()

View File

@@ -95,7 +95,11 @@ class ResponseGenerator:
# return None
# 生成回复
content, reasoning_content = await model.generate_response(prompt)
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception as e:
print(f"生成回复时出错: {e}")
return None
# 保存到数据库
self._save_to_db(
@@ -138,9 +142,12 @@ class ResponseGenerator:
内容:{content}
输出:
'''
content, _ = await self.model_v3.generate_response(prompt)
return [content.strip()] if content else ["neutral"]
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
return [content]
else:
return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")

View File

@@ -52,12 +52,16 @@ class Message_Sender:
await asyncio.sleep(typing_time)
# 发送消息
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer:

View File

@@ -36,7 +36,9 @@ class PromptBuilder:
memory_prompt = ''
start_time = time.time() # 记录开始时间
topic = topic_identifier.identify_topic_jieba(message_txt)
# topic = await topic_identifier.identify_topic_llm(message_txt)
topic = topic_identifier.identify_topic_snownlp(message_txt)
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
all_first_layer_items = [] # 存储所有第一层记忆
@@ -64,15 +66,7 @@ class PromptBuilder:
if overlap:
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
# if all_first_layer_items:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
# if overlapping_second_layer:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
# 使用集合去重
# 从每个来源随机选择2条记忆如果有的话
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []

View File

@@ -15,16 +15,6 @@ class TopicIdentifier:
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
self.select=global_config.topic_extract
def identify_topic(self):
if self.select=='jieba':
return self.identify_topic_jieba
elif self.select=='snownlp':
return self.identify_topic_snownlp
elif self.select=='llm':
return self.identify_topic_llm
else:
return self.identify_topic_snownlp
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
@@ -48,56 +38,10 @@ class TopicIdentifier:
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
words = jieba.lcut(text)
# 去除停用词和标点符号
stop_words = {
'', '', '', '', '', '', '', '', '', '', '', '', '', '',
'因为', '所以', '如果', '虽然', '一个', '', '', '', '', '', '我们', '你们',
'他们', '', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么',
'那么', '多少', '', '', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办',
'怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样',
'一会儿', '一边', '一起',
# 添加更多量词
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
# 添加更多介词
'', '按照', '', '', '', '比如', '', '除了', '', '', '对于',
'根据', '关于', '', '', '', '', '经过', '', '', '', '通过',
'', '', '', '为了', '围绕', '', '', '由于', '', '', '沿', '沿着',
'', '依照', '', '', '因为', '', '', '', '', '自从'
}
# 过滤掉停用词和标点符号,只保留名词和动词
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '',
'', '', '', '', '', '', '', '', '', '', ''
}:
filtered_words.append(word)
# 统计词频
word_freq = {}
for word in filtered_words:
word_freq[word] = word_freq.get(word, 0) + 1
# 按词频排序取前3个
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
top_words = [word for word, freq in sorted_words[:3]]
return top_words if top_words else None
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
@@ -113,7 +57,7 @@ class TopicIdentifier:
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(3)
keywords = s.keywords(5)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")

View File

@@ -75,13 +75,11 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2)
def calculate_information_content(text):
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
@@ -90,27 +88,37 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
# print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
return [] # 如果没有找到记录,返回空列表
print(f"消息已读取3次跳过")
return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录

View File

@@ -7,6 +7,7 @@ from ...common.database import Database
import zlib # 用于 CRC32
import base64
from nonebot import get_driver
from loguru import logger
driver = get_driver()
config = driver.config
@@ -213,11 +214,11 @@ def storage_image(image_data: bytes) -> bytes:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def compress_base64_image_by_scale(base64_data: str, scale: float = 0.5) -> str:
"""按比例压缩base64格式的图片
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
scale: 压缩比例0-1之间的浮点数
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
@@ -225,34 +226,64 @@ def compress_base64_image_by_scale(base64_data: str, scale: float = 0.5) -> str:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= target_size:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(img.width * scale)
new_height = int(img.height * scale)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 缩放图片
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 转换为RGB模式去除透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 保存压缩后的图片
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
# 转换回base64
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
logger.error(f"压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
logger.error(traceback.format_exc())
return base64_data

View File

@@ -9,7 +9,7 @@ class WillingManager:
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(3)
await asyncio.sleep(5)
for group_id in self.group_reply_willing:
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
@@ -39,11 +39,11 @@ class WillingManager:
if interested_rate > 0.65:
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.5
current_willing += interested_rate-0.6
self.group_reply_willing[group_id] = min(current_willing, 3.0)
reply_probability = max((current_willing - 0.5) * 2, 0)
reply_probability = max((current_willing - 0.55) * 1.9, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
@@ -52,8 +52,8 @@ class WillingManager:
reply_probability = reply_probability / 3.5
reply_probability = min(reply_probability, 1)
if reply_probability < 0.1:
reply_probability = 0.1
if reply_probability < 0:
reply_probability = 0
return reply_probability
def change_reply_willing_sent(self, group_id: int):
@@ -65,7 +65,7 @@ class WillingManager:
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""

View File

@@ -79,7 +79,7 @@ class KnowledgeLibrary:
content = f.read()
# 按1024字符分段
segments = [content[i:i+400] for i in range(0, len(content), 400)]
segments = [content[i:i+600] for i in range(0, len(content), 600)]
# 处理每个分段
for segment in segments:

View File

@@ -22,63 +22,6 @@ from src.common.database import Database # 使用正确的导入语法
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class LLMModel:
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
class Memory_graph:
def __init__(self):
@@ -232,19 +175,10 @@ def main():
)
memory_graph = Memory_graph()
# 创建LLM模型实例
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
# visualize_graph(memory_graph, color_by_memory=False)
visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
# visualize_graph(memory_graph, color_by_memory=True)
visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
@@ -327,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
node_size=200,
font_size=10,
font_family='SimHei',
font_weight='bold')
@@ -353,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 2 or degree <= 2:
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
@@ -366,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
# 计算节点大小和颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
node_sizes = []
nodes = list(H.nodes())
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
node_sizes.append(size)
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
red = min(1.0, degree / max_degree)
# 蓝色分量随着度数减少而增加
blue = 1.0 - red
color = (red, 0, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold')
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.7)
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()

View File

@@ -9,15 +9,26 @@ import random
import time
from ..chat.config import global_config
from ...common.database import Database # 使用正确的导入语法
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
from ..models.utils_model import LLM_request
import math
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
@@ -38,9 +49,7 @@ class Memory_graph:
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
@@ -52,7 +61,6 @@ class Memory_graph:
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
@@ -69,7 +77,6 @@ class Memory_graph:
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
@@ -87,87 +94,59 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def save_graph_to_db(self):
# 保存节点
for node in self.G.nodes(data=True):
concept = node[0]
memory_items = node[1].get('memory_items', [])
def forget_topic(self, topic):
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
if topic not in self.G:
return None
# 查找是否存在同名节点
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 获取话题节点数据
node_data = self.G.nodes[topic]
# 保存边
for edge in self.G.edges():
source, target = edge
# 如果节点存在memory_items
if 'memory_items' in node_data:
memory_items = node_data['memory_items']
# 查找是否存在同样的边
existing_edge = self.db.db.graph_data.edges.find_one({
'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.G.nodes[topic]['memory_items'] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.G.remove_node(topic)
return removed_item
return None
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5)
def calculate_node_hash(self, concept, memory_items):
"""计算节点的特征值"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
sorted_items = sorted(memory_items)
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""计算边的特征值"""
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
@@ -175,82 +154,340 @@ class Hippocampus:
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
return [text for text in chat_text if text]
async def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = await self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
compressed_memory = set()
async def memory_compress(self, input_text, compress_rate=0.1):
print(input_text)
#获取topics
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
# 修改话题处理逻辑
print(f"话题: {topics_response[0]}")
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
print(f"话题: {topics}")
# 创建所有话题的请求任务
tasks = []
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
topic_what_prompt = self.topic_what(input_text, topic)
# 创建异步任务
task = self.llm_model_summary.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
def calculate_topic_num(self,text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count('\n')*compress_rate
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content)/2)
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
return topic_num
async def operation_build_memory(self,chat_size=20):
# 最近消息获取频率
time_frequency = {'near':2,'mid':4,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
# 加载进度可视化
all_topics = []
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = await self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_to_db(self):
"""检查并同步内存中的图结构与数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
node_data = {
'concept': concept,
'memory_items': memory_items,
'hash': memory_hash
}
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
'hash': memory_hash
}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
db_edge_dict[(edge['source'], edge['target'])] = {
'hash': edge_hash,
'strength': edge.get('strength', 1)
}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = self.memory_graph.G[source][target].get('strength', 1)
if edge_key not in db_edge_dict:
# 添加新边
edge_data = {
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash
}
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
'strength': strength
}}
)
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
self.memory_graph.db.db.graph_data.edges.delete_one({
'source': source,
'target': target
})
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
print(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
print("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(merged_text, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
def find_topic_llm(self,text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt
def topic_what(self,text, topic):
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
from nonebot import get_driver
driver = get_driver()
@@ -268,10 +505,10 @@ Database.initialize(
)
#创建记忆图
memory_graph = Memory_graph()
#加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
#创建海马体
hippocampus = Hippocampus(memory_graph)
#从数据库加载记忆图
hippocampus.sync_memory_from_db()
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")

View File

@@ -1,463 +0,0 @@
# -*- coding: utf-8 -*-
import sys
import jieba
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
import os
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
return chat_text
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
if record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 保存节点
for node in self.G.nodes(data=True):
concept = node[0]
memory_items = node[1].get('memory_items', [])
# 查找是否存在同名节点
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
source, target = edge
# 查找是否存在同样的边
existing_edge = self.db.db.graph_data.edges.find_one({
'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
#加载进度可视化
for i, input_text in enumerate(memory_sample, 1):
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# print(f"第{i}条消息: {input_text}")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 1 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
start_time = time.time()
# 创建记忆图
memory_graph = Memory_graph()
# 加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
hippocampus.build_memory(chat_size=25)
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# 交互式查询
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,786 @@
# -*- coding: utf-8 -*-
import sys
import jieba
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
import os
from dotenv import load_dotenv
import pymongo
from loguru import logger
from pathlib import Path
from snownlp import SnowNLP
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database
from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录
current_dir = Path(__file__).resolve().parent
# 获取项目根目录(上三层目录)
project_root = current_dir.parent.parent.parent
# env.dev文件路径
env_path = project_root / ".env.dev"
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
print(f"消息已读取3次跳过")
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return [chat for chat in chat_text if chat]
def calculate_topic_num(self,text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count('\n')*compress_rate
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content)/2)
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
return topic_num
async def memory_compress(self, input_text, compress_rate=0.1):
print(input_text)
#获取topics
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num))
# 修改话题处理逻辑
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
print(f"话题: {topics}")
# 创建所有话题的请求任务
tasks = []
for topic in topics:
topic_what_prompt = self.topic_what(input_text, topic)
# 创建异步任务
task = self.llm_model_small.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def operation_build_memory(self, chat_size=12):
# 最近消息获取频率
time_frequency = {'near': 3, 'mid': 8, 'far': 5}
memory_sample = self.get_memory_sample(chat_size, time_frequency)
all_topics = [] # 用于存储所有话题
for i, input_text in enumerate(memory_sample, 1):
# 加载进度可视化
all_topics = []
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_from_db(self):
"""
从数据库同步数据到内存中的图结构
将清空当前内存中的图,并从数据库重新加载所有节点和边
"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
logger.success("从数据库同步记忆图谱完成")
def calculate_node_hash(self, concept, memory_items):
"""
计算节点的特征值
"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 将记忆项排序以确保相同内容生成相同的哈希值
sorted_items = sorted(memory_items)
# 组合概念和记忆项生成特征值
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""
计算边的特征值
"""
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def sync_memory_to_db(self):
"""
检查并同步内存中的图结构与数据库
使用特征值(哈希值)快速判断是否需要更新
"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
logger.info(f"添加新节点: {concept}")
node_data = {
'concept': concept,
'memory_items': memory_items,
'hash': memory_hash
}
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
logger.info(f"更新节点内容: {concept}")
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
'hash': memory_hash
}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
logger.info(f"删除多余节点: {db_node['concept']}")
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
db_edge_dict[(edge['source'], edge['target'])] = {
'hash': edge_hash,
'num': edge.get('num', 1)
}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
if edge_key not in db_edge_dict:
# 添加新边
logger.info(f"添加新边: {source} - {target}")
edge_data = {
'source': source,
'target': target,
'num': 1,
'hash': edge_hash
}
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
logger.info(f"更新边: {source} - {target}")
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'hash': edge_hash}}
)
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
logger.info(f"删除多余边: {source} - {target}")
self.memory_graph.db.db.graph_data.edges.delete_one({
'source': source,
'target': target
})
logger.success("完成记忆图谱与数据库的差异同步")
def find_topic_llm(self,text, topic_num):
# prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt
def topic_what(self,text, topic):
# prompt = f'这是一段文字:{text}。我想知道这段文字里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt
def remove_node_from_db(self, topic):
"""
从数据库中删除指定节点及其相关的边
Args:
topic: 要删除的节点概念
"""
# 删除节点
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边
self.memory_graph.db.db.graph_data.edges.delete_many({
'$or': [
{'source': topic},
{'target': topic}
]
})
def forget_topic(self, topic):
"""
随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
只在内存中的图上操作,不直接与数据库交互
Args:
topic: 要删除记忆的话题
Returns:
removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
"""
if topic not in self.memory_graph.G:
return None
# 获取话题节点数据
node_data = self.memory_graph.G.nodes[topic]
# 如果节点存在memory_items
if 'memory_items' in node_data:
memory_items = node_data['memory_items']
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.memory_graph.G.remove_node(topic)
return removed_item
return None
async def operation_forget_topic(self, percentage=0.1):
"""
随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
logger.info("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(merged_text, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数用于归一化节点大小
max_memories = 1
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
max_memories = max(max_memories, memory_count)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 400 + 2000 * (ratio ** 2) # 增大节点大小
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
if degree >= 30:
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
else:
# 将1-10映射到0-1的范围
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
# 使用蓝到红的渐变
red = min(0.9, color_ratio)
blue = max(0.0, 1.0 - color_ratio)
node_colors.append((red, 0, blue))
# 绘制图形
plt.figure(figsize=(16, 12)) # 减小图形尺寸
pos = nx.spring_layout(H,
k=1, # 调整节点间斥力
iterations=100, # 增加迭代次数
scale=1.5, # 减小布局尺寸
weight='strength') # 使用边的strength属性作为权重
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=12, # 保持增大的字体大小
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=1.5) # 统一的边宽度
title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare['do_build_memory']:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare['do_query']:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
if first_layer:
print("\n直接相关的记忆:")
for item in first_layer:
print(f"- {item}")
if second_layer:
print("\n间接相关的记忆:")
for item in second_layer:
print(f"- {item}")
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -2,28 +2,23 @@ import os
import requests
from typing import Tuple, Union
import time
from nonebot import get_driver
import aiohttp
import asyncio
from loguru import logger
from src.plugins.chat.config import BotConfig, global_config
driver = get_driver()
config = driver.config
class LLMModel:
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
async def generate_response(self, prompt: str) -> Tuple[str, str]:
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
@@ -47,7 +42,60 @@ class LLMModel:
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15
async with aiohttp.ClientSession() as session:
for retry in range(max_retries):
try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
@@ -63,15 +111,15 @@ class LLMModel:
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -55,13 +55,24 @@ class LLM_request:
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
if response.status in [500, 503]:
logger.error(f"服务器错误: {response.status}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
return content, reasoning_content
return "没有返回结果", ""
@@ -117,6 +128,7 @@ class LLM_request:
base_wait_time = 15
current_image_base64 = image_base64
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
for retry in range(max_retries):
@@ -163,6 +175,61 @@ class LLM_request:
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15
async with aiohttp.ClientSession() as session:
for retry in range(max_retries):
try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""同步方法:根据输入的提示和图片生成模型的响应"""
headers = {
@@ -170,6 +237,8 @@ class LLM_request:
"Content-Type": "application/json"
}
image_base64=compress_base64_image_by_scale(image_base64)
# 构建请求体
data = {
"model": self.model_name,

View File

@@ -1,12 +1,12 @@
import datetime
import os
from typing import List, Dict
from typing import List, Dict, Union
from ...common.database import Database # 使用正确的导入语法
from src.plugins.chat.config import global_config
from nonebot import get_driver
from ..models.utils_model import LLM_request
from loguru import logger
import json
driver = get_driver()
config = driver.config
@@ -58,19 +58,20 @@ class ScheduleGenerator:
elif read_only == False:
print(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:"""+\
"""
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床""""
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制,格式为{"时间": "活动","时间": "活动",...}"""
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
# print(self.schedule_text)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
else:
print(f"{date_str}的日程不存在。")
schedule_text = "忘了"
@@ -80,20 +81,15 @@ class ScheduleGenerator:
schedule_form = self._parse_schedule(schedule_text)
return schedule_text,schedule_form
def _parse_schedule(self, schedule_text: str) -> Dict[str, str]:
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
"""解析日程文本,转换为时间和活动的字典"""
schedule_dict = {}
# 按行分割日程文本
lines = schedule_text.strip().split('\n')
for line in lines:
# print(line)
if ',' in line:
# 假设格式为 "时间: 活动"
time_str, activity = line.split(',', 1)
# print(time_str)
# print(activity)
schedule_dict[time_str.strip()] = activity.strip()
return schedule_dict
try:
schedule_dict = json.loads(schedule_text)
return schedule_dict
except json.JSONDecodeError as e:
print(schedule_text)
print(f"解析日程失败: {str(e)}")
return False
def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间"""
@@ -108,6 +104,8 @@ class ScheduleGenerator:
min_diff = float('inf')
# 检查今天的日程
if not self.today_schedule:
return "摸鱼"
for time_str in self.today_schedule.keys():
diff = abs(self._time_diff(current_time, time_str))
if closest_time is None or diff < min_diff:
@@ -148,11 +146,14 @@ class ScheduleGenerator:
def print_schedule(self):
"""打印完整的日程安排"""
print("\n=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n")
if not self._parse_schedule(self.today_schedule_text):
print("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
print("\n=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n")
# def main():
# # 使用示例