Merge remote-tracking branch 'upstream/debug' into debug
This commit is contained in:
34
.github/workflows/docker-image.yml
vendored
34
.github/workflows/docker-image.yml
vendored
@@ -3,10 +3,11 @@ name: Docker Build and Push
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # 推送到main分支时触发
|
||||
- main
|
||||
- debug # 新增 debug 分支触发
|
||||
tags:
|
||||
- 'v*' # 推送v开头的tag时触发(例如v1.0.0)
|
||||
workflow_dispatch: # 允许手动触发
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -24,15 +25,24 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Determine Image Tags
|
||||
id: tags
|
||||
run: |
|
||||
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
|
||||
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
|
||||
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
|
||||
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: . # Docker构建上下文路径
|
||||
file: ./Dockerfile # Dockerfile路径
|
||||
platforms: linux/amd64,linux/arm64 # 支持arm架构
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }}
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
|
||||
cache-to: type=inline
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ${{ steps.tags.outputs.tags }}
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max
|
||||
10
README.md
10
README.md
@@ -42,22 +42,22 @@
|
||||
## 🎯 功能介绍
|
||||
|
||||
### 💬 聊天功能
|
||||
- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言,目前有bug,所以现在只会检测主题,不会进行存储
|
||||
- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言
|
||||
- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
|
||||
- 使用硅基流动的api进行回复生成,可随机使用R1,V3,R1-distill等模型,未来将加入官网api支持
|
||||
- 支持多模型,多厂商自定义配置
|
||||
- 动态的prompt构建器,更拟人
|
||||
- 支持图片,转发消息,回复消息的识别
|
||||
- 错别字和多条回复功能:麦麦可以随机生成错别字,会多条发送回复以及对消息进行reply
|
||||
|
||||
### 😊 表情包功能
|
||||
- 支持根据发言内容发送对应情绪的表情包:未完善,可以用
|
||||
- 会自动偷群友的表情包(未完善,暂时禁用)目前有bug
|
||||
- 支持根据发言内容发送对应情绪的表情包
|
||||
- 会自动偷群友的表情包
|
||||
|
||||
### 📅 日程功能
|
||||
- 麦麦会自动生成一天的日程,实现更拟人的回复
|
||||
|
||||
### 🧠 记忆功能
|
||||
- 对聊天记录进行概括存储,在需要时调用,没写完
|
||||
- 对聊天记录进行概括存储,在需要时调用,待完善
|
||||
|
||||
### 📚 知识库功能
|
||||
- 基于embedding模型的知识库,手动放入txt会自动识别,写完了,暂时禁用
|
||||
|
||||
@@ -11,7 +11,7 @@ prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的
|
||||
|
||||
[message]
|
||||
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
|
||||
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
|
||||
max_context_size = 15 # 麦麦获得的上文数量
|
||||
emoji_chance = 0.2 # 麦麦使用表情包的概率
|
||||
ban_words = [
|
||||
# "403","张三"
|
||||
@@ -31,6 +31,7 @@ model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概
|
||||
|
||||
[memory]
|
||||
build_memory_interval = 300 # 记忆构建间隔 单位秒
|
||||
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
|
||||
|
||||
[others]
|
||||
enable_advance_output = true # 是否启用高级输出
|
||||
|
||||
@@ -98,7 +98,19 @@ async def monitor_relationships():
|
||||
async def build_memory_task():
|
||||
"""每30秒执行一次记忆构建"""
|
||||
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
||||
await hippocampus.build_memory(chat_size=30)
|
||||
await hippocampus.operation_build_memory(chat_size=30)
|
||||
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
||||
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
|
||||
async def forget_memory_task():
|
||||
"""每30秒执行一次记忆构建"""
|
||||
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
||||
await hippocampus.operation_forget_topic(percentage=0.1)
|
||||
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
||||
|
||||
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="build_memory")
|
||||
async def build_memory_task():
|
||||
"""每30秒执行一次记忆构建"""
|
||||
print("\033[1;32m[记忆整合]\033[0m 开始整合")
|
||||
await hippocampus.operation_merge_memory(percentage=0.1)
|
||||
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class BotConfig:
|
||||
ban_user_id = set()
|
||||
|
||||
build_memory_interval: int = 60 # 记忆构建间隔(秒)
|
||||
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
|
||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||
|
||||
@@ -155,6 +156,7 @@ class BotConfig:
|
||||
if "memory" in toml_dict:
|
||||
memory_config = toml_dict["memory"]
|
||||
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
||||
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
|
||||
|
||||
# 群组配置
|
||||
if "groups" in toml_dict:
|
||||
@@ -188,6 +190,6 @@ global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
|
||||
|
||||
if not global_config.enable_advance_output:
|
||||
# logger.remove()
|
||||
logger.remove()
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
from typing import Union, List, Optional, Deque, Dict
|
||||
from nonebot.adapters.onebot.v11 import Bot, MessageSegment
|
||||
import asyncio
|
||||
import random
|
||||
import os
|
||||
from .message import Message, Message_Thinking, MessageSet
|
||||
from .cq_code import CQCode
|
||||
from collections import deque
|
||||
import time
|
||||
from .storage import MessageStorage
|
||||
from .config import global_config
|
||||
from .cq_code import cq_code_tool
|
||||
|
||||
if os.name == "nt":
|
||||
from .message_visualizer import message_visualizer
|
||||
|
||||
|
||||
|
||||
class SendTemp:
|
||||
"""单个群组的临时消息队列管理器"""
|
||||
def __init__(self, group_id: int, max_size: int = 100):
|
||||
self.group_id = group_id
|
||||
self.max_size = max_size
|
||||
self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
|
||||
self.last_send_time = 0
|
||||
|
||||
def add(self, message: Message) -> None:
|
||||
"""按时间顺序添加消息到队列"""
|
||||
if not self.messages:
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 按时间顺序插入
|
||||
if message.time >= self.messages[-1].time:
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 使用二分查找找到合适的插入位置
|
||||
messages_list = list(self.messages)
|
||||
left, right = 0, len(messages_list)
|
||||
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if messages_list[mid].time < message.time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# 重建消息队列,保持时间顺序
|
||||
new_messages = deque(maxlen=self.max_size)
|
||||
new_messages.extend(messages_list[:left])
|
||||
new_messages.append(message)
|
||||
new_messages.extend(messages_list[left:])
|
||||
self.messages = new_messages
|
||||
def get_earliest_message(self) -> Optional[Message]:
|
||||
"""获取时间最早的消息"""
|
||||
message = self.messages.popleft() if self.messages else None
|
||||
return message
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空队列"""
|
||||
self.messages.clear()
|
||||
|
||||
def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
|
||||
"""获取所有待发送的消息"""
|
||||
if group_id is None:
|
||||
return list(self.messages)
|
||||
return [msg for msg in self.messages if msg.group_id == group_id]
|
||||
|
||||
def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
|
||||
"""查看下一条要发送的消息(不移除)"""
|
||||
return self.messages[0] if self.messages else None
|
||||
|
||||
def has_messages(self) -> bool:
|
||||
"""检查是否有待发送的消息"""
|
||||
return bool(self.messages)
|
||||
|
||||
def count(self, group_id: Optional[int] = None) -> int:
|
||||
"""获取待发送消息数量"""
|
||||
if group_id is None:
|
||||
return len(self.messages)
|
||||
return len([msg for msg in self.messages if msg.group_id == group_id])
|
||||
|
||||
def get_last_send_time(self) -> float:
|
||||
"""获取最后一次发送时间"""
|
||||
return self.last_send_time
|
||||
|
||||
def update_send_time(self):
|
||||
"""更新最后发送时间"""
|
||||
self.last_send_time = time.time()
|
||||
|
||||
class SendTempContainer:
|
||||
"""管理所有群组的消息缓存容器"""
|
||||
def __init__(self):
|
||||
self.temp_queues: Dict[int, SendTemp] = {}
|
||||
|
||||
def get_queue(self, group_id: int) -> SendTemp:
|
||||
"""获取或创建群组的消息队列"""
|
||||
if group_id not in self.temp_queues:
|
||||
self.temp_queues[group_id] = SendTemp(group_id)
|
||||
return self.temp_queues[group_id]
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到对应群组的队列"""
|
||||
queue = self.get_queue(message.group_id)
|
||||
queue.add(message)
|
||||
|
||||
def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
|
||||
"""获取指定群组的所有待发送消息"""
|
||||
queue = self.get_queue(group_id)
|
||||
return queue.get_all()
|
||||
|
||||
def has_messages(self, group_id: int) -> bool:
|
||||
"""检查指定群组是否有待发送消息"""
|
||||
queue = self.get_queue(group_id)
|
||||
return queue.has_messages()
|
||||
|
||||
def get_all_groups(self) -> List[int]:
|
||||
"""获取所有有待发送消息的群组ID"""
|
||||
return list(self.temp_queues.keys())
|
||||
|
||||
def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
|
||||
queue = self.get_queue(message_obj.group_id)
|
||||
# 使用列表解析找到匹配的消息索引
|
||||
matching_indices = [
|
||||
i for i, msg in enumerate(queue.messages)
|
||||
if msg.message_id == message_obj.message_id
|
||||
]
|
||||
|
||||
if not matching_indices:
|
||||
return False
|
||||
|
||||
index = matching_indices[0] # 获取第一个匹配的索引
|
||||
|
||||
# 将消息转换为列表以便修改
|
||||
messages = list(queue.messages)
|
||||
|
||||
# 根据消息类型处理
|
||||
if isinstance(message_obj, MessageSet):
|
||||
messages.pop(index)
|
||||
# 在原位置插入新消息组
|
||||
for i, single_message in enumerate(message_obj.messages):
|
||||
messages.insert(index + i, single_message)
|
||||
# print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
|
||||
else:
|
||||
# 直接替换原消息
|
||||
messages[index] = message_obj
|
||||
# print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
|
||||
|
||||
# 重建队列
|
||||
queue.messages.clear()
|
||||
for msg in messages:
|
||||
queue.messages.append(msg)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class MessageSendControl:
|
||||
"""消息发送控制器"""
|
||||
def __init__(self):
|
||||
self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
|
||||
self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
|
||||
self.max_retry = 3 # 最大重试次数
|
||||
self.send_temp_container = SendTempContainer()
|
||||
self._running = True
|
||||
self._paused = False
|
||||
self._current_bot = None
|
||||
self.storage = MessageStorage() # 添加存储实例
|
||||
try:
|
||||
message_visualizer.start()
|
||||
except(NameError):
|
||||
pass
|
||||
|
||||
async def process_group_messages(self, group_id: int):
|
||||
queue = self.send_temp_container.get_queue(group_id)
|
||||
if queue.has_messages():
|
||||
message = queue.peek_next()
|
||||
# 处理消息的逻辑
|
||||
if isinstance(message, Message_Thinking):
|
||||
message.update_thinking_time()
|
||||
thinking_time = message.thinking_time
|
||||
if message.interupt:
|
||||
print(f"\033[1;34m[调试]\033[0m 思考不打算回复,移除")
|
||||
queue.get_earliest_message()
|
||||
return
|
||||
elif thinking_time < 90: # 最少思考2秒
|
||||
if int(thinking_time) % 15 == 0:
|
||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒")
|
||||
return
|
||||
else:
|
||||
print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
|
||||
queue.get_earliest_message() # 移除超时的思考消息
|
||||
return
|
||||
elif isinstance(message, Message):
|
||||
message = queue.get_earliest_message()
|
||||
if message and message.processed_plain_text:
|
||||
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
|
||||
cost_time = round(time.time(), 2) - message.time
|
||||
if cost_time > 40:
|
||||
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
|
||||
cur_time = time.time()
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=group_id,
|
||||
message=str(message.processed_plain_text),
|
||||
auto_escape=False
|
||||
)
|
||||
cost_time = round(time.time(), 2) - cur_time
|
||||
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒")
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
|
||||
|
||||
if message.is_emoji:
|
||||
message.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(message, None)
|
||||
else:
|
||||
await self.storage.store_message(message, None)
|
||||
|
||||
|
||||
|
||||
queue.update_send_time()
|
||||
if queue.has_messages():
|
||||
await asyncio.sleep(
|
||||
random.uniform(
|
||||
self.message_interval[0],
|
||||
self.message_interval[1]
|
||||
)
|
||||
)
|
||||
|
||||
async def start_processor(self, bot: Bot):
|
||||
"""启动消息处理器"""
|
||||
self._current_bot = bot
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1.5)
|
||||
tasks = []
|
||||
for group_id in self.send_temp_container.get_all_groups():
|
||||
tasks.append(self.process_group_messages(group_id))
|
||||
|
||||
# 并行处理所有群组的消息
|
||||
await asyncio.gather(*tasks)
|
||||
try:
|
||||
message_visualizer.update_content(self.send_temp_container)
|
||||
except(NameError):
|
||||
pass
|
||||
|
||||
def set_typing_speed(self, min_speed: float, max_speed: float):
|
||||
"""设置打字速度范围"""
|
||||
self.typing_speed = (min_speed, max_speed)
|
||||
|
||||
# 创建全局实例
|
||||
message_sender_control = MessageSendControl()
|
||||
@@ -1,271 +0,0 @@
|
||||
from typing import List, Optional, Dict
|
||||
from .message import Message
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
class MessageStream:
|
||||
"""单个群组的消息流容器"""
|
||||
def __init__(self, group_id: int, max_size: int = 1000):
|
||||
self.group_id = group_id
|
||||
self.messages = deque(maxlen=max_size)
|
||||
self.max_size = max_size
|
||||
self.last_save_time = time.time()
|
||||
|
||||
# 确保日志目录存在
|
||||
self.log_dir = os.path.join("log", str(self.group_id))
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
# 启动自动保存任务
|
||||
asyncio.create_task(self._auto_save())
|
||||
|
||||
async def _auto_save(self):
|
||||
"""每30秒自动保存一次消息记录"""
|
||||
while True:
|
||||
await asyncio.sleep(30) # 等待30秒
|
||||
await self.save_to_log()
|
||||
|
||||
async def save_to_log(self):
|
||||
"""将消息保存到日志文件"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
# 只有有新消息时才保存
|
||||
if not self.messages or self.last_save_time == current_time:
|
||||
return
|
||||
|
||||
# 生成日志文件名 (使用当前日期)
|
||||
date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
|
||||
log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
|
||||
|
||||
# 获取需要保存的新消息
|
||||
new_messages = [
|
||||
msg for msg in self.messages
|
||||
if msg.time > self.last_save_time
|
||||
]
|
||||
|
||||
if not new_messages:
|
||||
return
|
||||
|
||||
# 将消息转换为可序列化的格式
|
||||
message_logs = []
|
||||
for msg in new_messages:
|
||||
message_logs.append({
|
||||
"time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
|
||||
"user_id": msg.user_id,
|
||||
"user_nickname": msg.user_nickname,
|
||||
"user_cardname": msg.user_cardname,
|
||||
"message_id": msg.message_id,
|
||||
"raw_message": msg.raw_message,
|
||||
"processed_text": msg.processed_plain_text
|
||||
})
|
||||
|
||||
# 追加写入日志文件
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
for log in message_logs:
|
||||
f.write(json.dumps(log, ensure_ascii=False) + "\n")
|
||||
|
||||
self.last_save_time = current_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""按时间顺序添加新消息到队列
|
||||
|
||||
使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
|
||||
|
||||
Args:
|
||||
message: Message对象,要添加的新消息
|
||||
"""
|
||||
|
||||
# 空队列或消息应该添加到末尾的情况
|
||||
if (not self.messages or
|
||||
message.time >= self.messages[-1].time):
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 消息应该添加到开头的情况
|
||||
if message.time <= self.messages[0].time:
|
||||
self.messages.appendleft(message)
|
||||
return
|
||||
|
||||
# 使用二分查找在现有队列中找到合适的插入位置
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].time < message.time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
temp = list(self.messages)
|
||||
temp.insert(left, message)
|
||||
|
||||
# 如果超出最大长度,移除多余的消息
|
||||
if len(temp) > self.max_size:
|
||||
temp = temp[-self.max_size:]
|
||||
|
||||
# 重建队列
|
||||
self.messages = deque(temp, maxlen=self.max_size)
|
||||
|
||||
async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
|
||||
"""从数据库中获取最近的消息记录
|
||||
|
||||
Args:
|
||||
count: 需要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Message]: 最近的消息列表
|
||||
"""
|
||||
try:
|
||||
from ...common.database import Database
|
||||
db = Database.get_instance()
|
||||
|
||||
# 从数据库中查询最近的消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": self.group_id},
|
||||
# {
|
||||
# "time": 1,
|
||||
# "user_id": 1,
|
||||
# "user_nickname": 1,
|
||||
# # "user_cardname": 1,
|
||||
# "message_id": 1,
|
||||
# "raw_message": 1,
|
||||
# "processed_text": 1
|
||||
# }
|
||||
).sort("time", -1).limit(count))
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 转换为 Message 对象
|
||||
from .message import Message
|
||||
messages = []
|
||||
for msg_data in recent_messages:
|
||||
try:
|
||||
msg = Message(
|
||||
time=msg_data["time"],
|
||||
user_id=msg_data["user_id"],
|
||||
user_nickname=msg_data.get("user_nickname", ""),
|
||||
user_cardname=msg_data.get("user_cardname", ""),
|
||||
message_id=msg_data["message_id"],
|
||||
raw_message=msg_data["raw_message"],
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=self.group_id
|
||||
)
|
||||
messages.append(msg)
|
||||
except KeyError:
|
||||
print("[WARNING] 数据库中存在无效的消息")
|
||||
continue
|
||||
|
||||
return list(reversed(messages)) # 返回按时间正序的消息
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent_messages(self, count: int = 10) -> List[Message]:
|
||||
"""获取最近的n条消息(从内存队列)"""
|
||||
print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
|
||||
return list(self.messages)[-count:]
|
||||
|
||||
def get_messages_in_timerange(self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None) -> List[Message]:
|
||||
"""获取时间范围内的消息"""
|
||||
if start_time is None:
|
||||
start_time = time.time() - 3600
|
||||
if end_time is None:
|
||||
end_time = time.time()
|
||||
|
||||
return [
|
||||
msg for msg in self.messages
|
||||
if start_time <= msg.time <= end_time
|
||||
]
|
||||
|
||||
def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
|
||||
"""获取特定用户的最近消息"""
|
||||
user_messages = [msg for msg in self.messages if msg.user_id == user_id]
|
||||
return user_messages[-count:]
|
||||
|
||||
def clear_old_messages(self, hours: int = 24) -> None:
|
||||
"""清理旧消息"""
|
||||
cutoff_time = time.time() - (hours * 3600)
|
||||
self.messages = deque(
|
||||
[msg for msg in self.messages if msg.time > cutoff_time],
|
||||
maxlen=self.max_size
|
||||
)
|
||||
|
||||
class MessageStreamContainer:
|
||||
"""管理所有群组的消息流容器"""
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.streams: Dict[int, MessageStream] = {}
|
||||
self.max_size = max_size
|
||||
|
||||
async def save_all_logs(self):
|
||||
"""保存所有群组的消息日志"""
|
||||
for stream in self.streams.values():
|
||||
await stream.save_to_log()
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到对应群组的消息流"""
|
||||
if not message.group_id:
|
||||
return
|
||||
|
||||
if message.group_id not in self.streams:
|
||||
self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
|
||||
|
||||
self.streams[message.group_id].add_message(message)
|
||||
|
||||
def get_stream(self, group_id: int) -> Optional[MessageStream]:
|
||||
"""获取特定群组的消息流"""
|
||||
return self.streams.get(group_id)
|
||||
|
||||
def get_all_streams(self) -> Dict[int, MessageStream]:
|
||||
"""获取所有群组的消息流"""
|
||||
return self.streams
|
||||
|
||||
def clear_old_messages(self, hours: int = 24) -> None:
|
||||
"""清理所有群组的旧消息"""
|
||||
for stream in self.streams.values():
|
||||
stream.clear_old_messages(hours)
|
||||
|
||||
def get_group_stats(self, group_id: int) -> Dict:
|
||||
"""获取群组的消息统计信息"""
|
||||
stream = self.streams.get(group_id)
|
||||
if not stream:
|
||||
return {
|
||||
"total_messages": 0,
|
||||
"unique_users": 0,
|
||||
"active_hours": [],
|
||||
"most_active_user": None
|
||||
}
|
||||
|
||||
messages = stream.messages
|
||||
user_counts = {}
|
||||
hour_counts = {}
|
||||
|
||||
for msg in messages:
|
||||
user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
|
||||
hour = datetime.fromtimestamp(msg.time).hour
|
||||
hour_counts[hour] = hour_counts.get(hour, 0) + 1
|
||||
|
||||
most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
|
||||
active_hours = sorted(
|
||||
hour_counts.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:5]
|
||||
|
||||
return {
|
||||
"total_messages": len(messages),
|
||||
"unique_users": len(user_counts),
|
||||
"active_hours": active_hours,
|
||||
"most_active_user": most_active_user
|
||||
}
|
||||
|
||||
# 创建全局实例
|
||||
message_stream_container = MessageStreamContainer()
|
||||
@@ -1,138 +0,0 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import queue
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
from .message import Message_Thinking
|
||||
|
||||
class MessageVisualizer:
|
||||
def __init__(self):
|
||||
self.process = None
|
||||
self.message_queue = queue.Queue()
|
||||
self.is_running = False
|
||||
self.content_file = "message_queue_content.txt"
|
||||
|
||||
def start(self):
|
||||
if self.process is None:
|
||||
# 创建用于显示的批处理文件
|
||||
with open("message_queue_window.bat", "w", encoding="utf-8") as f:
|
||||
f.write('@echo off\n')
|
||||
f.write('chcp 65001\n') # 设置UTF-8编码
|
||||
f.write('title Message Queue Visualizer\n')
|
||||
f.write('echo Waiting for message queue updates...\n')
|
||||
f.write(':loop\n')
|
||||
f.write('if exist "queue_update.txt" (\n')
|
||||
f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
|
||||
f.write(' del "queue_update.txt"\n')
|
||||
f.write(' cls\n')
|
||||
f.write(' type "message_queue_content.txt"\n')
|
||||
f.write(')\n')
|
||||
f.write('timeout /t 1 /nobreak >nul\n')
|
||||
f.write('goto loop\n')
|
||||
|
||||
# 清空内容文件
|
||||
with open(self.content_file, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
|
||||
# 启动新窗口
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
self.process = subprocess.Popen(
|
||||
['cmd', '/c', 'start', 'message_queue_window.bat'],
|
||||
shell=True,
|
||||
startupinfo=startupinfo
|
||||
)
|
||||
self.is_running = True
|
||||
|
||||
# 启动处理线程
|
||||
threading.Thread(target=self._process_messages, daemon=True).start()
|
||||
|
||||
def _process_messages(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
# 获取新消息
|
||||
text = self.message_queue.get(timeout=1)
|
||||
# 写入更新文件
|
||||
with open("queue_update.txt", "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"处理队列可视化内容时出错: {e}")
|
||||
|
||||
def update_content(self, send_temp_container):
|
||||
"""更新显示内容"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
display_text = f"Message Queue Status - {current_time}\n"
|
||||
display_text += "=" * 50 + "\n\n"
|
||||
|
||||
# 遍历所有群组的队列
|
||||
for group_id, queue in send_temp_container.temp_queues.items():
|
||||
display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
|
||||
display_text += f"消息队列长度: {len(queue.messages)}\n"
|
||||
display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
|
||||
display_text += "\n消息队列内容:\n"
|
||||
|
||||
# 显示队列中的消息
|
||||
if not queue.messages:
|
||||
display_text += " [空队列]\n"
|
||||
else:
|
||||
for i, msg in enumerate(queue.messages):
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
||||
display_text += f"\n--- 消息 {i+1} ---\n"
|
||||
|
||||
if isinstance(msg, Message_Thinking):
|
||||
display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
|
||||
display_text += f"时间: {msg_time}\n"
|
||||
display_text += f"消息ID: {msg.message_id}\n"
|
||||
display_text += f"群组: {msg.group_id}\n"
|
||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
||||
display_text += f"内容: {msg.thinking_text}\n"
|
||||
display_text += f"思考时间: {int(msg.thinking_time)}秒\n"
|
||||
else:
|
||||
display_text += f"类型: 普通消息\n"
|
||||
display_text += f"时间: {msg_time}\n"
|
||||
display_text += f"消息ID: {msg.message_id}\n"
|
||||
display_text += f"群组: {msg.group_id}\n"
|
||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
||||
if hasattr(msg, 'is_emoji') and msg.is_emoji:
|
||||
display_text += f"内容: [表情包消息]\n"
|
||||
else:
|
||||
# 显示原始消息和处理后的消息
|
||||
display_text += f"原始内容: {msg.raw_message[:50]}...\n"
|
||||
display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
|
||||
|
||||
if msg.reply_message:
|
||||
display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
|
||||
|
||||
display_text += f"\n{'-' * 50}\n"
|
||||
|
||||
# 添加统计信息
|
||||
display_text += "\n总体统计:\n"
|
||||
display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
|
||||
total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
|
||||
display_text += f"总消息数: {total_messages}\n"
|
||||
thinking_messages = sum(
|
||||
sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
|
||||
for q in send_temp_container.temp_queues.values()
|
||||
)
|
||||
display_text += f"思考中消息数: {thinking_messages}\n"
|
||||
|
||||
self.message_queue.put(display_text)
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.process = None
|
||||
# 清理文件
|
||||
for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
|
||||
if os.path.exists(file):
|
||||
os.remove(file)
|
||||
|
||||
# 创建全局单例
|
||||
message_visualizer = MessageVisualizer()
|
||||
@@ -9,7 +9,7 @@ class WillingManager:
|
||||
async def _decay_reply_willing(self):
|
||||
"""定期衰减回复意愿"""
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
await asyncio.sleep(5)
|
||||
for group_id in self.group_reply_willing:
|
||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
||||
|
||||
@@ -39,11 +39,11 @@ class WillingManager:
|
||||
|
||||
if interested_rate > 0.65:
|
||||
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||
current_willing += interested_rate-0.5
|
||||
current_willing += interested_rate-0.6
|
||||
|
||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||
|
||||
reply_probability = max((current_willing - 0.5) * 2, 0)
|
||||
reply_probability = max((current_willing - 0.55) * 1.9, 0)
|
||||
if group_id not in config.talk_allowed_groups:
|
||||
current_willing = 0
|
||||
reply_probability = 0
|
||||
@@ -65,7 +65,7 @@ class WillingManager:
|
||||
"""发送消息后提高群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
if current_willing < 1:
|
||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
|
||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
|
||||
|
||||
async def ensure_started(self):
|
||||
"""确保衰减任务已启动"""
|
||||
|
||||
@@ -22,63 +22,6 @@ from src.common.database import Database # 使用正确的导入语法
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
}
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
@@ -232,19 +175,10 @@ def main():
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
# 创建LLM模型实例
|
||||
|
||||
memory_graph.load_graph_from_db()
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
# visualize_graph(memory_graph, color_by_memory=False)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
# visualize_graph(memory_graph, color_by_memory=True)
|
||||
visualize_graph_lite(memory_graph, color_by_memory=True)
|
||||
|
||||
# memory_graph.save_graph_to_db()
|
||||
# 只显示一次优化后的图形
|
||||
visualize_graph_lite(memory_graph)
|
||||
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
@@ -327,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
nx.draw(G, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
node_size=200,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
@@ -353,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count <= 2 or degree <= 2:
|
||||
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
@@ -366,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# 保存图到本地
|
||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
# 计算节点大小和颜色
|
||||
node_colors = []
|
||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||
node_sizes = []
|
||||
nodes = list(H.nodes())
|
||||
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
count = 1 if memory_items else 0
|
||||
memory_counts.append(count)
|
||||
max_memories = max(memory_counts) if memory_counts else 1
|
||||
# 获取最大记忆数和最大度数用于归一化
|
||||
max_memories = 1
|
||||
max_degree = 1
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
max_memories = max(max_memories, memory_count)
|
||||
max_degree = max(max_degree, degree)
|
||||
|
||||
# 计算每个节点的大小和颜色
|
||||
for node in nodes:
|
||||
# 计算节点大小(基于记忆数量)
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
|
||||
node_sizes.append(size)
|
||||
|
||||
for count in memory_counts:
|
||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||
if max_memories > 0:
|
||||
intensity = min(1.0, count / max_memories)
|
||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||
for node in nodes:
|
||||
degree = H.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1)
|
||||
node_colors.append(color)
|
||||
# 计算节点颜色(基于连接数)
|
||||
degree = H.degree(node)
|
||||
# 红色分量随着度数增加而增加
|
||||
red = min(1.0, degree / max_degree)
|
||||
# 蓝色分量随着度数减少而增加
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
node_colors.append(color)
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.7)
|
||||
|
||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
|
||||
@@ -17,7 +17,12 @@ class Memory_graph:
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
self.G.add_edge(concept1, concept2)
|
||||
# 如果边已存在,增加 strength
|
||||
if self.G.has_edge(concept1, concept2):
|
||||
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||
else:
|
||||
# 如果是新边,初始化 strength 为 1
|
||||
self.G.add_edge(concept1, concept2, strength=1)
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
if concept in self.G:
|
||||
@@ -38,9 +43,7 @@ class Memory_graph:
|
||||
if concept in self.G:
|
||||
# 从图中获取节点数据
|
||||
node_data = self.G.nodes[concept]
|
||||
# print(node_data)
|
||||
# 创建新的Memory_dot对象
|
||||
return concept,node_data
|
||||
return concept, node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
@@ -52,7 +55,6 @@ class Memory_graph:
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
# print(f"第一层: {topic}")
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
@@ -69,7 +71,6 @@ class Memory_graph:
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
# print(f"第二层: {neighbor}")
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
@@ -87,79 +88,38 @@ class Memory_graph:
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
def save_graph_to_db(self):
|
||||
# 保存节点
|
||||
for node in self.G.nodes(data=True):
|
||||
concept = node[0]
|
||||
memory_items = node[1].get('memory_items', [])
|
||||
def forget_topic(self, topic):
|
||||
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
|
||||
if topic not in self.G:
|
||||
return None
|
||||
|
||||
# 查找是否存在同名节点
|
||||
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
||||
if existing_node:
|
||||
# 如果存在,合并memory_items并去重
|
||||
existing_items = existing_node.get('memory_items', [])
|
||||
if not isinstance(existing_items, list):
|
||||
existing_items = [existing_items] if existing_items else []
|
||||
|
||||
# 合并并去重
|
||||
all_items = list(set(existing_items + memory_items))
|
||||
|
||||
# 更新节点
|
||||
self.db.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {'memory_items': all_items}}
|
||||
)
|
||||
else:
|
||||
# 如果不存在,创建新节点
|
||||
node_data = {
|
||||
'concept': concept,
|
||||
'memory_items': memory_items
|
||||
}
|
||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||
# 获取话题节点数据
|
||||
node_data = self.G.nodes[topic]
|
||||
|
||||
# 保存边
|
||||
for edge in self.G.edges():
|
||||
source, target = edge
|
||||
# 如果节点存在memory_items
|
||||
if 'memory_items' in node_data:
|
||||
memory_items = node_data['memory_items']
|
||||
|
||||
# 查找是否存在同样的边
|
||||
existing_edge = self.db.db.graph_data.edges.find_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
|
||||
if existing_edge:
|
||||
# 如果存在,增加num属性
|
||||
num = existing_edge.get('num', 1) + 1
|
||||
self.db.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {'num': num}}
|
||||
)
|
||||
else:
|
||||
# 如果不存在,创建新边
|
||||
edge_data = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'num': 1
|
||||
}
|
||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||
|
||||
def load_graph_from_db(self):
|
||||
# 清空当前图
|
||||
self.G.clear()
|
||||
# 加载节点
|
||||
nodes = self.db.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
memory_items = node.get('memory_items', [])
|
||||
# 确保memory_items是列表
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = self.db.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||
|
||||
|
||||
|
||||
|
||||
# 如果有记忆项可以删除
|
||||
if memory_items:
|
||||
# 随机选择一个记忆项删除
|
||||
removed_item = random.choice(memory_items)
|
||||
memory_items.remove(removed_item)
|
||||
|
||||
# 更新节点的记忆项
|
||||
if memory_items:
|
||||
self.G.nodes[topic]['memory_items'] = memory_items
|
||||
else:
|
||||
# 如果没有记忆项了,删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
|
||||
return removed_item
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# 海马体
|
||||
@@ -169,23 +129,33 @@ class Hippocampus:
|
||||
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
||||
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
||||
|
||||
def calculate_node_hash(self, concept, memory_items):
|
||||
"""计算节点的特征值"""
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
sorted_items = sorted(memory_items)
|
||||
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||
return hash(content)
|
||||
|
||||
def calculate_edge_hash(self, source, target):
|
||||
"""计算边的特征值"""
|
||||
nodes = sorted([source, target])
|
||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||
|
||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
#短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('far')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
return chat_text
|
||||
@@ -207,8 +177,8 @@ class Hippocampus:
|
||||
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
|
||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||
return compressed_memory
|
||||
|
||||
async def build_memory(self,chat_size=12):
|
||||
|
||||
async def operation_build_memory(self,chat_size=12):
|
||||
#最近消息获取频率
|
||||
time_frequency = {'near':1,'mid':2,'far':2}
|
||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||
@@ -236,7 +206,247 @@ class Hippocampus:
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
else:
|
||||
print(f"空消息 跳过")
|
||||
self.memory_graph.save_graph_to_db()
|
||||
self.sync_memory_to_db()
|
||||
|
||||
def sync_memory_to_db(self):
|
||||
"""检查并同步内存中的图结构与数据库"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
db_nodes_dict = {node['concept']: node for node in db_nodes}
|
||||
|
||||
# 检查并更新节点
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 计算内存中节点的特征值
|
||||
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||
|
||||
if concept not in db_nodes_dict:
|
||||
# 数据库中缺少的节点,添加
|
||||
node_data = {
|
||||
'concept': concept,
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
db_hash = db_node.get('hash', None)
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}}
|
||||
)
|
||||
|
||||
# 检查并删除数据库中多余的节点
|
||||
memory_concepts = set(node[0] for node in memory_nodes)
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
# 创建边的哈希值字典
|
||||
db_edge_dict = {}
|
||||
for edge in db_edges:
|
||||
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
|
||||
db_edge_dict[(edge['source'], edge['target'])] = {
|
||||
'hash': edge_hash,
|
||||
'strength': edge.get('strength', 1)
|
||||
}
|
||||
|
||||
# 检查并更新边
|
||||
for source, target in memory_edges:
|
||||
edge_hash = self.calculate_edge_hash(source, target)
|
||||
edge_key = (source, target)
|
||||
strength = self.memory_graph.G[source][target].get('strength', 1)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
# 添加新边
|
||||
edge_data = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'strength': strength,
|
||||
'hash': edge_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
'strength': strength
|
||||
}}
|
||||
)
|
||||
|
||||
# 删除多余的边
|
||||
memory_edge_set = set(memory_edges)
|
||||
for edge_key in db_edge_dict:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
# 清空当前图
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
# 确保memory_items是列表
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
strength = edge.get('strength', 1) # 获取 strength,默认为 1
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
self.memory_graph.G.add_edge(source, target, strength=strength)
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.1):
|
||||
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
|
||||
# 获取所有节点
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
# 计算要检查的节点数量
|
||||
check_count = max(1, int(len(all_nodes) * percentage))
|
||||
# 随机选择节点
|
||||
nodes_to_check = random.sample(all_nodes, check_count)
|
||||
|
||||
forgotten_nodes = []
|
||||
for node in nodes_to_check:
|
||||
# 获取节点的连接数
|
||||
connections = self.memory_graph.G.degree(node)
|
||||
|
||||
# 获取节点的内容条数
|
||||
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
|
||||
# 检查连接强度
|
||||
weak_connections = True
|
||||
if connections > 1: # 只有当连接数大于1时才检查强度
|
||||
for neighbor in self.memory_graph.G.neighbors(node):
|
||||
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
|
||||
if strength > 2:
|
||||
weak_connections = False
|
||||
break
|
||||
|
||||
# 如果满足遗忘条件
|
||||
if (connections <= 1 and weak_connections) or content_count <= 2:
|
||||
removed_item = self.memory_graph.forget_topic(node)
|
||||
if removed_item:
|
||||
forgotten_nodes.append((node, removed_item))
|
||||
print(f"遗忘节点 {node} 的记忆: {removed_item}")
|
||||
|
||||
# 同步到数据库
|
||||
if forgotten_nodes:
|
||||
self.sync_memory_to_db()
|
||||
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
||||
else:
|
||||
print("本次检查没有节点满足遗忘条件")
|
||||
|
||||
async def merge_memory(self, topic):
|
||||
"""
|
||||
对指定话题的记忆进行合并压缩
|
||||
|
||||
Args:
|
||||
topic: 要合并的话题节点
|
||||
"""
|
||||
# 获取节点的记忆项
|
||||
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 如果记忆项不足,直接返回
|
||||
if len(memory_items) < 10:
|
||||
return
|
||||
|
||||
# 随机选择10条记忆
|
||||
selected_memories = random.sample(memory_items, 10)
|
||||
|
||||
# 拼接成文本
|
||||
merged_text = "\n".join(selected_memories)
|
||||
print(f"\n[合并记忆] 话题: {topic}")
|
||||
print(f"选择的记忆:\n{merged_text}")
|
||||
|
||||
# 使用memory_compress生成新的压缩记忆
|
||||
compressed_memories = await self.memory_compress(merged_text, 0.1)
|
||||
|
||||
# 从原记忆列表中移除被选中的记忆
|
||||
for memory in selected_memories:
|
||||
memory_items.remove(memory)
|
||||
|
||||
# 添加新的压缩记忆
|
||||
for _, compressed_memory in compressed_memories:
|
||||
memory_items.append(compressed_memory)
|
||||
print(f"添加压缩记忆: {compressed_memory}")
|
||||
|
||||
# 更新节点的记忆项
|
||||
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||
|
||||
async def operation_merge_memory(self, percentage=0.1):
|
||||
"""
|
||||
随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
|
||||
|
||||
Args:
|
||||
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
# 计算要检查的节点数量
|
||||
check_count = max(1, int(len(all_nodes) * percentage))
|
||||
# 随机选择节点
|
||||
nodes_to_check = random.sample(all_nodes, check_count)
|
||||
|
||||
merged_nodes = []
|
||||
for node in nodes_to_check:
|
||||
# 获取节点的内容条数
|
||||
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
|
||||
# 如果内容数量超过100,进行合并
|
||||
if content_count > 100:
|
||||
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
|
||||
await self.merge_memory(node)
|
||||
merged_nodes.append(node)
|
||||
|
||||
# 同步到数据库
|
||||
if merged_nodes:
|
||||
self.sync_memory_to_db()
|
||||
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
||||
else:
|
||||
print("\n本次检查没有需要合并的节点")
|
||||
|
||||
|
||||
def segment_text(text):
|
||||
@@ -268,10 +478,10 @@ Database.initialize(
|
||||
)
|
||||
#创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
#加载数据库中存储的记忆图
|
||||
memory_graph.load_graph_from_db()
|
||||
#创建海马体
|
||||
hippocampus = Hippocampus(memory_graph)
|
||||
#从数据库加载记忆图
|
||||
hippocampus.sync_memory_from_db()
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
@@ -1,463 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
from collections import Counter
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
# from chat.config import global_config
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
from src.plugins.memory_system.llm_module import LLMModel
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
# 统计字符频率
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
# 计算熵
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
||||
chat_text = ''
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
|
||||
return chat_text
|
||||
|
||||
return ''
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
self.G.add_edge(concept1, concept2)
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
if concept in self.G:
|
||||
# 如果节点已存在,将新记忆添加到现有列表中
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
# 如果当前不是列表,将其转换为列表
|
||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||
self.G.nodes[concept]['memory_items'].append(memory)
|
||||
else:
|
||||
self.G.nodes[concept]['memory_items'] = [memory]
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept, memory_items=[memory])
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
if concept in self.G:
|
||||
# 从图中获取节点数据
|
||||
node_data = self.G.nodes[concept]
|
||||
# print(node_data)
|
||||
# 创建新的Memory_dot对象
|
||||
return concept,node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
return [], []
|
||||
|
||||
first_layer_items = []
|
||||
second_layer_items = []
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
# print(f"第一层: {topic}")
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
# print(f"第二层: {neighbor}")
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
def store_memory(self):
|
||||
for node in self.G.nodes():
|
||||
dot_data = {
|
||||
"concept": node
|
||||
}
|
||||
self.db.db.store_memory_dots.insert_one(dot_data)
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
|
||||
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
for record in chat_record:
|
||||
if record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
try:
|
||||
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
|
||||
except:
|
||||
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||
return chat_text
|
||||
|
||||
return [] # 如果没有找到记录,返回空列表
|
||||
|
||||
def save_graph_to_db(self):
|
||||
# 保存节点
|
||||
for node in self.G.nodes(data=True):
|
||||
concept = node[0]
|
||||
memory_items = node[1].get('memory_items', [])
|
||||
|
||||
# 查找是否存在同名节点
|
||||
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
||||
if existing_node:
|
||||
# 如果存在,合并memory_items并去重
|
||||
existing_items = existing_node.get('memory_items', [])
|
||||
if not isinstance(existing_items, list):
|
||||
existing_items = [existing_items] if existing_items else []
|
||||
|
||||
# 合并并去重
|
||||
all_items = list(set(existing_items + memory_items))
|
||||
|
||||
# 更新节点
|
||||
self.db.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {'memory_items': all_items}}
|
||||
)
|
||||
else:
|
||||
# 如果不存在,创建新节点
|
||||
node_data = {
|
||||
'concept': concept,
|
||||
'memory_items': memory_items
|
||||
}
|
||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
||||
|
||||
# 保存边
|
||||
for edge in self.G.edges():
|
||||
source, target = edge
|
||||
|
||||
# 查找是否存在同样的边
|
||||
existing_edge = self.db.db.graph_data.edges.find_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
|
||||
if existing_edge:
|
||||
# 如果存在,增加num属性
|
||||
num = existing_edge.get('num', 1) + 1
|
||||
self.db.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {'num': num}}
|
||||
)
|
||||
else:
|
||||
# 如果不存在,创建新边
|
||||
edge_data = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'num': 1
|
||||
}
|
||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
||||
|
||||
def load_graph_from_db(self):
|
||||
# 清空当前图
|
||||
self.G.clear()
|
||||
# 加载节点
|
||||
nodes = self.db.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
memory_items = node.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||
# 加载边
|
||||
edges = self.db.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||
|
||||
# 海马体
|
||||
class Hippocampus:
|
||||
def __init__(self,memory_graph:Memory_graph):
|
||||
self.memory_graph = memory_graph
|
||||
self.llm_model = LLMModel()
|
||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
#短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('far')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
return chat_text
|
||||
|
||||
def build_memory(self,chat_size=12):
|
||||
#最近消息获取频率
|
||||
time_frequency = {'near':1,'mid':2,'far':2}
|
||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||
|
||||
#加载进度可视化
|
||||
for i, input_text in enumerate(memory_sample, 1):
|
||||
progress = (i / len(memory_sample)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_sample))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||
# print(f"第{i}条消息: {input_text}")
|
||||
if input_text:
|
||||
# 生成压缩后记忆
|
||||
first_memory = set()
|
||||
first_memory = self.memory_compress(input_text, 2.5)
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in first_memory:
|
||||
topics = segment_text(topic)
|
||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||
for split_topic in topics:
|
||||
self.memory_graph.add_dot(split_topic,memory)
|
||||
for split_topic in topics:
|
||||
for other_split_topic in topics:
|
||||
if split_topic != other_split_topic:
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
else:
|
||||
print(f"空消息 跳过")
|
||||
|
||||
self.memory_graph.save_graph_to_db()
|
||||
|
||||
def memory_compress(self, input_text, rate=1):
|
||||
information_content = calculate_information_content(input_text)
|
||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||
topic_prompt = find_topic(input_text, topic_num)
|
||||
topic_response = self.llm_model.generate_response(topic_prompt)
|
||||
# 检查 topic_response 是否为元组
|
||||
if isinstance(topic_response, tuple):
|
||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||
else:
|
||||
topics = topic_response.split(",")
|
||||
compressed_memory = set()
|
||||
for topic in topics:
|
||||
topic_what_prompt = topic_what(input_text,topic)
|
||||
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||
return compressed_memory
|
||||
|
||||
def segment_text(text):
|
||||
seg_text = list(jieba.cut(text))
|
||||
return seg_text
|
||||
|
||||
def find_topic(text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||
return prompt
|
||||
|
||||
def topic_what(text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||
return prompt
|
||||
|
||||
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
# 创建一个新图用于可视化
|
||||
H = G.copy()
|
||||
|
||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||
nodes_to_remove = []
|
||||
for node in H.nodes():
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
degree = H.degree(node)
|
||||
if memory_count <= 1 or degree <= 2:
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
# 如果过滤后没有节点,则返回
|
||||
if len(H.nodes()) == 0:
|
||||
print("过滤后没有符合条件的节点可显示")
|
||||
return
|
||||
|
||||
# 保存图到本地
|
||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
# 根据连接条数或记忆数量设置节点颜色
|
||||
node_colors = []
|
||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||
|
||||
if color_by_memory:
|
||||
# 计算每个节点的记忆数量
|
||||
memory_counts = []
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
if isinstance(memory_items, list):
|
||||
count = len(memory_items)
|
||||
else:
|
||||
count = 1 if memory_items else 0
|
||||
memory_counts.append(count)
|
||||
max_memories = max(memory_counts) if memory_counts else 1
|
||||
|
||||
for count in memory_counts:
|
||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||
if max_memories > 0:
|
||||
intensity = min(1.0, count / max_memories)
|
||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||
else:
|
||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||
node_colors.append(color)
|
||||
else:
|
||||
# 使用原来的连接数量着色方案
|
||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||
for node in nodes:
|
||||
degree = H.degree(node)
|
||||
if max_degree > 0:
|
||||
red = min(1.0, degree / max_degree)
|
||||
blue = 1.0 - red
|
||||
color = (red, 0, blue)
|
||||
else:
|
||||
color = (0, 0, 1)
|
||||
node_colors.append(color)
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=2000,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold')
|
||||
|
||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
# 加载数据库中存储的记忆图
|
||||
memory_graph.load_graph_from_db()
|
||||
# 创建海马体
|
||||
hippocampus = Hippocampus(memory_graph)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
|
||||
# 构建记忆
|
||||
hippocampus.build_memory(chat_size=25)
|
||||
|
||||
# 展示两种不同的可视化方式
|
||||
print("\n按连接数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=False)
|
||||
|
||||
print("\n按记忆数量着色的图谱:")
|
||||
visualize_graph(memory_graph, color_by_memory=True)
|
||||
|
||||
# 交互式查询
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
if items_list:
|
||||
for memory_item in items_list:
|
||||
print(memory_item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
while True:
|
||||
query = input("请输入问题:")
|
||||
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
|
||||
topic_prompt = find_topic(query, 3)
|
||||
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
|
||||
# 检查 topic_response 是否为元组
|
||||
if isinstance(topic_response, tuple):
|
||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||
else:
|
||||
topics = topic_response.split(",")
|
||||
print(topics)
|
||||
|
||||
for keyword in topics:
|
||||
items_list = memory_graph.get_related_item(keyword)
|
||||
if items_list:
|
||||
print(items_list)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
805
src/plugins/memory_system/memory_manual_build.py
Normal file
805
src/plugins/memory_system/memory_manual_build.py
Normal file
@@ -0,0 +1,805 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
from collections import Counter
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import pymongo
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from snownlp import SnowNLP
|
||||
# from chat.config import global_config
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database
|
||||
from src.plugins.memory_system.offline_llm import LLMModel
|
||||
|
||||
# 获取当前文件的目录
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
# 获取项目根目录(上三层目录)
|
||||
project_root = current_dir.parent.parent.parent
|
||||
# env.dev文件路径
|
||||
env_path = project_root / ".env.dev"
|
||||
|
||||
# 加载环境变量
|
||||
if env_path.exists():
|
||||
logger.info(f"从 {env_path} 加载环境变量")
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
logger.warning(f"未找到环境变量文件: {env_path}")
|
||||
logger.info("将使用默认配置")
|
||||
|
||||
class Database:
|
||||
_instance = None
|
||||
db = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not Database.db:
|
||||
Database.initialize(
|
||||
host=os.getenv("MONGODB_HOST"),
|
||||
port=int(os.getenv("MONGODB_PORT")),
|
||||
db_name=os.getenv("DATABASE_NAME"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
|
||||
try:
|
||||
if username and password:
|
||||
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
|
||||
else:
|
||||
uri = f"mongodb://{host}:{port}"
|
||||
|
||||
client = pymongo.MongoClient(uri)
|
||||
cls.db = client[db_name]
|
||||
# 测试连接
|
||||
client.server_info()
|
||||
logger.success("MongoDB连接成功!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化MongoDB失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
return entropy
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
||||
chat_text = ''
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
for record in chat_record:
|
||||
chat_text += record["detailed_plain_text"]
|
||||
return chat_text
|
||||
|
||||
return ''
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
# 如果边已存在,增加 strength
|
||||
if self.G.has_edge(concept1, concept2):
|
||||
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||
else:
|
||||
# 如果是新边,初始化 strength 为 1
|
||||
self.G.add_edge(concept1, concept2, strength=1)
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
if concept in self.G:
|
||||
# 如果节点已存在,将新记忆添加到现有列表中
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
# 如果当前不是列表,将其转换为列表
|
||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||
self.G.nodes[concept]['memory_items'].append(memory)
|
||||
else:
|
||||
self.G.nodes[concept]['memory_items'] = [memory]
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept, memory_items=[memory])
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
if concept in self.G:
|
||||
# 从图中获取节点数据
|
||||
node_data = self.G.nodes[concept]
|
||||
return concept, node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
return [], []
|
||||
|
||||
first_layer_items = []
|
||||
second_layer_items = []
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
node_data = self.get_dot(neighbor)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
if 'memory_items' in data:
|
||||
memory_items = data['memory_items']
|
||||
if isinstance(memory_items, list):
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
# 海马体
|
||||
class Hippocampus:
|
||||
def __init__(self, memory_graph: Memory_graph):
|
||||
self.memory_graph = memory_graph
|
||||
self.llm_model = LLMModel()
|
||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||
|
||||
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_text = []
|
||||
#短期:1h 中期:4h 长期:24h
|
||||
for _ in range(time_frequency.get('near')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
for _ in range(time_frequency.get('far')): # 循环10次
|
||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||
chat_text.append(chat_)
|
||||
return chat_text
|
||||
|
||||
def calculate_topic_num(self,text, compress_rate):
|
||||
"""计算文本的话题数量"""
|
||||
information_content = calculate_information_content(text)
|
||||
topic_by_length = text.count('\n')*compress_rate
|
||||
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
|
||||
topic_num = int((topic_by_length + topic_by_information_content)/2)
|
||||
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
|
||||
return topic_num
|
||||
|
||||
async def memory_compress(self, input_text, compress_rate=0.1):
|
||||
print(input_text)
|
||||
|
||||
#获取topics
|
||||
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
||||
topics_response = await self.llm_model_small.generate_response_async(self.find_topic_llm(input_text, topic_num))
|
||||
topics = topics_response[0].split(",")
|
||||
print(f"话题: {topics}")
|
||||
|
||||
# 创建所有话题的请求任务
|
||||
tasks = []
|
||||
for topic in topics:
|
||||
topic_what_prompt = self.topic_what(input_text, topic)
|
||||
# 创建异步任务
|
||||
task = self.llm_model_small.generate_response_async(topic_what_prompt)
|
||||
tasks.append((topic.strip(), task))
|
||||
|
||||
# 等待所有任务完成
|
||||
compressed_memory = set()
|
||||
for topic, task in tasks:
|
||||
response = await task
|
||||
if response:
|
||||
compressed_memory.add((topic, response[0]))
|
||||
|
||||
return compressed_memory
|
||||
|
||||
async def operation_build_memory(self, chat_size=12):
|
||||
#最近消息获取频率
|
||||
time_frequency = {'near':1,'mid':2,'far':2}
|
||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||
|
||||
for i, input_text in enumerate(memory_sample, 1):
|
||||
#加载进度可视化
|
||||
progress = (i / len(memory_sample)) * 100
|
||||
bar_length = 30
|
||||
filled_length = int(bar_length * i // len(memory_sample))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||
|
||||
if input_text:
|
||||
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
|
||||
compressed_memory = set()
|
||||
compress_rate = 0.15
|
||||
compressed_memory = await self.memory_compress(input_text,compress_rate)
|
||||
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
|
||||
|
||||
#将记忆加入到图谱中
|
||||
for topic, memory in compressed_memory:
|
||||
# 将jieba分词结果转换为列表以便多次使用
|
||||
topics = list(jieba.cut(topic))
|
||||
print(f"\033[1;34m话题\033[0m: {topic}")
|
||||
print(f"\033[1;34m分词结果\033[0m: {topics}")
|
||||
print(f"\033[1;34m记忆\033[0m: {memory}")
|
||||
|
||||
# 如果分词结果少于2个词,跳过连接
|
||||
if len(topics) < 2:
|
||||
print(f"\033[1;31m分词结果少于2个词,跳过连接\033[0m")
|
||||
# 仍然添加单个节点
|
||||
for split_topic in topics:
|
||||
self.memory_graph.add_dot(split_topic, memory)
|
||||
continue
|
||||
|
||||
# 先添加所有节点
|
||||
for split_topic in topics:
|
||||
print(f"\033[1;32m添加节点\033[0m: {split_topic}")
|
||||
self.memory_graph.add_dot(split_topic, memory)
|
||||
|
||||
# 再添加节点之间的连接
|
||||
for i, split_topic in enumerate(topics):
|
||||
for j, other_split_topic in enumerate(topics):
|
||||
if i < j: # 只连接一次,避免重复连接
|
||||
print(f"\033[1;32m连接节点\033[0m: {split_topic} 和 {other_split_topic}")
|
||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||
else:
|
||||
print(f"空消息 跳过")
|
||||
|
||||
# 每处理完一条消息就同步一次到数据库
|
||||
self.sync_memory_to_db_2()
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
"""
|
||||
从数据库同步数据到内存中的图结构
|
||||
将清空当前内存中的图,并从数据库重新加载所有节点和边
|
||||
"""
|
||||
# 清空当前图
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
# 确保memory_items是列表
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
strength = edge.get('strength', 1) # 获取 strength,默认为 1
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
self.memory_graph.G.add_edge(source, target, strength=strength)
|
||||
|
||||
logger.success("从数据库同步记忆图谱完成")
|
||||
|
||||
def calculate_node_hash(self, concept, memory_items):
|
||||
"""
|
||||
计算节点的特征值
|
||||
"""
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
# 将记忆项排序以确保相同内容生成相同的哈希值
|
||||
sorted_items = sorted(memory_items)
|
||||
# 组合概念和记忆项生成特征值
|
||||
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||
return hash(content)
|
||||
|
||||
def calculate_edge_hash(self, source, target):
|
||||
"""
|
||||
计算边的特征值
|
||||
"""
|
||||
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
|
||||
nodes = sorted([source, target])
|
||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||
|
||||
def sync_memory_to_db_2(self):
|
||||
"""
|
||||
检查并同步内存中的图结构与数据库
|
||||
使用特征值(哈希值)快速判断是否需要更新
|
||||
"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
db_nodes_dict = {node['concept']: node for node in db_nodes}
|
||||
|
||||
# 检查并更新节点
|
||||
for concept, data in memory_nodes:
|
||||
memory_items = data.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 计算内存中节点的特征值
|
||||
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||
|
||||
if concept not in db_nodes_dict:
|
||||
# 数据库中缺少的节点,添加
|
||||
logger.info(f"添加新节点: {concept}")
|
||||
node_data = {
|
||||
'concept': concept,
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
db_hash = db_node.get('hash', None)
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
logger.info(f"更新节点内容: {concept}")
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}}
|
||||
)
|
||||
|
||||
# 检查并删除数据库中多余的节点
|
||||
memory_concepts = set(node[0] for node in memory_nodes)
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
logger.info(f"删除多余节点: {db_node['concept']}")
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
# 创建边的哈希值字典
|
||||
db_edge_dict = {}
|
||||
for edge in db_edges:
|
||||
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
|
||||
db_edge_dict[(edge['source'], edge['target'])] = {
|
||||
'hash': edge_hash,
|
||||
'num': edge.get('num', 1)
|
||||
}
|
||||
|
||||
# 检查并更新边
|
||||
for source, target in memory_edges:
|
||||
edge_hash = self.calculate_edge_hash(source, target)
|
||||
edge_key = (source, target)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
# 添加新边
|
||||
logger.info(f"添加新边: {source} - {target}")
|
||||
edge_data = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'num': 1,
|
||||
'hash': edge_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
logger.info(f"更新边: {source} - {target}")
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {'hash': edge_hash}}
|
||||
)
|
||||
|
||||
# 删除多余的边
|
||||
memory_edge_set = set(memory_edges)
|
||||
for edge_key in db_edge_dict:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
logger.info(f"删除多余边: {source} - {target}")
|
||||
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
|
||||
logger.success("完成记忆图谱与数据库的差异同步")
|
||||
|
||||
def find_topic_llm(self,text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||
return prompt
|
||||
|
||||
def topic_what(self,text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
|
||||
return prompt
|
||||
|
||||
def remove_node_from_db(self, topic):
|
||||
"""
|
||||
从数据库中删除指定节点及其相关的边
|
||||
|
||||
Args:
|
||||
topic: 要删除的节点概念
|
||||
"""
|
||||
# 删除节点
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
# 删除所有涉及该节点的边
|
||||
self.memory_graph.db.db.graph_data.edges.delete_many({
|
||||
'$or': [
|
||||
{'source': topic},
|
||||
{'target': topic}
|
||||
]
|
||||
})
|
||||
|
||||
def forget_topic(self, topic):
|
||||
"""
|
||||
随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
|
||||
只在内存中的图上操作,不直接与数据库交互
|
||||
|
||||
Args:
|
||||
topic: 要删除记忆的话题
|
||||
|
||||
Returns:
|
||||
removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
|
||||
"""
|
||||
if topic not in self.memory_graph.G:
|
||||
return None
|
||||
|
||||
# 获取话题节点数据
|
||||
node_data = self.memory_graph.G.nodes[topic]
|
||||
|
||||
# 如果节点存在memory_items
|
||||
if 'memory_items' in node_data:
|
||||
memory_items = node_data['memory_items']
|
||||
|
||||
# 确保memory_items是列表
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 如果有记忆项可以删除
|
||||
if memory_items:
|
||||
# 随机选择一个记忆项删除
|
||||
removed_item = random.choice(memory_items)
|
||||
memory_items.remove(removed_item)
|
||||
|
||||
# 更新节点的记忆项
|
||||
if memory_items:
|
||||
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||
else:
|
||||
# 如果没有记忆项了,删除整个节点
|
||||
self.memory_graph.G.remove_node(topic)
|
||||
|
||||
return removed_item
|
||||
|
||||
return None
|
||||
|
||||
async def operation_forget_topic(self, percentage=0.1):
|
||||
"""
|
||||
随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
|
||||
|
||||
Args:
|
||||
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
# 计算要检查的节点数量
|
||||
check_count = max(1, int(len(all_nodes) * percentage))
|
||||
# 随机选择节点
|
||||
nodes_to_check = random.sample(all_nodes, check_count)
|
||||
|
||||
forgotten_nodes = []
|
||||
for node in nodes_to_check:
|
||||
# 获取节点的连接数
|
||||
connections = self.memory_graph.G.degree(node)
|
||||
|
||||
# 获取节点的内容条数
|
||||
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
|
||||
# 检查连接强度
|
||||
weak_connections = True
|
||||
if connections > 1: # 只有当连接数大于1时才检查强度
|
||||
for neighbor in self.memory_graph.G.neighbors(node):
|
||||
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
|
||||
if strength > 2:
|
||||
weak_connections = False
|
||||
break
|
||||
|
||||
# 如果满足遗忘条件
|
||||
if (connections <= 1 and weak_connections) or content_count <= 2:
|
||||
removed_item = self.forget_topic(node)
|
||||
if removed_item:
|
||||
forgotten_nodes.append((node, removed_item))
|
||||
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
|
||||
|
||||
# 同步到数据库
|
||||
if forgotten_nodes:
|
||||
self.sync_memory_to_db_2()
|
||||
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
||||
else:
|
||||
logger.info("本次检查没有节点满足遗忘条件")
|
||||
|
||||
async def merge_memory(self, topic):
|
||||
"""
|
||||
对指定话题的记忆进行合并压缩
|
||||
|
||||
Args:
|
||||
topic: 要合并的话题节点
|
||||
"""
|
||||
# 获取节点的记忆项
|
||||
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
# 如果记忆项不足,直接返回
|
||||
if len(memory_items) < 10:
|
||||
return
|
||||
|
||||
# 随机选择10条记忆
|
||||
selected_memories = random.sample(memory_items, 10)
|
||||
|
||||
# 拼接成文本
|
||||
merged_text = "\n".join(selected_memories)
|
||||
print(f"\n[合并记忆] 话题: {topic}")
|
||||
print(f"选择的记忆:\n{merged_text}")
|
||||
|
||||
# 使用memory_compress生成新的压缩记忆
|
||||
compressed_memories = await self.memory_compress(merged_text, 0.1)
|
||||
|
||||
# 从原记忆列表中移除被选中的记忆
|
||||
for memory in selected_memories:
|
||||
memory_items.remove(memory)
|
||||
|
||||
# 添加新的压缩记忆
|
||||
for _, compressed_memory in compressed_memories:
|
||||
memory_items.append(compressed_memory)
|
||||
print(f"添加压缩记忆: {compressed_memory}")
|
||||
|
||||
# 更新节点的记忆项
|
||||
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||
|
||||
async def operation_merge_memory(self, percentage=0.1):
|
||||
"""
|
||||
随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
|
||||
|
||||
Args:
|
||||
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||
"""
|
||||
# 获取所有节点
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
# 计算要检查的节点数量
|
||||
check_count = max(1, int(len(all_nodes) * percentage))
|
||||
# 随机选择节点
|
||||
nodes_to_check = random.sample(all_nodes, check_count)
|
||||
|
||||
merged_nodes = []
|
||||
for node in nodes_to_check:
|
||||
# 获取节点的内容条数
|
||||
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
content_count = len(memory_items)
|
||||
|
||||
# 如果内容数量超过100,进行合并
|
||||
if content_count > 100:
|
||||
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
|
||||
await self.merge_memory(node)
|
||||
merged_nodes.append(node)
|
||||
|
||||
# 同步到数据库
|
||||
if merged_nodes:
|
||||
self.sync_memory_to_db_2()
|
||||
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
||||
else:
|
||||
print("\n本次检查没有需要合并的节点")
|
||||
|
||||
|
||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
# 创建一个新图用于可视化
|
||||
H = G.copy()
|
||||
|
||||
# 计算节点大小和颜色
|
||||
node_colors = []
|
||||
node_sizes = []
|
||||
nodes = list(H.nodes())
|
||||
|
||||
# 获取最大记忆数用于归一化节点大小
|
||||
max_memories = 1
|
||||
for node in nodes:
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
max_memories = max(max_memories, memory_count)
|
||||
|
||||
# 计算每个节点的大小和颜色
|
||||
for node in nodes:
|
||||
# 计算节点大小(基于记忆数量)
|
||||
memory_items = H.nodes[node].get('memory_items', [])
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
|
||||
node_sizes.append(size)
|
||||
|
||||
# 计算节点颜色(基于连接数)
|
||||
degree = H.degree(node)
|
||||
if degree >= 30:
|
||||
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
|
||||
else:
|
||||
# 将1-10映射到0-1的范围
|
||||
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
|
||||
# 使用蓝到红的渐变
|
||||
red = min(0.9, color_ratio)
|
||||
blue = max(0.0, 1.0 - color_ratio)
|
||||
node_colors.append((red, 0, blue))
|
||||
|
||||
# 获取边的权重和透明度
|
||||
edge_colors = []
|
||||
max_strength = 1
|
||||
|
||||
# 找出最大强度值
|
||||
for (u, v) in H.edges():
|
||||
strength = H[u][v].get('strength', 1)
|
||||
max_strength = max(max_strength, strength)
|
||||
|
||||
# 创建边权重字典用于布局
|
||||
edge_weights = {}
|
||||
|
||||
# 计算每条边的透明度和权重
|
||||
for (u, v) in H.edges():
|
||||
strength = H[u][v].get('strength', 1)
|
||||
# 将强度映射到透明度范围 [0.05, 0.8]
|
||||
alpha = 0.02 + 0.55 * (strength / max_strength)
|
||||
# 使用统一的蓝色,但透明度不同
|
||||
edge_colors.append((0, 0, 1, alpha))
|
||||
# 设置边的权重(强度越大,权重越大,节点间距离越小)
|
||||
edge_weights[(u, v)] = strength
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(20, 16)) # 增加图形尺寸
|
||||
# 调整弹簧布局参数,使用边权重影响布局
|
||||
pos = nx.spring_layout(H,
|
||||
k=2.0, # 增加节点间斥力
|
||||
iterations=100, # 增加迭代次数
|
||||
scale=2.0, # 增加布局尺寸
|
||||
weight='strength') # 使用边的strength属性作为权重
|
||||
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=8, # 稍微减小字体大小
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color=edge_colors,
|
||||
width=1.5) # 统一的边宽度
|
||||
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
async def main():
|
||||
# 初始化数据库
|
||||
logger.info("正在初始化数据库连接...")
|
||||
db = Database.get_instance()
|
||||
start_time = time.time()
|
||||
|
||||
test_pare = {'do_build_memory':False,'do_forget_topic':True,'do_visualize_graph':True,'do_query':False,'do_merge_memory':True}
|
||||
|
||||
# 创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
|
||||
# 创建海马体
|
||||
hippocampus = Hippocampus(memory_graph)
|
||||
|
||||
# 从数据库同步数据
|
||||
hippocampus.sync_memory_from_db()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
|
||||
# 构建记忆
|
||||
if test_pare['do_build_memory']:
|
||||
logger.info("开始构建记忆...")
|
||||
chat_size = 25
|
||||
await hippocampus.operation_build_memory(chat_size=chat_size)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = {chat_size}]\033[0m")
|
||||
|
||||
if test_pare['do_forget_topic']:
|
||||
logger.info("开始遗忘记忆...")
|
||||
await hippocampus.operation_forget_topic(percentage=0.1)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
|
||||
if test_pare['do_merge_memory']:
|
||||
logger.info("开始合并记忆...")
|
||||
await hippocampus.operation_merge_memory(percentage=0.1)
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||
|
||||
if test_pare['do_visualize_graph']:
|
||||
# 展示优化后的图形
|
||||
logger.info("生成记忆图谱可视化...")
|
||||
print("\n生成优化后的记忆图谱:")
|
||||
visualize_graph_lite(memory_graph)
|
||||
|
||||
if test_pare['do_query']:
|
||||
# 交互式查询
|
||||
while True:
|
||||
query = input("\n请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
if items_list:
|
||||
first_layer, second_layer = items_list
|
||||
if first_layer:
|
||||
print("\n直接相关的记忆:")
|
||||
for item in first_layer:
|
||||
print(f"- {item}")
|
||||
if second_layer:
|
||||
print("\n间接相关的记忆:")
|
||||
for item in second_layer:
|
||||
print(f"- {item}")
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@@ -2,28 +2,23 @@ import os
|
||||
import requests
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
from nonebot import get_driver
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from src.plugins.chat.config import BotConfig, global_config
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
|
||||
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = config.siliconflow_key
|
||||
self.base_url = config.siliconflow_base_url
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
||||
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
|
||||
|
||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
@@ -47,7 +42,60 @@ class LLMModel:
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
}
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
@@ -63,15 +111,15 @@ class LLMModel:
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
@@ -60,8 +60,15 @@ class LLM_request:
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
message = result["choices"][0]["message"]
|
||||
content = message.get("content", "")
|
||||
think_match = None
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if not reasoning_content:
|
||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user