v0.5.6堂堂发布 Merge pull request #60 from SengokuCola/debug

v0.5.6堂堂发布
This commit is contained in:
SengokuCola
2025-03-06 00:36:57 +08:00
committed by GitHub
46 changed files with 3515 additions and 2567 deletions

26
.env
View File

@@ -1,26 +0,0 @@
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
ENVIRONMENT=prod
# HOST=127.0.0.1
# PORT=8080
# COMMAND_START=["/"]
# # 插件配置
# PLUGINS=["src2.plugins.chat"]
# # 默认配置
# MONGODB_HOST=127.0.0.1
# MONGODB_PORT=27017
# DATABASE_NAME=MegBot
# MONGODB_USERNAME = "" # 默认空值
# MONGODB_PASSWORD = "" # 默认空值
# MONGODB_AUTH_SOURCE = "" # 默认空值
# #key and url
# CHAT_ANY_WHERE_KEY=
# SILICONFLOW_KEY=
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
# DEEP_SEEK_KEY=
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

View File

@@ -3,10 +3,11 @@ name: Docker Build and Push
on:
push:
branches:
- main # 推送到main分支时触发
- main
- debug # 新增 debug 分支触发
tags:
- 'v*' # 推送v开头的tag时触发例如v1.0.0
workflow_dispatch: # 允许手动触发
- 'v*'
workflow_dispatch:
jobs:
build-and-push:
@@ -24,13 +25,24 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Determine Image Tags
id: tags
run: |
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
fi
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: . # Docker构建上下文路径
file: ./Dockerfile # Dockerfile路径
platforms: linux/amd64,linux/arm64 # 支持arm架构
tags: |
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }}
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
push: true
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
tags: ${{ steps.tags.outputs.tags }}
push: true
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max

7
.gitignore vendored
View File

@@ -3,15 +3,17 @@ mongodb/
NapCat.Framework.Windows.Once/
log/
/test
/src/test
message_queue_content.txt
message_queue_content.bat
message_queue_window.bat
message_queue_window.txt
queue_update.txt
memory_graph.gml
.env
.env.*
config/bot_config_dev.toml
config/bot_config.toml
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -183,3 +185,6 @@ cython_debug/
# PyPI configuration file
.pypirc
.env
# jieba
jieba.cache

View File

@@ -32,7 +32,7 @@
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
**交流群**: 766798517仅用于开发和建议相关讨论
**交流群**: 766798517仅用于开发和建议相关讨论不建议在群内询问部署问题,我不一定有空回复,会优先写文档和代码
## 📚 文档
@@ -42,22 +42,22 @@
## 🎯 功能介绍
### 💬 聊天功能
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言目前有bug,所以现在只会检测主题,不会进行存储
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言
- 支持bot名字呼唤发言检测到"麦麦"会主动发言,可配置
- 使用硅基流动的api进行回复生成可随机使用R1V3R1-distill等模型未来将加入官网api支持
- 支持多模型,多厂商自定义配置
- 动态的prompt构建器更拟人
- 支持图片,转发消息,回复消息的识别
- 错别字和多条回复功能麦麦可以随机生成错别字会多条发送回复以及对消息进行reply
### 😊 表情包功能
- 支持根据发言内容发送对应情绪的表情包:未完善,可以用
- 会自动偷群友的表情包未完善暂时禁用目前有bug
- 支持根据发言内容发送对应情绪的表情包
- 会自动偷群友的表情包
### 📅 日程功能
- 麦麦会自动生成一天的日程,实现更拟人的回复
### 🧠 记忆功能
- 对聊天记录进行概括存储,在需要时调用,没写完
- 对聊天记录进行概括存储,在需要时调用,待完善
### 📚 知识库功能
- 基于embedding模型的知识库手动放入txt会自动识别写完了暂时禁用
@@ -66,25 +66,27 @@
- 针对每个用户创建"关系"可以对不同用户进行个性化回复目前只有极其简单的好感度WIP
- 针对每个群创建"群印象"可以对不同群进行个性化回复WIP
## 🚧 开发中功能
## 开发计划TODOLIST
- 人格功能WIP
- 群氛围功能WIP
- 图片发送转发功能WIP
- 幽默和meme功能WIP的WIP
- 让麦麦玩mcWIP的WIP的WIP
## 开发计划TODOLIST
- 兼容gif的解析和保存
- 小程序转发链接解析
- 对思考链长度限制
- 修复已知bug
- 完善文档
- ~~完善文档~~
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
- ~~config自动生成和检测~~
- ~~log别用print~~
- ~~给发送消息写专门的类~~
- 改进表情包发送逻辑
- 自动生成的回复逻辑,例如自生成的回复方向,回复风格
- 采用截断生成加快麦麦的反应速度
- 改进发送消息的触发:
## 📌 注意事项
纯编程外行面向cursor编程很多代码史一样多多包涵
@@ -99,4 +101,10 @@
感谢各位大佬!
[![Contributors](https://contributors-img.web.app/image?repo=SengokuCola/MaiMBot)](https://github.com/SengokuCola/MaiMBot/graphs/contributors)
<a href="https://github.com/SengokuCola/MaiMBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=SengokuCola/MaiMBot&time=true" />
</a>
## Stargazers over time
[![Stargazers over time](https://starchart.cc/SengokuCola/MaiMBot.svg?variant=adaptive)](https://starchart.cc/SengokuCola/MaiMBot)

52
bot.py
View File

@@ -6,6 +6,7 @@ from loguru import logger
'''彩蛋'''
from colorama import init, Fore
init()
text = "多年以后面对行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
@@ -15,25 +16,47 @@ for i, char in enumerate(text):
print(rainbow_text)
'''彩蛋'''
# 首先加载基础环境变量
# 初次启动检测
if not os.path.exists("config/bot_config.toml") or not os.path.exists(".env"):
logger.info("检测到bot_config.toml不存在正在从模板复制")
import shutil
shutil.copy("config/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动")
# 初始化.env 默认ENVIRONMENT=prod
if not os.path.exists(".env"):
with open(".env", "w") as f:
f.write("ENVIRONMENT=prod")
# 检测.env.prod文件是否存在
if not os.path.exists(".env.prod"):
logger.error("检测到.env.prod文件不存在")
shutil.copy("template.env", "./.env.prod")
# 首先加载基础环境变量.env
if os.path.exists(".env"):
load_dotenv(".env")
logger.success("成功加载基础环境变量配置")
else:
logger.error("基础环境变量配置文件 .env 不存在")
exit(1)
# 根据 ENVIRONMENT 加载对应的环境配置
env = os.getenv("ENVIRONMENT")
env_file = f".env.{env}"
if env_file == ".env.dev" and os.path.exists(env_file):
logger.success("加载开发环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
elif os.path.exists(".env.prod"):
logger.success("加载环境变量配置")
# 根据 ENVIRONMENT 加载对应的环境配置
if os.getenv("ENVIRONMENT") == "prod":
logger.success("加载生产环境变量配置")
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
elif os.getenv("ENVIRONMENT") == "dev":
logger.success("加载开发环境变量配置")
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
elif os.path.exists(f".env.{os.getenv('ENVIRONMENT')}"):
logger.success(f"加载{os.getenv('ENVIRONMENT')}环境变量配置")
load_dotenv(f".env.{os.getenv('ENVIRONMENT')}", override=True) # override=True 允许覆盖已存在的环境变量
else:
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
logger.error(f"ENVIRONMENT配置错误请检查.env文件中的ENVIRONMENT变量对应的.env.{os.getenv('ENVIRONMENT')}是否存在")
exit(1)
# 检测Key是否存在
if not os.getenv("SILICONFLOW_KEY"):
logger.error("缺失必要的API KEY")
logger.error(f"请至少在.env.{os.getenv('ENVIRONMENT')}文件中填写SILICONFLOW_KEY后重新启动")
exit(1)
# 获取所有环境变量
@@ -57,5 +80,4 @@ driver.register_adapter(Adapter)
nonebot.load_plugins("src/plugins")
if __name__ == "__main__":
nonebot.run()
nonebot.run()

View File

@@ -1,61 +0,0 @@
[bot]
qq = 123
nickname = "麦麦"
[message]
min_text_length = 2
max_context_size = 15
emoji_chance = 0.2
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow"
api_paid = true
model_r1_probability = 0.8
model_v3_probability = 0.1
model_r1_distill_probability = 0.1
[memory]
build_memory_interval = 300
[others]
enable_advance_output = true
[groups]
talk_allowed = [
123,
123,
]
talk_frequency_down = []
ban_user_id = []
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor]
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm]
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"

View File

@@ -0,0 +1,98 @@
[bot]
qq = 123
nickname = "麦麦"
[personality]
prompt_personality = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧人格
"是一个女大学生,你有黑色头发,你会刷小红书" # 小红书人格
]
prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
[message]
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
max_context_size = 15 # 麦麦获得的上文数量
emoji_chance = 0.2 # 麦麦使用表情包的概率
ban_words = [
# "403","张三"
]
[emoji]
check_interval = 120 # 检查表情包的时间间隔
register_interval = 10 # 注册表情包的时间间隔
[cq_code]
enable_pic_translate = false
[response]
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
[memory]
build_memory_interval = 300 # 记忆构建间隔 单位秒
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
[others]
enable_advance_output = true # 是否启用高级输出
enable_kuuki_read = true # 是否启用读空气功能
[groups]
talk_allowed = [
123,
123,
] #可以回复消息的群
talk_frequency_down = [] #降低回复频率的群
ban_user_id = [] #禁止回复消息的QQ号
#V3
#name = "deepseek-chat"
#base_url = "DEEP_SEEK_BASE_URL"
#key = "DEEP_SEEK_KEY"
#R1
#name = "deepseek-reasoner"
#base_url = "DEEP_SEEK_BASE_URL"
#key = "DEEP_SEEK_KEY"
#下面的模型若使用硅基流动则不需要更改使用ds官方则改成.env.prod自定义的宏使用自定义模型则选择定位相似的模型自己填写
[model.llm_reasoning] #R1
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor] #R1蒸馏
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal] #V3
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor] #V2.5
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm] #图像识别
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.embedding] #嵌入
name = "BAAI/bge-m3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
# 主题提取jieba和snownlp不用apillm需要api
[topic]
topic_extract='snownlp' # 只支持jieba,snownlp,llm三种选项
[topic.llm_topic]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"

View File

@@ -2,7 +2,9 @@
## 部署方式
### 🐳 Docker部署推荐
如果你不知道Docker是什么建议寻找相关教程或使用手动部署
### 🐳 Docker部署推荐但不一定是最新
1. 获取配置文件:
```bash
@@ -25,9 +27,7 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
```bash
# 创建虚拟环境(推荐)
python -m venv venv
source venv/bin/activate # Linux
venv\\Scripts\\activate # Windows
# 安装依赖
pip install -r requirements.txt
```
@@ -41,33 +41,37 @@ pip install -r requirements.txt
- 添加反向WS`ws://localhost:8080/onebot/v11/ws`
4. **配置文件设置**
- 复制并修改环境配置:`.env.prod`
- 复制并修改机器人配置:`bot_config.toml`
- 修改环境配置文件`.env.prod`
- 修改机器人配置文件`bot_config.toml`
5. **启动服务**
5. **启动麦麦机器人**
- 打开命令行cd到对应路径
```bash
nb run
```
6. **其他组件**
- `run_thingking.bat`: 启动可视化推理界面(未完善)和消息队列预览
- `knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库
- `run_thingking.bat`: 启动可视化推理界面(未完善)
- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~
- 直接运行 knowledge.py生成知识库
## ⚙️ 配置说明
### 环境配置 (.env.prod)
```ini
# API配置(必填)
# API配置,你可以在这里定义你的密钥和base_url
# 你可以选择定义其他服务商提供的KEY完全可以自定义
SILICONFLOW_KEY=your_key
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=your_key
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
# 服务配置
# 服务配置,如果你不知道这是什么,保持默认
HOST=127.0.0.1
PORT=8080
# 数据库配置
# 数据库配置,如果你不知道这是什么,保持默认
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
@@ -80,19 +84,58 @@ qq = "你的机器人QQ号"
nickname = "麦麦"
[message]
min_text_length = 2
max_context_size = 15
emoji_chance = 0.2
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow" # 或 "deepseek"
#现已移除deepseek或硅基流动选项可以直接切换分别配置任意模型
model_r1_probability = 0.8 #推理模型权重
model_v3_probability = 0.1 #非推理模型权重
model_r1_distill_probability = 0.1
[memory]
build_memory_interval = 300
[others]
enable_advance_output = false # 是否启用详细日志输出
enable_advance_output = true # 是否启用详细日志输出
[groups]
talk_allowed = [] # 允许回复的群号列表
talk_frequency_down = [] # 降低回复频率的群号列表
ban_user_id = [] # 禁止回复的用户QQ号列表
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor]
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm]
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
```
## ⚠️ 注意事项

View File

@@ -1,6 +0,0 @@
@echo off
echo 正在查找并结束所有 MongoDB 进程...
taskkill /F /IM mongod.exe
taskkill /F /IM mongo.exe
echo MongoDB 进程已结束
pause

View File

@@ -1,3 +1,4 @@
chcp 65001
call conda activate niuniu
cd .

11
setup.py Normal file
View File

@@ -0,0 +1,11 @@
from setuptools import setup, find_packages
setup(
name="maimai-bot",
version="0.1",
packages=find_packages(),
install_requires=[
'python-dotenv',
'pymongo',
],
)

View File

@@ -24,4 +24,25 @@ class Database:
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
return cls._instance
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@@ -0,0 +1,165 @@
from typing import Dict, List, Union, Optional, Any
import base64
import os
"""
OneBot v11 Message Segment Builder
This module provides classes for building message segments that conform to the
OneBot v11 standard. These segments can be used to construct complex messages
for sending through bots that implement the OneBot interface.
"""
class Segment:
"""Base class for all message segments."""
def __init__(self, type_: str, data: Dict[str, Any]):
self.type = type_
self.data = data
def to_dict(self) -> Dict[str, Any]:
"""Convert the segment to a dictionary format."""
return {
"type": self.type,
"data": self.data
}
class Text(Segment):
"""Text message segment."""
def __init__(self, text: str):
super().__init__("text", {"text": text})
class Face(Segment):
"""Face/emoji message segment."""
def __init__(self, face_id: int):
super().__init__("face", {"id": str(face_id)})
class Image(Segment):
"""Image message segment."""
@classmethod
def from_url(cls, url: str) -> 'Image':
"""Create an Image segment from a URL."""
return cls(url=url)
@classmethod
def from_path(cls, path: str) -> 'Image':
"""Create an Image segment from a file path."""
with open(path, 'rb') as f:
file_b64 = base64.b64encode(f.read()).decode('utf-8')
return cls(file=f"base64://{file_b64}")
def __init__(self, file: str = None, url: str = None, cache: bool = True):
data = {}
if file:
data["file"] = file
if url:
data["url"] = url
if not cache:
data["cache"] = "0"
super().__init__("image", data)
class At(Segment):
"""@Someone message segment."""
def __init__(self, user_id: Union[int, str]):
data = {"qq": str(user_id)}
super().__init__("at", data)
class Record(Segment):
"""Voice message segment."""
def __init__(self, file: str, magic: bool = False, cache: bool = True):
data = {"file": file}
if magic:
data["magic"] = "1"
if not cache:
data["cache"] = "0"
super().__init__("record", data)
class Video(Segment):
"""Video message segment."""
def __init__(self, file: str):
super().__init__("video", {"file": file})
class Reply(Segment):
"""Reply message segment."""
def __init__(self, message_id: int):
super().__init__("reply", {"id": str(message_id)})
class MessageBuilder:
"""Helper class for building complex messages."""
def __init__(self):
self.segments: List[Segment] = []
def text(self, text: str) -> 'MessageBuilder':
"""Add a text segment."""
self.segments.append(Text(text))
return self
def face(self, face_id: int) -> 'MessageBuilder':
"""Add a face/emoji segment."""
self.segments.append(Face(face_id))
return self
def image(self, file: str = None) -> 'MessageBuilder':
"""Add an image segment."""
self.segments.append(Image(file=file))
return self
def at(self, user_id: Union[int, str]) -> 'MessageBuilder':
"""Add an @someone segment."""
self.segments.append(At(user_id))
return self
def record(self, file: str, magic: bool = False) -> 'MessageBuilder':
"""Add a voice record segment."""
self.segments.append(Record(file, magic))
return self
def video(self, file: str) -> 'MessageBuilder':
"""Add a video segment."""
self.segments.append(Video(file))
return self
def reply(self, message_id: int) -> 'MessageBuilder':
"""Add a reply segment."""
self.segments.append(Reply(message_id))
return self
def build(self) -> List[Dict[str, Any]]:
"""Build the message into a list of segment dictionaries."""
return [segment.to_dict() for segment in self.segments]
'''Convenience functions
def text(content: str) -> Dict[str, Any]:
"""Create a text message segment."""
return Text(content).to_dict()
def image_url(url: str) -> Dict[str, Any]:
"""Create an image message segment from URL."""
return Image.from_url(url).to_dict()
def image_path(path: str) -> Dict[str, Any]:
"""Create an image message segment from file path."""
return Image.from_path(path).to_dict()
def at(user_id: Union[int, str]) -> Dict[str, Any]:
"""Create an @someone message segment."""
return At(user_id).to_dict()'''

View File

@@ -10,6 +10,10 @@ import random
from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager
from nonebot.rule import to_me
from .bot import chat_bot
from .emoji_manager import emoji_manager
import time
# 获取驱动器
@@ -17,12 +21,12 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
)
print("\033[1;32m[初始化数据库完成]\033[0m")
@@ -30,8 +34,9 @@ print("\033[1;32m[初始化数据库完成]\033[0m")
# 导入其他模块
from .bot import ChatBot
from .emoji_manager import emoji_manager
from .message_send_control import message_sender
# from .message_send_control import message_sender
from .relationship_manager import relationship_manager
from .message_sender import message_manager,message_sender
from ..memory_system.memory import memory_graph,hippocampus
# 初始化表情管理器
@@ -40,8 +45,8 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例
chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message()
# 注册消息处理器
group_msg = on_message(priority=5)
# 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler
@@ -66,10 +71,13 @@ async def init_relationships():
async def _(bot: Bot):
"""Bot连接成功时的处理"""
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
message_sender.set_bot(bot)
asyncio.create_task(message_sender.start_processor(bot))
await willing_manager.ensure_started()
message_sender.set_bot(bot)
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
asyncio.create_task(message_manager.start_processor())
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
@@ -79,19 +87,27 @@ async def _(bot: Bot):
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
'''
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
async def monitor_relationships():
"""每15秒打印一次关系数据"""
relationship_manager.print_all_relationships()
'''
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
await hippocampus.build_memory(chat_size=30)
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time()
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆整合]\033[0m 开始整合")
# await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")

View File

@@ -1,21 +1,22 @@
from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot
from .message import Message,MessageSet
from .message import Message, MessageSet, Message_Sending
from .config import BotConfig, global_config
from .storage import MessageStorage
from .llm_generator import ResponseGenerator
from .message_stream import MessageStream, MessageStreamContainer
# from .message_stream import MessageStream, MessageStreamContainer
from .topic_identifier import topic_identifier
from random import random, choice
from .emoji_manager import emoji_manager # 导入表情包管理器
import time
import os
from .cq_code import CQCode # 导入CQCode模块
from .message_send_control import message_sender # 导入消息发送控制
from .message_sender import message_manager # 导入新的消息管理
from .message import Message_Thinking # 导入 Message_Thinking 类
from .relationship_manager import relationship_manager
from .willing_manager import willing_manager # 导入意愿管理器
from .utils import is_mentioned_bot_in_txt, calculate_typing_time
from ..memory_system.memory import memory_graph
from loguru import logger
class ChatBot:
def __init__(self):
@@ -25,15 +26,12 @@ class ChatBot:
self._started = False
self.emoji_chance = 0.2 # 发送表情包的基础概率
self.message_streams = MessageStreamContainer()
self.message_sender = message_sender
# self.message_streams = MessageStreamContainer()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
# 只保留必要的任务
self._started = True
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
@@ -44,60 +42,41 @@ class ChatBot:
if event.user_id in global_config.ban_user_id:
return
# 打印原始消息内容
'''
print(f"\n\033[1;33m[消息详情]\033[0m")
# print(f"- 原始消息: {str(event.raw_message)}")
print(f"- post_type: {event.post_type}")
print(f"- sub_type: {event.sub_type}")
print(f"- user_id: {event.user_id}")
print(f"- message_type: {event.message_type}")
# print(f"- message_id: {event.message_id}")
# print(f"- message: {event.message}")
print(f"- original_message: {event.original_message}")
print(f"- raw_message: {event.raw_message}")
# print(f"- font: {event.font}")
print(f"- sender: {event.sender}")
# print(f"- to_me: {event.to_me}")
if event.reply:
print(f"\n\033[1;33m[回复消息详情]\033[0m")
# print(f"- message_id: {event.reply.message_id}")
print(f"- message_type: {event.reply.message_type}")
print(f"- sender: {event.reply.sender}")
# print(f"- time: {event.reply.time}")
print(f"- message: {event.reply.message}")
print(f"- raw_message: {event.reply.raw_message}")
# print(f"- original_message: {event.reply.original_message}")
'''
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
# print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
message = Message(
group_id=event.group_id,
user_id=event.user_id,
message_id=event.message_id,
user_cardname=sender_info['card'],
raw_message=str(event.original_message),
plain_text=event.get_plaintext(),
reply_message=event.reply,
)
# 过滤词
for word in global_config.ban_words:
if word in message.detailed_plain_text:
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
# topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
# topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
# topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
all_num = 0
interested_num = 0
@@ -110,9 +89,7 @@ class ChatBot:
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
interested_rate = interested_num / all_num if all_num > 0 else 0
await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
reply_probability = willing_manager.change_reply_willing_received(
@@ -127,54 +104,71 @@ class ChatBot:
current_willing = willing_manager.get_willing(event.group_id)
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability:.1f}]\033[0m")
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
response = ""
# 创建思考消息
if random() < reply_probability:
tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message,message_id=think_id)
message_sender.send_temp_container.add_message(thinking_message)
message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id)
response, emotion = await self.gpt.generate_response(message)
# 如果生成了回复,发送并记录
'''
生成回复后的内容
'''
# if response is None:
# thinking_message.interupt=True
if response:
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
# print(f"\033[1;32m[思考结束]\033[0m 思考结束,已得到回复,开始回复")
# 找到并删除对应的thinking消息
container = message_manager.get_container(event.group_id)
thinking_message = None
# 找到message,删除
for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
thinking_message = msg
container.messages.remove(msg)
print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break
#记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
#计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
for msg in response:
print(f"当前消息: {msg}")
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
#通过时间改变时间戳
typing_time = calculate_typing_time(msg)
accu_typing_time += typing_time
timepoint = tinking_time_point+accu_typing_time
# print(f"\033[1;32m[调试]\033[0m 消息: {msg},添加!, 累计打字时间: {accu_typing_time:.2f}秒")
timepoint = tinking_time_point + accu_typing_time
bot_message = Message(
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=timepoint
time=timepoint, #记录了回复生成的时间
thinking_start_time=thinking_start_time, #记录了思考开始的时间
reply_message_id=message.message_id
)
message_set.add_message(bot_message)
message_sender.send_temp_container.update_thinking_message(message_set)
#message_set 可以直接加入 message_manager
print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
message_manager.add_message(message_set)
bot_response_time = tinking_time_point
if random() < global_config.emoji_chance:
@@ -187,20 +181,24 @@ class ChatBot:
else:
bot_response_time = bot_response_time + 1
bot_message = Message(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
processed_plain_text=emoji_cq,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
is_emoji=True,
translate_cq=False
)
message_sender.send_temp_container.add_message(bot_message)
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
processed_plain_text=emoji_cq,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
)
message_manager.add_message(bot_message)
# 如果收到新消息,提高回复意愿
willing_manager.change_reply_willing_after_sent(event.group_id)
willing_manager.change_reply_willing_after_sent(event.group_id)
# 创建全局ChatBot实例
chat_bot = ChatBot()

View File

@@ -1,8 +1,6 @@
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, Set
import os
from nonebot.log import logger, default_format
import logging
import configparser
import tomli
import sys
@@ -28,16 +26,24 @@ class BotConfig:
talk_frequency_down_groups = set()
ban_user_id = set()
build_memory_interval: int = 60 # 记忆构建间隔(秒)
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
ban_words = set()
# 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
embedding: Dict[str, str] = field(default_factory=lambda: {})
vlm: Dict[str, str] = field(default_factory=lambda: {})
# 主题提取配置
topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm
llm_topic_extract: Dict[str, str] = field(default_factory=lambda: {})
API_USING: str = "siliconflow" # 使用的API
API_PAID: bool = False # 是否使用付费API
@@ -47,6 +53,13 @@ class BotConfig:
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
# 默认人设
PROMPT_PERSONALITY=[
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书"
]
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
@staticmethod
def get_config_dir() -> str:
@@ -66,6 +79,15 @@ class BotConfig:
if os.path.exists(config_path):
with open(config_path, "rb") as f:
toml_dict = tomli.load(f)
if 'personality' in toml_dict:
personality_config=toml_dict['personality']
personality=personality_config.get('prompt_personality')
if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)
if "emoji" in toml_dict:
emoji_config = toml_dict["emoji"]
@@ -103,12 +125,25 @@ class BotConfig:
if "llm_normal" in model_config:
config.llm_normal = model_config["llm_normal"]
config.llm_topic_extract = config.llm_normal
if "llm_normal_minor" in model_config:
config.llm_normal_minor = model_config["llm_normal_minor"]
if "vlm" in model_config:
config.vlm = model_config["vlm"]
if "embedding" in model_config:
config.embedding = model_config["embedding"]
if 'topic' in toml_dict:
topic_config=toml_dict['topic']
if 'topic_extract' in topic_config:
config.topic_extract=topic_config.get('topic_extract',config.topic_extract)
logger.info(f"载入自定义主题提取为{config.topic_extract}")
if config.topic_extract=='llm' and 'llm_topic' in topic_config:
config.llm_topic_extract=topic_config['llm_topic']
logger.info(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}")
# 消息配置
if "message" in toml_dict:
@@ -116,10 +151,12 @@ class BotConfig:
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
config.ban_words=msg_config.get("ban_words",config.ban_words)
if "memory" in toml_dict:
memory_config = toml_dict["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
# 群组配置
if "groups" in toml_dict:
@@ -131,6 +168,7 @@ class BotConfig:
if "others" in toml_dict:
others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
logger.success(f"成功加载配置文件: {config_path}")
@@ -144,32 +182,14 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
logger.info("使用bot配置文件")
else:
logger.info("已找到开发环境配置文件")
logger.info("已找到开发bot配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
logger.remove()
pass

View File

@@ -12,6 +12,8 @@ import base64
import shutil
import asyncio
import time
from PIL import Image
import io
from nonebot import get_driver
from ..chat.config import global_config
@@ -235,37 +237,107 @@ class EmojiManager:
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
return "skip"
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
return "skip" # 默认标签
async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]:
"""压缩图片并返回base64编码
Args:
image_path: 图片文件路径
target_size: 目标文件大小字节默认0.8MB
Returns:
Optional[str]: 成功返回base64编码的图片数据失败返回None
"""
try:
file_size = os.path.getsize(image_path)
if file_size <= target_size:
# 如果文件已经小于目标大小直接读取并返回base64
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
# 打开图片
with Image.open(image_path) as img:
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / file_size) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
print(f"\033[1;32m[成功]\033[0m 压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}")
return None
async def scan_new_emojis(self):
"""扫描新的表情包"""
try:
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 获取所有jpg文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')]
# 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
image_path = os.path.join(emoji_dir, filename)
# 读取图片数据
with open(image_path, 'rb') as f:
image_data = f.read()
# 将图片转换为base64
image_base64 = base64.b64encode(image_data).decode('utf-8')
# 压缩图片并获取base64编码
image_base64 = await self._compress_image(image_path)
if image_base64 is None:
os.remove(image_path)
continue
# 获取表情包的情感标签
tag = await self._get_emoji_tag(image_base64)
if not tag == "skip":
# 准备数据库记录
# 准备数据库记录
emoji_record = {
'filename': filename,
'path': image_path,
@@ -279,7 +351,6 @@ class EmojiManager:
print(f"标签: {tag}")
else:
print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}")

View File

@@ -21,9 +21,9 @@ config = driver.config
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1
@@ -50,12 +50,19 @@ class ResponseGenerator:
model_response, emotion = await self._process_response(model_response)
if model_response:
print(f"'{model_response}' 获取到的情感标签为:{emotion}")
valuedict={
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
return model_response, emotion
return None, []
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
if message.user_cardname:
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
# 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
@@ -70,25 +77,29 @@ class ResponseGenerator:
group_id=message.group_id
)
# 读空气模块
if global_config.enable_kuuki_read:
content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
if 'yes' not in content_check.lower() and random.random() < 0.3:
self._save_to_db(
message=message,
sender_name=sender_name,
prompt=prompt,
prompt_check=prompt_check,
content="",
content_check=content_check,
reasoning_content="",
reasoning_content_check=reasoning_content_check
)
return None
# 读空气模块 简化逻辑,先停用
# if global_config.enable_kuuki_read:
# content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
# print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
# if 'yes' not in content_check.lower() and random.random() < 0.3:
# self._save_to_db(
# message=message,
# sender_name=sender_name,
# prompt=prompt,
# prompt_check=prompt_check,
# content="",
# content_check=content_check,
# reasoning_content="",
# reasoning_content_check=reasoning_content_check
# )
# return None
# 生成回复
content, reasoning_content = await model.generate_response(prompt)
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception as e:
print(f"生成回复时出错: {e}")
return None
# 保存到数据库
self._save_to_db(
@@ -97,15 +108,17 @@ class ResponseGenerator:
prompt=prompt,
prompt_check=prompt_check,
content=content,
content_check=content_check if global_config.enable_kuuki_read else "",
# content_check=content_check if global_config.enable_kuuki_read else "",
reasoning_content=reasoning_content,
reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
)
return content
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
content: str, reasoning_content: str,):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
@@ -113,8 +126,8 @@ class ResponseGenerator:
'user': sender_name,
'message': message.processed_plain_text,
'model': self.current_model_type,
'reasoning_check': reasoning_content_check,
'response_check': content_check,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
@@ -129,9 +142,12 @@ class ResponseGenerator:
内容:{content}
输出:
'''
content, _ = await self.model_v3.generate_response(prompt)
return [content.strip()] if content else ["neutral"]
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
return [content]
else:
return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
@@ -145,4 +161,42 @@ class ResponseGenerator:
emotion_tags = await self._get_emotion_tags(content)
processed_response = process_llm_response(content)
return processed_response, emotion_tags
return processed_response, emotion_tags
class InitiativeMessageGenerate:
def __init__(self):
self.db = Database.get_instance()
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7
)
def gen_response(self, message: Message):
topic_select_prompt, dots_for_select, prompt_template = (
prompt_builder._build_initiative_prompt_select(message.group_id)
)
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
print(f"[DEBUG] {content_select} {reasoning}")
topics_list = [dot[0] for dot in dots_for_select]
if content_select:
if content_select in topics_list:
select_dot = dots_for_select[topics_list.index(content_select)]
else:
return None
else:
return None
prompt_check, memory = prompt_builder._build_initiative_prompt_check(
select_dot[1], prompt_template
)
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
print(f"[DEBUG] {content_check} {reasoning_check}")
if "yes" not in content_check.lower():
return None
prompt = prompt_builder._build_initiative_prompt(
select_dot, prompt_template, memory
)
content, reasoning = self.model_r1.generate_response(prompt)
print(f"[DEBUG] {content} {reasoning}")
return content

View File

@@ -8,7 +8,7 @@ from ...common.database import Database
from PIL import Image
from .config import global_config
import urllib3
from .utils_user import get_user_nickname
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode
@@ -21,46 +21,46 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
group_name: str = None # 群名称
user_cardname: str=None # 用户群昵称
message_id: int = None
raw_message: str = None
plain_text: str = None
message_based_id: int = None
reply_message: Dict = None # 存储回复消息
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
time: float = None
reply_message: Dict = None # 存储 回复的 源消息
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0
type: str = 'received' # 消息类型可以是received或者send
def __post_init__(self):
if self.time is None:
self.time = int(time.time())
if not self.group_name:
self.group_name = get_groupname(self.group_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id)
if not self.group_name:
self.group_name = self.get_groupname(self.group_id)
if not self.user_cardname:
self.user_cardname=get_user_cardname(self.user_id)
if not self.processed_plain_text:
if self.raw_message:
@@ -71,24 +71,12 @@ class Message:
)
#将详细翻译为详细可读文本
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = self.user_nickname or f"用户{self.user_id}"
try:
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
except:
name = self.user_nickname or f"用户{self.user_id}"
content = self.processed_plain_text
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n"
def get_groupname(self, group_id: int) -> str:
if not group_id:
return "未知群"
group_id = int(group_id)
# 使用数据库单例
db = Database.get_instance()
# 查找用户,打印查询条件和结果
query = {'group_id': group_id}
group = db.db.group_info.find_one(query)
if group:
return group.get('group_name')
else:
return f"{group_id}"
def parse_message_segments(self, message: str) -> List[CQCode]:
"""
@@ -159,49 +147,58 @@ class Message_Thinking:
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
self.message_id = message_id
# 思考状态相关属性
self.thinking_text = "正在思考..."
self.time = int(time.time())
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.interupt=False
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.time
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
@property
def processed_plain_text(self) -> str:
"""获取处理后的文本"""
return self.thinking_text
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
def __str__(self) -> str:
return f"[思考中] 群:{self.group_id} 用户:{self.user_nickname} 时间:{self.time} 消息ID:{self.message_id}"
reply_message_id: int = None # 存储 回复的 源消息ID
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
return self.thinking_time
class MessageSet:
"""消息集合类,可以存储多个相关的消息"""
"""消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id
self.user_id = user_id
self.message_id = message_id
self.messages: List[Message] = []
self.messages: List[Message_Sending] = [] # 修改类型标注
self.time = round(time.time(), 2)
def add_message(self, message: Message) -> None:
"""添加消息到集合"""
def add_message(self, message: Message_Sending) -> None:
"""添加消息到集合只接受Message_Sending类型"""
if not isinstance(message, Message_Sending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
self.messages.append(message)
# 按时间排序
self.messages.sort(key=lambda x: x.time)
def get_message_by_index(self, index: int) -> Optional[Message]:
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[Message]:
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
@@ -222,7 +219,7 @@ class MessageSet:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: Message) -> bool:
def remove_message(self, message: Message_Sending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
@@ -236,8 +233,4 @@ class MessageSet:
return len(self.messages)

View File

@@ -1,251 +0,0 @@
from typing import Union, List, Optional, Deque, Dict
from nonebot.adapters.onebot.v11 import Bot, MessageSegment
import asyncio
import random
import os
from .message import Message, Message_Thinking, MessageSet
from .cq_code import CQCode
from collections import deque
import time
from .storage import MessageStorage
from .config import global_config
from .cq_code import cq_code_tool
if os.name == "nt":
from .message_visualizer import message_visualizer
class SendTemp:
"""单个群组的临时消息队列管理器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
self.max_size = max_size
self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
self.last_send_time = 0
def add(self, message: Message) -> None:
"""按时间顺序添加消息到队列"""
if not self.messages:
self.messages.append(message)
return
# 按时间顺序插入
if message.time >= self.messages[-1].time:
self.messages.append(message)
return
# 使用二分查找找到合适的插入位置
messages_list = list(self.messages)
left, right = 0, len(messages_list)
while left < right:
mid = (left + right) // 2
if messages_list[mid].time < message.time:
left = mid + 1
else:
right = mid
# 重建消息队列,保持时间顺序
new_messages = deque(maxlen=self.max_size)
new_messages.extend(messages_list[:left])
new_messages.append(message)
new_messages.extend(messages_list[left:])
self.messages = new_messages
def get_earliest_message(self) -> Optional[Message]:
"""获取时间最早的消息"""
message = self.messages.popleft() if self.messages else None
return message
def clear(self) -> None:
"""清空队列"""
self.messages.clear()
def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
"""获取所有待发送的消息"""
if group_id is None:
return list(self.messages)
return [msg for msg in self.messages if msg.group_id == group_id]
def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
"""查看下一条要发送的消息(不移除)"""
return self.messages[0] if self.messages else None
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def count(self, group_id: Optional[int] = None) -> int:
"""获取待发送消息数量"""
if group_id is None:
return len(self.messages)
return len([msg for msg in self.messages if msg.group_id == group_id])
def get_last_send_time(self) -> float:
"""获取最后一次发送时间"""
return self.last_send_time
def update_send_time(self):
"""更新最后发送时间"""
self.last_send_time = time.time()
class SendTempContainer:
"""管理所有群组的消息缓存容器"""
def __init__(self):
self.temp_queues: Dict[int, SendTemp] = {}
def get_queue(self, group_id: int) -> SendTemp:
"""获取或创建群组的消息队列"""
if group_id not in self.temp_queues:
self.temp_queues[group_id] = SendTemp(group_id)
return self.temp_queues[group_id]
def add_message(self, message: Message) -> None:
"""添加消息到对应群组的队列"""
queue = self.get_queue(message.group_id)
queue.add(message)
def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
"""获取指定群组的所有待发送消息"""
queue = self.get_queue(group_id)
return queue.get_all()
def has_messages(self, group_id: int) -> bool:
"""检查指定群组是否有待发送消息"""
queue = self.get_queue(group_id)
return queue.has_messages()
def get_all_groups(self) -> List[int]:
"""获取所有有待发送消息的群组ID"""
return list(self.temp_queues.keys())
def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
queue = self.get_queue(message_obj.group_id)
# 使用列表解析找到匹配的消息索引
matching_indices = [
i for i, msg in enumerate(queue.messages)
if msg.message_id == message_obj.message_id
]
if not matching_indices:
return False
index = matching_indices[0] # 获取第一个匹配的索引
# 将消息转换为列表以便修改
messages = list(queue.messages)
# 根据消息类型处理
if isinstance(message_obj, MessageSet):
messages.pop(index)
# 在原位置插入新消息组
for i, single_message in enumerate(message_obj.messages):
messages.insert(index + i, single_message)
# print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
else:
# 直接替换原消息
messages[index] = message_obj
# print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
# 重建队列
queue.messages.clear()
for msg in messages:
queue.messages.append(msg)
return True
class MessageSendControl:
"""消息发送控制器"""
def __init__(self):
self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
self.max_retry = 3 # 最大重试次数
self.send_temp_container = SendTempContainer()
self._running = True
self._paused = False
self._current_bot = None
self.storage = MessageStorage() # 添加存储实例
try:
message_visualizer.start()
except(NameError):
pass
def set_bot(self, bot: Bot):
"""设置当前bot实例"""
self._current_bot = bot
async def process_group_messages(self, group_id: int):
queue = self.send_temp_container.get_queue(group_id)
if queue.has_messages():
message = queue.peek_next()
# 处理消息的逻辑
if isinstance(message, Message_Thinking):
message.update_thinking_time()
thinking_time = message.thinking_time
if thinking_time < 90: # 最少思考2秒
if int(thinking_time) % 15 == 0:
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}")
return
else:
print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
queue.get_earliest_message() # 移除超时的思考消息
return
elif isinstance(message, Message):
message = queue.get_earliest_message()
if message and message.processed_plain_text:
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
cost_time = round(time.time(), 2) - message.time
if cost_time > 40:
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
cur_time = time.time()
await self._current_bot.send_group_msg(
group_id=group_id,
message=str(message.processed_plain_text),
auto_escape=False
)
cost_time = round(time.time(), 2) - cur_time
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}")
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
if message.is_emoji:
message.processed_plain_text = "[表情包]"
await self.storage.store_message(message, None)
else:
await self.storage.store_message(message, None)
queue.update_send_time()
if queue.has_messages():
await asyncio.sleep(
random.uniform(
self.message_interval[0],
self.message_interval[1]
)
)
async def start_processor(self, bot: Bot):
"""启动消息处理器"""
self._current_bot = bot
while self._running:
await asyncio.sleep(1.5)
tasks = []
for group_id in self.send_temp_container.get_all_groups():
tasks.append(self.process_group_messages(group_id))
# 并行处理所有群组的消息
await asyncio.gather(*tasks)
try:
message_visualizer.update_content(self.send_temp_container)
except(NameError):
pass
def set_typing_speed(self, min_speed: float, max_speed: float):
"""设置打字速度范围"""
self.typing_speed = (min_speed, max_speed)
# 创建全局实例
message_sender = MessageSendControl()

View File

@@ -0,0 +1,225 @@
from typing import Union, List, Optional, Dict
from collections import deque
from .message import Message, Message_Thinking, MessageSet, Message_Sending
import time
import asyncio
from nonebot.adapters.onebot.v11 import Bot
from .config import global_config
from .storage import MessageStorage
from .cq_code import cq_code_tool
import random
from .utils import calculate_typing_time
class Message_Sender:
"""发送器"""
def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0
self._current_bot = None
def set_bot(self, bot: Bot):
"""设置当前bot实例"""
self._current_bot = bot
async def send_group_message(
self,
group_id: int,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
) -> None:
if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text
# 如果需要回复
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer:
"""单个群的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
self.max_size = max_size
self.messages = []
self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time()
timeout_messages = []
for msg in self.messages:
if isinstance(msg, Message_Sending):
if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float('inf')
earliest_message = None
for msg in self.messages:
msg_time = msg.thinking_start_time
if msg_time < earliest_time:
earliest_time = msg_time
earliest_message = msg
return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
"""添加消息到队列"""
print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet):
for single_message in message.messages:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False"""
try:
if message in self.messages:
self.messages.remove(message)
return True
return False
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}")
return False
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
"""获取所有消息"""
return list(self.messages)
class MessageManager:
"""管理所有群的消息容器"""
def __init__(self):
self.containers: Dict[int, MessageContainer] = {}
self.storage = MessageStorage()
self._running = True
def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建群的消息容器"""
if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id)
container.add_message(message)
async def process_group_messages(self, group_id: int):
"""处理群消息"""
# if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
#一个月后删了
if not message_earliest:
print(f"\033[1;34m[BUG如果出现这个说明有BUG3月4日留]\033[0m ")
return
#如果是思考消息
if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}")
else:# 如果不是message_thinking就只能是message_sending
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢
if message_earliest.update_thinking_time() < 30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
#移除消息
if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None)
container.remove_message(message_earliest)
#获取并处理超时消息
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
for msg in message_timeout:
if msg == message_earliest:
continue # 跳过已经处理过的消息
try:
#发送
if msg.update_thinking_time() < 30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
#如果是表情包,则替换为"[表情包]"
if msg.is_emoji:
msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None)
# 安全地移除消息
if not container.remove_message(msg):
print(f"\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}")
continue
async def start_processor(self):
"""启动消息处理器"""
while self._running:
await asyncio.sleep(1)
tasks = []
for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id))
await asyncio.gather(*tasks)
# 创建全局消息管理器实例
message_manager = MessageManager()
# 创建全局发送器实例
message_sender = Message_Sender()

View File

@@ -1,264 +0,0 @@
from typing import List, Optional, Dict
from .message import Message
import time
from collections import deque
from datetime import datetime, timedelta
import os
import json
import asyncio
class MessageStream:
"""单个群组的消息流容器"""
def __init__(self, group_id: int, max_size: int = 1000):
self.group_id = group_id
self.messages = deque(maxlen=max_size)
self.max_size = max_size
self.last_save_time = time.time()
# 确保日志目录存在
self.log_dir = os.path.join("log", str(self.group_id))
os.makedirs(self.log_dir, exist_ok=True)
# 启动自动保存任务
asyncio.create_task(self._auto_save())
async def _auto_save(self):
"""每30秒自动保存一次消息记录"""
while True:
await asyncio.sleep(30) # 等待30秒
await self.save_to_log()
async def save_to_log(self):
"""将消息保存到日志文件"""
try:
current_time = time.time()
# 只有有新消息时才保存
if not self.messages or self.last_save_time == current_time:
return
# 生成日志文件名 (使用当前日期)
date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
# 获取需要保存的新消息
new_messages = [
msg for msg in self.messages
if msg.time > self.last_save_time
]
if not new_messages:
return
# 将消息转换为可序列化的格式
message_logs = []
for msg in new_messages:
message_logs.append({
"time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
"user_id": msg.user_id,
"user_nickname": msg.user_nickname,
"message_id": msg.message_id,
"raw_message": msg.raw_message,
"processed_text": msg.processed_plain_text
})
# 追加写入日志文件
with open(log_file, "a", encoding="utf-8") as f:
for log in message_logs:
f.write(json.dumps(log, ensure_ascii=False) + "\n")
self.last_save_time = current_time
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
def add_message(self, message: Message) -> None:
"""按时间顺序添加新消息到队列
使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
Args:
message: Message对象要添加的新消息
"""
# 空队列或消息应该添加到末尾的情况
if (not self.messages or
message.time >= self.messages[-1].time):
self.messages.append(message)
return
# 消息应该添加到开头的情况
if message.time <= self.messages[0].time:
self.messages.appendleft(message)
return
# 使用二分查找在现有队列中找到合适的插入位置
left, right = 0, len(self.messages) - 1
while left <= right:
mid = (left + right) // 2
if self.messages[mid].time < message.time:
left = mid + 1
else:
right = mid - 1
temp = list(self.messages)
temp.insert(left, message)
# 如果超出最大长度,移除多余的消息
if len(temp) > self.max_size:
temp = temp[-self.max_size:]
# 重建队列
self.messages = deque(temp, maxlen=self.max_size)
async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
"""从数据库中获取最近的消息记录
Args:
count: 需要获取的消息数量
Returns:
List[Message]: 最近的消息列表
"""
try:
from ...common.database import Database
db = Database.get_instance()
# 从数据库中查询最近的消息
recent_messages = list(db.db.messages.find(
{"group_id": self.group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
).sort("time", -1).limit(count))
if not recent_messages:
return []
# 转换为 Message 对象
from .message import Message
messages = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
return list(reversed(messages)) # 返回按时间正序的消息
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
return []
def get_recent_messages(self, count: int = 10) -> List[Message]:
"""获取最近的n条消息从内存队列"""
print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
return list(self.messages)[-count:]
def get_messages_in_timerange(self,
start_time: Optional[float] = None,
end_time: Optional[float] = None) -> List[Message]:
"""获取时间范围内的消息"""
if start_time is None:
start_time = time.time() - 3600
if end_time is None:
end_time = time.time()
return [
msg for msg in self.messages
if start_time <= msg.time <= end_time
]
def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
"""获取特定用户的最近消息"""
user_messages = [msg for msg in self.messages if msg.user_id == user_id]
return user_messages[-count:]
def clear_old_messages(self, hours: int = 24) -> None:
"""清理旧消息"""
cutoff_time = time.time() - (hours * 3600)
self.messages = deque(
[msg for msg in self.messages if msg.time > cutoff_time],
maxlen=self.max_size
)
class MessageStreamContainer:
"""管理所有群组的消息流容器"""
def __init__(self, max_size: int = 1000):
self.streams: Dict[int, MessageStream] = {}
self.max_size = max_size
async def save_all_logs(self):
"""保存所有群组的消息日志"""
for stream in self.streams.values():
await stream.save_to_log()
def add_message(self, message: Message) -> None:
"""添加消息到对应群组的消息流"""
if not message.group_id:
return
if message.group_id not in self.streams:
self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
self.streams[message.group_id].add_message(message)
def get_stream(self, group_id: int) -> Optional[MessageStream]:
"""获取特定群组的消息流"""
return self.streams.get(group_id)
def get_all_streams(self) -> Dict[int, MessageStream]:
"""获取所有群组的消息流"""
return self.streams
def clear_old_messages(self, hours: int = 24) -> None:
"""清理所有群组的旧消息"""
for stream in self.streams.values():
stream.clear_old_messages(hours)
def get_group_stats(self, group_id: int) -> Dict:
"""获取群组的消息统计信息"""
stream = self.streams.get(group_id)
if not stream:
return {
"total_messages": 0,
"unique_users": 0,
"active_hours": [],
"most_active_user": None
}
messages = stream.messages
user_counts = {}
hour_counts = {}
for msg in messages:
user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
hour = datetime.fromtimestamp(msg.time).hour
hour_counts[hour] = hour_counts.get(hour, 0) + 1
most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
active_hours = sorted(
hour_counts.items(),
key=lambda x: x[1],
reverse=True
)[:5]
return {
"total_messages": len(messages),
"unique_users": len(user_counts),
"active_hours": active_hours,
"most_active_user": most_active_user
}
# 创建全局实例
message_stream_container = MessageStreamContainer()

View File

@@ -1,138 +0,0 @@
import subprocess
import threading
import queue
import os
import time
from typing import Dict
from .message import Message_Thinking
class MessageVisualizer:
def __init__(self):
self.process = None
self.message_queue = queue.Queue()
self.is_running = False
self.content_file = "message_queue_content.txt"
def start(self):
if self.process is None:
# 创建用于显示的批处理文件
with open("message_queue_window.bat", "w", encoding="utf-8") as f:
f.write('@echo off\n')
f.write('chcp 65001\n') # 设置UTF-8编码
f.write('title Message Queue Visualizer\n')
f.write('echo Waiting for message queue updates...\n')
f.write(':loop\n')
f.write('if exist "queue_update.txt" (\n')
f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
f.write(' del "queue_update.txt"\n')
f.write(' cls\n')
f.write(' type "message_queue_content.txt"\n')
f.write(')\n')
f.write('timeout /t 1 /nobreak >nul\n')
f.write('goto loop\n')
# 清空内容文件
with open(self.content_file, "w", encoding="utf-8") as f:
f.write("")
# 启动新窗口
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
self.process = subprocess.Popen(
['cmd', '/c', 'start', 'message_queue_window.bat'],
shell=True,
startupinfo=startupinfo
)
self.is_running = True
# 启动处理线程
threading.Thread(target=self._process_messages, daemon=True).start()
def _process_messages(self):
while self.is_running:
try:
# 获取新消息
text = self.message_queue.get(timeout=1)
# 写入更新文件
with open("queue_update.txt", "w", encoding="utf-8") as f:
f.write(text)
except queue.Empty:
continue
except Exception as e:
print(f"处理队列可视化内容时出错: {e}")
def update_content(self, send_temp_container):
"""更新显示内容"""
if not self.is_running:
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
display_text = f"Message Queue Status - {current_time}\n"
display_text += "=" * 50 + "\n\n"
# 遍历所有群组的队列
for group_id, queue in send_temp_container.temp_queues.items():
display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
display_text += f"消息队列长度: {len(queue.messages)}\n"
display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
display_text += "\n消息队列内容:\n"
# 显示队列中的消息
if not queue.messages:
display_text += " [空队列]\n"
else:
for i, msg in enumerate(queue.messages):
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
display_text += f"\n--- 消息 {i+1} ---\n"
if isinstance(msg, Message_Thinking):
display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
display_text += f"时间: {msg_time}\n"
display_text += f"消息ID: {msg.message_id}\n"
display_text += f"群组: {msg.group_id}\n"
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
display_text += f"内容: {msg.thinking_text}\n"
display_text += f"思考时间: {int(msg.thinking_time)}\n"
else:
display_text += f"类型: 普通消息\n"
display_text += f"时间: {msg_time}\n"
display_text += f"消息ID: {msg.message_id}\n"
display_text += f"群组: {msg.group_id}\n"
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
if hasattr(msg, 'is_emoji') and msg.is_emoji:
display_text += f"内容: [表情包消息]\n"
else:
# 显示原始消息和处理后的消息
display_text += f"原始内容: {msg.raw_message[:50]}...\n"
display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
if msg.reply_message:
display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
display_text += f"\n{'-' * 50}\n"
# 添加统计信息
display_text += "\n总体统计:\n"
display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
display_text += f"总消息数: {total_messages}\n"
thinking_messages = sum(
sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
for q in send_temp_container.temp_queues.values()
)
display_text += f"思考中消息数: {thinking_messages}\n"
self.message_queue.put(display_text)
def stop(self):
self.is_running = False
if self.process:
self.process.terminate()
self.process = None
# 清理文件
for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
if os.path.exists(file):
os.remove(file)
# 创建全局单例
message_visualizer = MessageVisualizer()

View File

@@ -36,7 +36,9 @@ class PromptBuilder:
memory_prompt = ''
start_time = time.time() # 记录开始时间
topic = topic_identifier.identify_topic_jieba(message_txt)
# topic = await topic_identifier.identify_topic_llm(message_txt)
topic = topic_identifier.identify_topic_snownlp(message_txt)
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
all_first_layer_items = [] # 存储所有第一层记忆
@@ -64,15 +66,7 @@ class PromptBuilder:
if overlap:
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
# if all_first_layer_items:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
# if overlapping_second_layer:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
# 使用集合去重
# 从每个来源随机选择2条记忆如果有的话
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
@@ -147,15 +141,15 @@ class PromptBuilder:
is_bot_prompt = ''
#人格选择
personality=global_config.PROMPT_PERSONALITY
prompt_personality = ''
personality_choice = random.random()
if personality_choice < 4/6: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < 1: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
请你表达自己的见解和观点。可以有个性。'''
@@ -170,7 +164,7 @@ class PromptBuilder:
#额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只需要输出回复内容就好,不要输出其他任何内容'''
@@ -195,25 +189,69 @@ class PromptBuilder:
prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < 4/6: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < 1: # 第二种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response
def _build_initiative_prompt_select(self,group_id):
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题
all_nodes=memory_graph.dots
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes)
nodes_for_select=random.sample(all_nodes,5)
topics=[info[0] for info in nodes_for_select]
infos=[info[1] for info in nodes_for_select]
#激活prompt构建
activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天。"
personality=global_config.PROMPT_PERSONALITY
prompt_personality = ''
personality_choice = random.random()
if personality_choice < 4/6: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}'''
elif personality_choice < 1: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}'''
topics_str=','.join(f"\"{topics}\"")
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular=f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select,nodes_for_select,prompt_regular
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
memory=random.sample(selected_node['memory_items'],3)
memory='\n'.join(memory)
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check,memory
def _build_initiative_prompt(self,selected_node,prompt_regular,memory):
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
return prompt_for_initiative
def get_prompt_info(self,message:str,threshold:float):
related_info = ''
if len(message) > 10:
message_segments = [message[i:i+10] for i in range(0, len(message), 10)]
for segment in message_segments:
embedding = get_embedding(segment)
related_info += self.get_info_from_db(embedding,threshold=threshold)
else:
embedding = get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold)
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold)
return related_info

View File

@@ -23,6 +23,7 @@ class MessageStorage:
"processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
@@ -37,6 +38,7 @@ class MessageStorage:
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,

View File

@@ -0,0 +1,14 @@
#Broca's Area
# 功能:语言产生、语法处理和言语运动控制。
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
import time
class Thinking_Idea:
def __init__(self, message_id: str):
self.messages = [] # 消息列表集合
self.current_thoughts = [] # 当前思考内容列表
self.time = time.time() # 创建时间
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID

View File

@@ -4,19 +4,20 @@ from .message import Message
import jieba
from nonebot import get_driver
from .config import global_config
from snownlp import SnowNLP
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=config.siliconflow_key,
base_url=config.siliconflow_base_url
)
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
self.select=global_config.topic_extract
def identify_topic_llm(self, text: str) -> Optional[str]:
"""识别消息主题"""
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:
1. 主题通常2-4个字必须简短要求精准概括不要太具体。
@@ -24,77 +25,42 @@ class TopicIdentifier:
消息内容:{text}"""
response = self.client.chat.completions.create(
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10
)
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_client.generate_response(prompt)
if not response or not response.choices:
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
if not topic:
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
return None
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
topic = response.choices[0].message.content.strip() if response.choices[0].message.content else None
if topic == "无主题":
return None
else:
# print(f"[主题分析结果]{text[:20]}... : {topic}")
split_topic = self.parse_topic(topic)
return split_topic
def parse_topic(self, topic: str) -> List[str]:
"""解析主题,返回主题列表"""
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return []
return [t.strip() for t in topic.split(",") if t.strip()]
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
words = jieba.lcut(text)
# 去除停用词和标点符号
stop_words = {
'', '', '', '', '', '', '', '', '', '', '', '', '', '',
'因为', '所以', '如果', '虽然', '一个', '', '', '', '', '', '我们', '你们',
'他们', '', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么',
'那么', '多少', '', '', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办',
'怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样',
'一会儿', '一边', '一起',
# 添加更多量词
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
# 添加更多介词
'', '按照', '', '', '', '比如', '', '除了', '', '', '对于',
'根据', '关于', '', '', '', '', '经过', '', '', '', '通过',
'', '', '', '为了', '围绕', '', '', '由于', '', '', '沿', '沿着',
'', '依照', '', '', '因为', '', '', '', '', '自从'
}
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
# 过滤掉停用词和标点符号,只保留名词和动词
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-'
}:
filtered_words.append(word)
# 统计词频
word_freq = {}
for word in filtered_words:
word_freq[word] = word_freq.get(word, 0) + 1
# 按词频排序取前3个
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
top_words = [word for word, freq in sorted_words[:3]]
return top_words if top_words else None
Args:
text (str): 需要识别主题的文本
Returns:
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
"""
if not text or len(text.strip()) == 0:
return None
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(5)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
return None
topic_identifier = TopicIdentifier()

View File

@@ -10,6 +10,7 @@ from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -37,7 +38,10 @@ def combine_messages(messages: List[Message]) -> str:
def db_message_to_str (message_dict: Dict) -> str:
print(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
try:
name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", ""))
except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}")
@@ -61,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return False
def get_embedding(text):
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
if response.status_code != 200:
print(f"API请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
return response.json()['data'][0]['embedding']
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text)
def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
@@ -87,13 +75,11 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2)
def calculate_information_content(text):
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
@@ -102,23 +88,37 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
# print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
return [] # 如果没有找到记录,返回空列表
print(f"消息已读取3次跳过")
return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
@@ -135,14 +135,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit))
if not recent_messages:
@@ -152,16 +152,20 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
from .message import Message
message_objects = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
# 按时间正序排列
message_objects.reverse()

View File

@@ -6,32 +6,27 @@ import os
from ...common.database import Database
import zlib # 用于 CRC32
import base64
from .config import global_config
from nonebot import get_driver
from loguru import logger
driver = get_driver()
config = driver.config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
if type == 'image':
return storage_compress_image(image_data, max_size)
elif type == 'emoji':
return storage_emoji(image_data)
else:
raise ValueError(f"不支持的图片类型: {type}")
def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩图片到指定大小单位KB并在数据库中记录图片信息
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args:
image_data: 图片字节数据
group_id: 群组ID
user_id: 用户ID
base64_data: base64编码的图片数据
max_size: 最大文件大小KB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
@@ -41,11 +36,11 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库
db = Database(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
@@ -55,14 +50,14 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return image_data
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return image_data
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
@@ -127,14 +122,16 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
return compressed_data
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
return image_data
return base64_data
def storage_emoji(image_data: bytes) -> bytes:
"""
@@ -215,4 +212,78 @@ def storage_image(image_data: bytes) -> bytes:
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
return image_data
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= target_size:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data

View File

@@ -5,4 +5,13 @@ def get_user_nickname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return relationship_manager.get_name(user_id)
return relationship_manager.get_name(user_id)
def get_user_cardname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return ''
def get_groupname(group_id: int) -> str:
return f"{group_id}"

View File

@@ -9,9 +9,8 @@ class WillingManager:
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(3)
await asyncio.sleep(5)
for group_id in self.group_reply_willing:
# 每分钟衰减10%的回复意愿
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
def get_willing(self, group_id: int) -> float:
@@ -26,13 +25,7 @@ class WillingManager:
"""改变指定群组的回复意愿并返回回复概率"""
current_willing = self.group_reply_willing.get(group_id, 0)
print(f"初始意愿: {current_willing}")
# if topic and current_willing < 1:
# current_willing += 0.2
# elif topic:
# current_willing += 0.05
# print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}")
@@ -44,23 +37,23 @@ class WillingManager:
current_willing *= 0.15
print(f"表情包, 当前意愿: {current_willing}")
if interested_rate > 0.6:
if interested_rate > 0.65:
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.45
current_willing += interested_rate-0.6
self.group_reply_willing[group_id] = min(current_willing, 3.0)
reply_probability = (current_willing - 0.5) * 2
reply_probability = max((current_willing - 0.55) * 1.9, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / 3.5
# if is_mentioned_bot and user_id == int(1026294844):
# reply_probability = 1
reply_probability = min(reply_probability, 1)
if reply_probability < 0:
reply_probability = 0
return reply_probability
def change_reply_willing_sent(self, group_id: int):
@@ -72,7 +65,7 @@ class WillingManager:
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""
@@ -82,4 +75,4 @@ class WillingManager:
self._started = True
# 创建全局实例
willing_manager = WillingManager()
willing_manager = WillingManager()

View File

@@ -3,26 +3,28 @@ import sys
import numpy as np
import requests
import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.plugins.chat.config import llm_config
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.dev")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
# 直接配置数据库连接信息
from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host=os.getenv("MONGODB_HOST", "localhost"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
)
class KnowledgeLibrary:
@@ -30,6 +32,9 @@ class KnowledgeLibrary:
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
def _ensure_dirs(self):
"""确保必要的目录存在"""
@@ -44,7 +49,7 @@ class KnowledgeLibrary:
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
@@ -74,7 +79,7 @@ class KnowledgeLibrary:
content = f.read()
# 按1024字符分段
segments = [content[i:i+300] for i in range(0, len(content), 300)]
segments = [content[i:i+600] for i in range(0, len(content), 600)]
# 处理每个分段
for segment in segments:

View File

@@ -2,7 +2,6 @@
import os
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
@@ -10,10 +9,19 @@ from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
from dotenv import load_dotenv
import sys
import asyncio
import aiohttp
from typing import Tuple
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
@@ -112,7 +120,11 @@ class Memory_graph:
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
@@ -154,38 +166,32 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
# 创建LLM模型实例
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
# visualize_graph(memory_graph, color_by_memory=False)
visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
# visualize_graph(memory_graph, color_by_memory=True)
visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
for item in second_layer_items:
print(item)
else:
print("未找到相关记忆。")
@@ -255,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
node_size=200,
font_size=10,
font_family='SimHei',
font_weight='bold')
@@ -281,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 2 or degree <= 2:
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
@@ -294,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
# 计算节点大小和颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
node_sizes = []
nodes = list(H.nodes())
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
node_sizes.append(size)
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
red = min(1.0, degree / max_degree)
# 蓝色分量随着度数减少而增加
blue = 1.0 - red
color = (red, 0, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold')
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.7)
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()

View File

@@ -1,74 +0,0 @@
import os
import requests
from typing import Tuple, Union
import time
from nonebot import get_driver
import aiohttp
import asyncio
from src.plugins.chat.config import BotConfig, global_config
driver = get_driver()
config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15 # 基础等待时间(秒)
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -9,15 +9,26 @@ import random
import time
from ..chat.config import global_config
from ...common.database import Database # 使用正确的导入语法
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
from ..models.utils_model import LLM_request
import math
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
@@ -38,9 +49,7 @@ class Memory_graph:
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
@@ -52,7 +61,6 @@ class Memory_graph:
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
@@ -69,7 +77,6 @@ class Memory_graph:
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
@@ -87,87 +94,59 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def save_graph_to_db(self):
# 保存节点
for node in self.G.nodes(data=True):
concept = node[0]
memory_items = node[1].get('memory_items', [])
def forget_topic(self, topic):
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
if topic not in self.G:
return None
# 查找是否存在同名节点
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 获取话题节点数据
node_data = self.G.nodes[topic]
# 保存边
for edge in self.G.edges():
source, target = edge
# 如果节点存在memory_items
if 'memory_items' in node_data:
memory_items = node_data['memory_items']
# 查找是否存在同样的边
existing_edge = self.db.db.graph_data.edges.find_one({
'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.G.nodes[topic]['memory_items'] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.G.remove_node(topic)
return removed_item
return None
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5)
def calculate_node_hash(self, concept, memory_items):
"""计算节点的特征值"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
sorted_items = sorted(memory_items)
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""计算边的特征值"""
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
@@ -175,82 +154,340 @@ class Hippocampus:
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
return [text for text in chat_text if text]
async def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = await self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
compressed_memory = set()
async def memory_compress(self, input_text, compress_rate=0.1):
print(input_text)
#获取topics
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
# 修改话题处理逻辑
print(f"话题: {topics_response[0]}")
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
print(f"话题: {topics}")
# 创建所有话题的请求任务
tasks = []
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
topic_what_prompt = self.topic_what(input_text, topic)
# 创建异步任务
task = self.llm_model_summary.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
def calculate_topic_num(self,text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count('\n')*compress_rate
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content)/2)
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
return topic_num
async def operation_build_memory(self,chat_size=20):
# 最近消息获取频率
time_frequency = {'near':2,'mid':4,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
# 加载进度可视化
all_topics = []
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = await self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_to_db(self):
"""检查并同步内存中的图结构与数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
node_data = {
'concept': concept,
'memory_items': memory_items,
'hash': memory_hash
}
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
'hash': memory_hash
}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
db_edge_dict[(edge['source'], edge['target'])] = {
'hash': edge_hash,
'strength': edge.get('strength', 1)
}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = self.memory_graph.G[source][target].get('strength', 1)
if edge_key not in db_edge_dict:
# 添加新边
edge_data = {
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash
}
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
'strength': strength
}}
)
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
self.memory_graph.db.db.graph_data.edges.delete_one({
'source': source,
'target': target
})
def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
print(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
print("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(merged_text, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
def find_topic_llm(self,text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt
def topic_what(self,text, topic):
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
from nonebot import get_driver
driver = get_driver()
@@ -259,19 +496,19 @@ config = driver.config
start_time = time.time()
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= config.MONGODB_PORT,
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
#创建记忆图
memory_graph = Memory_graph()
#加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
#创建海马体
hippocampus = Hippocampus(memory_graph)
#从数据库加载记忆图
hippocampus.sync_memory_from_db()
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")

View File

@@ -1,459 +0,0 @@
# -*- coding: utf-8 -*-
import sys
import jieba
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
import os
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
return chat_text
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
# print(f"第二层: {neighbor}")
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
if record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
# 保存节点
for node in self.G.nodes(data=True):
concept = node[0]
memory_items = node[1].get('memory_items', [])
# 查找是否存在同名节点
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
if existing_node:
# 如果存在,合并memory_items并去重
existing_items = existing_node.get('memory_items', [])
if not isinstance(existing_items, list):
existing_items = [existing_items] if existing_items else []
# 合并并去重
all_items = list(set(existing_items + memory_items))
# 更新节点
self.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {'memory_items': all_items}}
)
else:
# 如果不存在,创建新节点
node_data = {
'concept': concept,
'memory_items': memory_items
}
self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
source, target = edge
# 查找是否存在同样的边
existing_edge = self.db.db.graph_data.edges.find_one({
'source': source,
'target': target
})
if existing_edge:
# 如果存在,增加num属性
num = existing_edge.get('num', 1) + 1
self.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'num': num}}
)
else:
# 如果不存在,创建新边
edge_data = {
'source': source,
'target': target,
'num': 1
}
self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
#加载进度可视化
for i, input_text in enumerate(memory_sample, 1):
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# print(f"第{i}条消息: {input_text}")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 1 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
start_time = time.time()
# 创建记忆图
memory_graph = Memory_graph()
# 加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
hippocampus.build_memory(chat_size=25)
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# 交互式查询
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,786 @@
# -*- coding: utf-8 -*-
import sys
import jieba
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
import os
from dotenv import load_dotenv
import pymongo
from loguru import logger
from pathlib import Path
from snownlp import SnowNLP
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database
from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录
current_dir = Path(__file__).resolve().parent
# 获取项目根目录(上三层目录)
project_root = current_dir.parent.parent.parent
# env.dev文件路径
env_path = project_root / ".env.dev"
# 加载环境变量
if env_path.exists():
logger.info(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
total_chars = len(text)
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
print(f"消息已读取3次跳过")
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
else:
# 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list):
# 如果当前不是列表,将其转换为列表
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
self.G.nodes[concept]['memory_items'].append(memory)
else:
self.G.nodes[concept]['memory_items'] = [memory]
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
concept, data = node_data
if 'memory_items' in data:
memory_items = data['memory_items']
if isinstance(memory_items, list):
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return [chat for chat in chat_text if chat]
def calculate_topic_num(self,text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count('\n')*compress_rate
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content)/2)
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
return topic_num
async def memory_compress(self, input_text, compress_rate=0.1):
print(input_text)
#获取topics
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num))
# 修改话题处理逻辑
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
print(f"话题: {topics}")
# 创建所有话题的请求任务
tasks = []
for topic in topics:
topic_what_prompt = self.topic_what(input_text, topic)
# 创建异步任务
task = self.llm_model_small.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
# 等待所有任务完成
compressed_memory = set()
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
return compressed_memory
async def operation_build_memory(self, chat_size=12):
# 最近消息获取频率
time_frequency = {'near': 3, 'mid': 8, 'far': 5}
memory_sample = self.get_memory_sample(chat_size, time_frequency)
all_topics = [] # 用于存储所有话题
for i, input_text in enumerate(memory_sample, 1):
# 加载进度可视化
all_topics = []
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
def sync_memory_from_db(self):
"""
从数据库同步数据到内存中的图结构
将清空当前内存中的图,并从数据库重新加载所有节点和边
"""
# 清空当前图
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
strength = edge.get('strength', 1) # 获取 strength默认为 1
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength)
logger.success("从数据库同步记忆图谱完成")
def calculate_node_hash(self, concept, memory_items):
"""
计算节点的特征值
"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 将记忆项排序以确保相同内容生成相同的哈希值
sorted_items = sorted(memory_items)
# 组合概念和记忆项生成特征值
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target):
"""
计算边的特征值
"""
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def sync_memory_to_db(self):
"""
检查并同步内存中的图结构与数据库
使用特征值(哈希值)快速判断是否需要更新
"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点
for concept, data in memory_nodes:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
logger.info(f"添加新节点: {concept}")
node_data = {
'concept': concept,
'memory_items': memory_items,
'hash': memory_hash
}
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None)
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
logger.info(f"更新节点内容: {concept}")
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
'hash': memory_hash
}}
)
# 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
logger.info(f"删除多余节点: {db_node['concept']}")
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
db_edge_dict[(edge['source'], edge['target'])] = {
'hash': edge_hash,
'num': edge.get('num', 1)
}
# 检查并更新边
for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target)
if edge_key not in db_edge_dict:
# 添加新边
logger.info(f"添加新边: {source} - {target}")
edge_data = {
'source': source,
'target': target,
'num': 1,
'hash': edge_hash
}
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
logger.info(f"更新边: {source} - {target}")
self.memory_graph.db.db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'hash': edge_hash}}
)
# 删除多余的边
memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
logger.info(f"删除多余边: {source} - {target}")
self.memory_graph.db.db.graph_data.edges.delete_one({
'source': source,
'target': target
})
logger.success("完成记忆图谱与数据库的差异同步")
def find_topic_llm(self,text, topic_num):
# prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt
def topic_what(self,text, topic):
# prompt = f'这是一段文字:{text}。我想知道这段文字里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt
def remove_node_from_db(self, topic):
"""
从数据库中删除指定节点及其相关的边
Args:
topic: 要删除的节点概念
"""
# 删除节点
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边
self.memory_graph.db.db.graph_data.edges.delete_many({
'$or': [
{'source': topic},
{'target': topic}
]
})
def forget_topic(self, topic):
"""
随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
只在内存中的图上操作,不直接与数据库交互
Args:
topic: 要删除记忆的话题
Returns:
removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
"""
if topic not in self.memory_graph.G:
return None
# 获取话题节点数据
node_data = self.memory_graph.G.nodes[topic]
# 如果节点存在memory_items
if 'memory_items' in node_data:
memory_items = node_data['memory_items']
# 确保memory_items是列表
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果有记忆项可以删除
if memory_items:
# 随机选择一个记忆项删除
removed_item = random.choice(memory_items)
memory_items.remove(removed_item)
# 更新节点的记忆项
if memory_items:
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
else:
# 如果没有记忆项了,删除整个节点
self.memory_graph.G.remove_node(topic)
return removed_item
return None
async def operation_forget_topic(self, percentage=0.1):
"""
随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = []
for node in nodes_to_check:
# 获取节点的连接数
connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 检查连接强度
weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度
for neighbor in self.memory_graph.G.neighbors(node):
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
if strength > 2:
weak_connections = False
break
# 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.forget_topic(node)
if removed_item:
forgotten_nodes.append((node, removed_item))
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库
if forgotten_nodes:
self.sync_memory_to_db()
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else:
logger.info("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic):
"""
对指定话题的记忆进行合并压缩
Args:
topic: 要合并的话题节点
"""
# 获取节点的记忆项
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回
if len(memory_items) < 10:
return
# 随机选择10条记忆
selected_memories = random.sample(memory_items, 10)
# 拼接成文本
merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(merged_text, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
memory_items.remove(memory)
# 添加新的压缩记忆
for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1):
"""
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
Args:
percentage: 要检查的节点比例默认为0.110%
"""
# 获取所有节点
all_nodes = list(self.memory_graph.G.nodes())
# 计算要检查的节点数量
check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = []
for node in nodes_to_check:
# 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
# 如果内容数量超过100进行合并
if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node)
merged_nodes.append(node)
# 同步到数据库
if merged_nodes:
self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else:
print("\n本次检查没有需要合并的节点")
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数用于归一化节点大小
max_memories = 1
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
max_memories = max(max_memories, memory_count)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 400 + 2000 * (ratio ** 2) # 增大节点大小
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
if degree >= 30:
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
else:
# 将1-10映射到0-1的范围
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
# 使用蓝到红的渐变
red = min(0.9, color_ratio)
blue = max(0.0, 1.0 - color_ratio)
node_colors.append((red, 0, blue))
# 绘制图形
plt.figure(figsize=(16, 12)) # 减小图形尺寸
pos = nx.spring_layout(H,
k=1, # 调整节点间斥力
iterations=100, # 增加迭代次数
scale=1.5, # 减小布局尺寸
weight='strength') # 使用边的strength属性作为权重
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=12, # 保持增大的字体大小
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=1.5) # 统一的边宽度
title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare['do_build_memory']:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare['do_query']:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
if first_layer:
print("\n直接相关的记忆:")
for item in first_layer:
print(f"- {item}")
if second_layer:
print("\n间接相关的记忆:")
for item in second_layer:
print(f"- {item}")
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -0,0 +1,125 @@
import os
import requests
from typing import Tuple, Union
import time
import aiohttp
import asyncio
from loguru import logger
class LLMModel:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15 # 基础等待时间(秒)
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15
async with aiohttp.ClientSession() as session:
for retry in range(max_retries):
try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -2,20 +2,26 @@ import aiohttp
import asyncio
import requests
import time
import re
from typing import Tuple, Union
from nonebot import get_driver
from loguru import logger
from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
self.base_url = getattr(config, model["base_url"])
except AttributeError as e:
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"]
self.params = kwargs
@@ -25,48 +31,62 @@ class LLM_request:
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
if response.status in [500, 503]:
logger.error(f"服务器错误: {response.status}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
logger.critical(f"请求失败: {str(e)}", exc_info=True)
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
@@ -74,43 +94,116 @@ class LLM_request:
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
def build_request_data(img_base64: str):
return {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
}
}
]
}
],
**self.params
}
]
}
],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
max_retries = 3
base_wait_time = 15
current_image_base64 = image_base64
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
for retry in range(max_retries):
try:
data = build_request_data(current_image_base64)
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
elif response.status == 413:
logger.warning("图片太大(413),尝试压缩...")
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}", exc_info=True)
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
base_wait_time = 15
async with aiohttp.ClientSession() as session:
for retry in range(max_retries):
try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
@@ -122,16 +215,20 @@ class LLM_request:
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""同步方法:根据输入的提示和图片生成模型的响应"""
@@ -139,7 +236,9 @@ class LLM_request:
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
image_base64=compress_base64_image_by_scale(image_base64)
# 构建请求体
data = {
"model": self.model_name,
@@ -162,38 +261,160 @@ class LLM_request:
],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
logger.critical(f"请求失败: {str(e)}", exc_info=True)
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""同步方法获取文本的embedding向量
return "达到最大重试次数,请求仍然失败", ""
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
time.sleep(wait_time)
else:
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
return None
logger.error("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
return None
logger.error("达到最大重试次数embedding请求仍然失败")
return None

View File

@@ -1,22 +1,23 @@
import datetime
import os
from typing import List, Dict
from typing import List, Dict, Union
from ...common.database import Database # 使用正确的导入语法
from src.plugins.chat.config import global_config
from nonebot import get_driver
from ..models.utils_model import LLM_request
from loguru import logger
import json
driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator:
@@ -42,8 +43,6 @@ class ScheduleGenerator:
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
if target_date is None:
target_date = datetime.datetime.now()
date_str = target_date.strftime("%Y-%m-%d")
weekday = target_date.strftime("%A")
@@ -59,15 +58,20 @@ class ScheduleGenerator:
elif read_only == False:
print(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书,请为我生成{date_str}{weekday})的日程安排,包括:
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:"""+\
"""
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床""""
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制,格式为{"时间": "活动","时间": "活动",...}"""
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
# print(self.schedule_text)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
else:
print(f"{date_str}的日程不存在。")
schedule_text = "忘了"
@@ -77,20 +81,15 @@ class ScheduleGenerator:
schedule_form = self._parse_schedule(schedule_text)
return schedule_text,schedule_form
def _parse_schedule(self, schedule_text: str) -> Dict[str, str]:
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
"""解析日程文本,转换为时间和活动的字典"""
schedule_dict = {}
# 按行分割日程文本
lines = schedule_text.strip().split('\n')
for line in lines:
# print(line)
if ',' in line:
# 假设格式为 "时间: 活动"
time_str, activity = line.split(',', 1)
# print(time_str)
# print(activity)
schedule_dict[time_str.strip()] = activity.strip()
return schedule_dict
try:
schedule_dict = json.loads(schedule_text)
return schedule_dict
except json.JSONDecodeError as e:
print(schedule_text)
print(f"解析日程失败: {str(e)}")
return False
def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间"""
@@ -105,6 +104,8 @@ class ScheduleGenerator:
min_diff = float('inf')
# 检查今天的日程
if not self.today_schedule.keys():
return "摸鱼"
for time_str in self.today_schedule.keys():
diff = abs(self._time_diff(current_time, time_str))
if closest_time is None or diff < min_diff:
@@ -128,6 +129,10 @@ class ScheduleGenerator:
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1=="24:00":
time1="23:59"
if time2=="24:00":
time2="23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
@@ -141,11 +146,14 @@ class ScheduleGenerator:
def print_schedule(self):
"""打印完整的日程安排"""
print("\n=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n")
if not self._parse_schedule(self.today_schedule_text):
print("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
print("\n=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n")
# def main():
# # 使用示例
@@ -165,4 +173,4 @@ class ScheduleGenerator:
# if __name__ == "__main__":
# main()
bot_schedule = ScheduleGenerator()
bot_schedule = ScheduleGenerator()

View File

@@ -1,70 +0,0 @@
from textblob import TextBlob
import jieba
from translate import Translator
def analyze_emotion(text):
"""
分析文本的情感,返回情感极性和主观性得分
:param text: 输入文本
:return: (情感极性, 主观性) 元组
情感极性: -1(非常消极) 到 1(非常积极)
主观性: 0(客观) 到 1(主观)
"""
try:
# 创建翻译器
translator = Translator(to_lang="en", from_lang="zh")
# 如果是中文文本,先翻译成英文
# 因为TextBlob的情感分析主要基于英文
translated_text = translator.translate(text)
# 创建TextBlob对象
blob = TextBlob(translated_text)
# 获取情感极性和主观性
polarity = blob.sentiment.polarity
subjectivity = blob.sentiment.subjectivity
return polarity, subjectivity
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None, None
def get_emotion_description(polarity, subjectivity):
"""
根据情感极性和主观性生成描述性文字
"""
if polarity is None or subjectivity is None:
return "无法分析情感"
# 情感极性描述
if polarity > 0.5:
emotion = "非常积极"
elif polarity > 0:
emotion = "较为积极"
elif polarity == 0:
emotion = "中性"
elif polarity > -0.5:
emotion = "较为消极"
else:
emotion = "非常消极"
# 主观性描述
if subjectivity > 0.7:
subj = "非常主观"
elif subjectivity > 0.3:
subj = "较为主观"
else:
subj = "较为客观"
return f"情感倾向: {emotion}, 表达方式: {subj}"
if __name__ == "__main__":
# 测试样例
test_text = "今天天气真好,我感到非常开心!"
polarity, subjectivity = analyze_emotion(test_text)
print(f"测试文本: {test_text}")
print(f"情感极性: {polarity:.2f}")
print(f"主观性得分: {subjectivity:.2f}")
print(get_emotion_description(polarity, subjectivity))

View File

@@ -1,74 +0,0 @@
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
def setup_bert_analyzer():
"""
设置中文BERT情感分析器
"""
# 使用专门针对中文情感分析的模型
model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
try:
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 创建情感分析pipeline
analyzer = pipeline("sentiment-analysis",
model=model,
tokenizer=tokenizer)
return analyzer
except Exception as e:
print(f"模型加载错误: {str(e)}")
return None
def analyze_emotion_bert(text, analyzer):
"""
使用BERT模型进行中文情感分析
"""
try:
if not analyzer:
return None
# 进行情感分析
result = analyzer(text)[0]
return {
'label': result['label'],
'score': result['score']
}
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_bert(result):
"""
将BERT的情感分析结果转换为描述性文字
"""
if not result:
return "无法分析情感"
label = "积极" if result['label'] == 'positive' else "消极"
confidence = result['score']
if confidence > 0.9:
strength = "强烈"
elif confidence > 0.7:
strength = "明显"
else:
strength = "轻微"
return f"{strength}{label}"
if __name__ == "__main__":
# 初始化分析器
analyzer = setup_bert_analyzer()
# 测试样例
test_text = "这个产品质量很好,使用起来非常方便,推荐购买!"
result = analyze_emotion_bert(test_text, analyzer)
print(f"测试文本: {test_text}")
if result:
print(f"情感倾向: {get_emotion_description_bert(result)}")
print(f"置信度: {result['score']:.2f}")

View File

@@ -1,62 +0,0 @@
import hanlp
def analyze_emotion_hanlp(text):
"""
使用HanLP进行中文情感分析
"""
try:
# 使用更基础的模型
tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
# 分词
words = tokenizer(text)
# 简单的情感词典方法
positive_words = {'', '', '优秀', '喜欢', '开心', '快乐', '美味', '推荐', '优质', '满意'}
negative_words = {'', '', '', '讨厌', '失望', '难受', '恶心', '不满', '差劲', '垃圾'}
# 计算情感得分
score = 0
for word in words:
if word in positive_words:
score += 1
elif word in negative_words:
score -= 1
# 归一化得分
if score > 0:
return 1
elif score < 0:
return 0
else:
return 0.5
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_hanlp(score):
"""
将HanLP的情感分析结果转换为描述性文字
"""
if score is None:
return "无法分析情感"
elif score == 1:
return "积极"
elif score == 0:
return "消极"
else:
return "中性"
if __name__ == "__main__":
# 测试样例
test_texts = [
"这家餐厅的服务态度很好,菜品也很美味!",
"这个产品质量太差了,一点都不值这个价",
"今天天气不错,但是工作很累"
]
for test_text in test_texts:
result = analyze_emotion_hanlp(test_text)
print(f"\n测试文本: {test_text}")
print(f"情感倾向: {get_emotion_description_hanlp(result)}")

488
src/test/typo_creator.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

View File

@@ -1,301 +0,0 @@
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
"""
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2):
"""
创建包含同音字错误的句子,保留原文标点符号
"""
result = []
typo_info = []
# 获取每个字的拼音
chars_with_pinyin = get_pinyin(sentence)
# 创建原字到拼音的映射,用于跟踪已处理的字符
processed_chars = {char: py for char, py in chars_with_pinyin}
# 遍历原句中的每个字符
char_index = 0
for i, char in enumerate(sentence):
if char.isspace():
# 保留空格
result.append(char)
elif char in processed_chars:
# 处理汉字
py = processed_chars[char]
# 基础错误率
if random.random() < error_rate:
# 获取频率相近的同音字(可能包含声调错误)
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
# 随机选择一个替换字
typo_char = random.choice(similar_chars)
# 获取替换字的频率
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
# 计算实际替换概率
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
# 根据频率差进行概率替换
if random.random() < replace_prob:
result.append(typo_char)
# 获取替换字的实际拼音
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
else:
result.append(char)
else:
result.append(char)
else:
result.append(char)
char_index += 1
else:
# 保留非汉字字符(标点符号等)
result.append(char)
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
min_freq=5, tone_error_rate=0.2)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原字:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"错字:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印所有可能的同音字
print("\n每个字的所有同音字(按频率排序,仅显示频率>=5的字")
for char, py in result:
homophones = get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5)
char_freq = char_frequency.get(char, 0)
print(f"{char}: {py} [频率:{format_frequency(char_freq)}]")
if homophones:
homophone_info = []
for h in homophones:
h_freq = char_frequency.get(h, 0)
homophone_info.append(f"{h}[{format_frequency(h_freq)}]")
print(f"同音字: {''.join(homophone_info)}")
else:
print("没有找到频率>=5的同音字")
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,6 @@
HOST=127.0.0.1
PORT=8080
COMMAND_START=["/"]
# 插件配置
PLUGINS=["src2.plugins.chat"]
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
#key and url
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
#定义你要用的api的base_url
DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=