Merge branch 'debug' into main

This commit is contained in:
SengokuCola
2025-03-12 00:20:54 +08:00
committed by GitHub
64 changed files with 6255 additions and 2846 deletions

View File

@@ -5,6 +5,7 @@ on:
branches:
- main
- debug # 新增 debug 分支触发
- stable-dev
tags:
- 'v*'
workflow_dispatch:
@@ -34,6 +35,8 @@ jobs:
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
elif [ "${{ github.ref }}" == "refs/heads/stable-dev" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:stable-dev" >> $GITHUB_OUTPUT
fi
- name: Build and Push Docker Image

6
.gitignore vendored
View File

@@ -1,4 +1,5 @@
data/
data1/
mongodb/
NapCat.Framework.Windows.Once/
log/
@@ -193,9 +194,8 @@ cython_debug/
# jieba
jieba.cache
# vscode
/.vscode
# .vscode
!.vscode/settings.json
# direnv
/.direnv

48
CLAUDE.md Normal file
View File

@@ -0,0 +1,48 @@
# MaiMBot 开发指南
## 🛠️ 常用命令
- **运行机器人**: `python run.py``python bot.py`
- **安装依赖**: `pip install --upgrade -r requirements.txt`
- **Docker 部署**: `docker-compose up`
- **代码检查**: `ruff check .`
- **代码格式化**: `ruff format .`
- **内存可视化**: `run_memory_vis.bat``python -m src.plugins.memory_system.draw_memory`
- **推理过程可视化**: `script/run_thingking.bat`
## 🔧 脚本工具
- **运行MongoDB**: `script/run_db.bat` - 在端口27017启动MongoDB
- **Windows完整启动**: `script/run_windows.bat` - 检查Python版本、设置虚拟环境、安装依赖并运行机器人
- **快速启动**: `script/run_maimai.bat` - 设置UTF-8编码并执行"nb run"命令
## 📝 代码风格
- **Python版本**: 3.9+
- **行长度限制**: 88字符
- **命名规范**:
- `snake_case` 用于函数和变量
- `PascalCase` 用于类
- `_prefix` 用于私有成员
- **导入顺序**: 标准库 → 第三方库 → 本地模块
- **异步编程**: 对I/O操作使用async/await
- **日志记录**: 使用loguru进行一致的日志记录
- **错误处理**: 使用带有具体异常的try/except
- **文档**: 为类和公共函数编写docstrings
## 🧩 系统架构
- **框架**: NoneBot2框架与插件架构
- **数据库**: MongoDB持久化存储
- **设计模式**: 工厂模式和单例管理器
- **配置管理**: 使用环境变量和TOML文件
- **内存系统**: 基于图的记忆结构,支持记忆构建、压缩、检索和遗忘
- **情绪系统**: 情绪模拟与概率权重
- **LLM集成**: 支持多个LLM服务提供商(ChatAnywhere, SiliconFlow, DeepSeek)
## ⚙️ 环境配置
- 使用`template.env`作为环境变量模板
- 使用`template/bot_config_template.toml`作为机器人配置模板
- MongoDB配置: 主机、端口、数据库名
- API密钥配置: 各LLM提供商的API密钥

View File

@@ -1,5 +1,4 @@
# 麦麦MaiMBot (编辑中)
# 麦麦MaiMBot (编辑中)
<div align="center">
@@ -29,37 +28,43 @@
</a>
</div>
> ⚠️ **注意事项**
> [!WARNING]
> - 项目处于活跃开发阶段,代码可能随时更改
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
> - 由于开发中可能消耗较多token
**交流群**: 766798517 一群人较多,建议加下面的(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
**交流群**: 571780722 另一个群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
**交流群**: 1035228475 另一个群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
## 💬交流群
- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 ,建议加下面的(开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722 (开发和建议相关讨论)不一定有空回复,会优先写文档和代码
- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475开发和建议相关讨论不一定有空回复会优先写文档和代码
**其他平台版本**
- (由 [CabLate](https://github.com/cablate) 贡献) [Telegram 与其他平台(未来可能会有)的版本](https://github.com/cablate/MaiMBot/tree/telegram) - [集中讨论串](https://github.com/SengokuCola/MaiMBot/discussions/149)
##
<div align="left">
<h2>📚 文档 ⬇️ 快速开始使用麦麦 ⬇️</h2>
</div>
### 部署方式
如果你不知道Docker是什么建议寻找相关教程或使用手动部署现在不建议使用docker更新慢可能不适配
- 📦 **Windows 一键傻瓜式部署**:请运行项目根目录中的 `run.bat`,部署完成后请参照后续配置指南进行配置
- [📦 Windows 手动部署指南 ](docs/manual_deploy_windows.md)
- [📦 Linux 手动部署指南 ](docs/manual_deploy_linux.md)
如果你不知道Docker是什么建议寻找相关教程或使用手动部署 **现在不建议使用docker更新慢可能不适配**
- [🐳 Docker部署指南](docs/docker_deploy.md)
- [📦 手动部署指南 Windows](docs/manual_deploy_windows.md)
- [📦 手动部署指南 Linux](docs/manual_deploy_linux.md)
- 📦 Windows 一键傻瓜式部署,请运行项目根目录中的 ```run.bat```,部署完成后请参照后续配置指南进行配置
### 配置说明
- [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘
- [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户
@@ -72,6 +77,7 @@
## 🎯 功能介绍
### 💬 聊天功能
- 支持关键词检索主动发言对消息的话题topic进行识别如果检测到麦麦存储过的话题就会主动进行发言
- 支持bot名字呼唤发言检测到"麦麦"会主动发言,可配置
- 支持多模型,多厂商自定义配置
@@ -80,31 +86,33 @@
- 错别字和多条回复功能麦麦可以随机生成错别字会多条发送回复以及对消息进行reply
### 😊 表情包功能
- 支持根据发言内容发送对应情绪的表情包
- 会自动偷群友的表情包
### 📅 日程功能
- 麦麦会自动生成一天的日程,实现更拟人的回复
### 🧠 记忆功能
- 对聊天记录进行概括存储,在需要时调用,待完善
### 📚 知识库功能
- 基于embedding模型的知识库手动放入txt会自动识别写完了暂时禁用
### 👥 关系功能
- 针对每个用户创建"关系"可以对不同用户进行个性化回复目前只有极其简单的好感度WIP
- 针对每个群创建"群印象"可以对不同群进行个性化回复WIP
## 开发计划TODOLIST
规划主线
0.6.0:记忆系统更新
0.7.0: 麦麦RunTime
- 人格功能WIP
- 群氛围功能WIP
- 图片发送转发功能WIP
@@ -124,7 +132,6 @@
- 采用截断生成加快麦麦的反应速度
- 改进发送消息的触发
## 设计理念
- **千石可乐说:**
@@ -134,13 +141,14 @@
- 如果人类真的需要一个AI来陪伴自己并不是所有人都需要一个完美的能解决所有问题的helpful assistant而是一个会犯错的拥有自己感知和想法的"生命形式"。
- 代码会保持开源和开放但个人希望MaiMbot的运行时数据保持封闭尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器.
## 📌 注意事项
SengokuCola纯编程外行面向cursor编程很多代码史一样多多包涵
> ⚠️ **警告**:本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI生成内容不代表本人观点和立场。
SengokuCola~~纯编程外行面向cursor编程很多代码写得不好多多包涵~~已得到大脑升级
> [!WARNING]
> 本应用生成内容来自人工智能模型,由 AI 生成请仔细甄别请勿用于违反法律的用途AI生成内容不代表本人观点和立场。
## 致谢
[nonebot2](https://github.com/nonebot/nonebot2): 跨平台 Python 异步聊天机器人框架
[NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现
@@ -149,9 +157,9 @@ SengokuCola纯编程外行面向cursor编程很多代码史一样多多包
感谢各位大佬!
<a href="https://github.com/SengokuCola/MaiMBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=SengokuCola/MaiMBot&time=true" />
<img src="https://contrib.rocks/image?repo=SengokuCola/MaiMBot" />
</a>
## Stargazers over time
[![Stargazers over time](https://starchart.cc/SengokuCola/MaiMBot.svg?variant=adaptive)](https://starchart.cc/SengokuCola/MaiMBot)

104
bot.py
View File

@@ -1,7 +1,12 @@
import asyncio
import os
import shutil
import sys
import nonebot
import time
import uvicorn
from dotenv import load_dotenv
from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter
@@ -10,6 +15,9 @@ import platform
# 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ}
uvicorn_server = None
def easter_egg():
# 彩蛋
from colorama import init, Fore
@@ -22,11 +30,12 @@ def easter_egg():
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
print(rainbow_text)
def init_config():
# 初次启动检测
if not os.path.exists("config/bot_config.toml"):
logger.warning("检测到bot_config.toml不存在正在从模板复制")
# 检查config目录是否存在
if not os.path.exists("config"):
os.makedirs("config")
@@ -35,6 +44,7 @@ def init_config():
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动")
def init_env():
# 初始化.env 默认ENVIRONMENT=prod
if not os.path.exists(".env"):
@@ -46,11 +56,17 @@ def init_env():
logger.error("检测到.env.prod文件不存在")
shutil.copy("template.env", "./.env.prod")
# 检测.env.dev文件是否存在不存在的话直接复制生产环境配置
if not os.path.exists(".env.dev"):
logger.error("检测到.env.dev文件不存在")
shutil.copy(".env.prod", "./.env.dev")
# 首先加载基础环境变量.env
if os.path.exists(".env"):
load_dotenv(".env")
load_dotenv(".env",override=True)
logger.success("成功加载基础环境变量配置")
def load_env():
# 使用闭包实现对加载器的横向扩展,避免大量重复判断
def prod():
@@ -70,7 +86,7 @@ def load_env():
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
if env in fn_map:
fn_map[env]() # 根据映射执行闭包函数
fn_map[env]() # 根据映射执行闭包函数
elif os.path.exists(f".env.{env}"):
logger.success(f"加载{env}环境变量配置")
@@ -81,6 +97,29 @@ def load_env():
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def load_logger():
logger.remove() # 移除默认配置
if os.getenv("ENVIRONMENT") == "dev":
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别默认为DEBUG
)
else:
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "INFO"), # 根据环境设置日志级别默认为INFO
filter=lambda record: "nonebot" not in record["name"]
)
def scan_provider(env_config: dict):
provider = {}
@@ -115,16 +154,51 @@ def scan_provider(env_config: dict):
)
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
if __name__ == "__main__":
async def graceful_shutdown():
try:
global uvicorn_server
if uvicorn_server:
uvicorn_server.force_exit = True # 强制退出
await uvicorn_server.shutdown()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"麦麦关闭失败: {e}")
async def uvicorn_main():
global uvicorn_server
config = uvicorn.Config(
app="__main__:app",
host=os.getenv("HOST", "127.0.0.1"),
port=int(os.getenv("PORT", 8080)),
reload=os.getenv("ENVIRONMENT") == "dev",
timeout_graceful_shutdown=5,
log_config=None,
access_log=False
)
server = uvicorn.Server(config)
uvicorn_server = server
await server.serve()
def raw_main():
# 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != 'windows':
time.tzset()
easter_egg()
load_logger()
init_config()
init_env()
load_env()
load_logger()
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
@@ -140,10 +214,30 @@ if __name__ == "__main__":
nonebot.init(**base_config, **env_config)
# 注册适配器
global driver
driver = nonebot.get_driver()
driver.register_adapter(Adapter)
# 加载插件
nonebot.load_plugins("src/plugins")
nonebot.run()
if __name__ == "__main__":
try:
raw_main()
global app
app = nonebot.get_asgi()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(uvicorn_main())
except KeyboardInterrupt:
logger.warning("麦麦会努力做的更好的!正在停止中......")
except Exception as e:
logger.error(f"主程序异常: {e}")
finally:
loop.run_until_complete(graceful_shutdown())
loop.close()
logger.info("进程终止完毕,麦麦开始休眠......下次再见哦!")

6
changelog.md Normal file
View File

@@ -0,0 +1,6 @@
# Changelog
## [0.5.12] - 2025-3-9
### Added
- 新增了 我是测试

12
changelog_config.md Normal file
View File

@@ -0,0 +1,12 @@
# Changelog
## [0.0.5] - 2025-3-11
### Added
- 新增了 `alias_names` 配置项,用于指定麦麦的别名。
## [0.0.4] - 2025-3-9
### Added
- 新增了 `memory_ban_words` 配置项,用于指定不希望记忆的词汇。

59
config/auto_update.py Normal file
View File

@@ -0,0 +1,59 @@
import os
import shutil
import tomlkit
from pathlib import Path
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
# 读取旧配置文件
old_config = {}
if old_config_path.exists():
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 删除旧的配置文件
if old_config_path.exists():
os.remove(old_config_path)
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
# 读取新配置文件
with open(new_config_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 递归更新配置
def update_dict(target, source):
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
update_dict(target[key], value)
else:
try:
# 直接使用tomlkit的item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
if __name__ == "__main__":
update_config()

View File

@@ -6,8 +6,6 @@ services:
- NAPCAT_UID=${NAPCAT_UID}
- NAPCAT_GID=${NAPCAT_GID} # 让 NapCat 获取当前用户 GID,UID防止权限问题
ports:
- 3000:3000
- 3001:3001
- 6099:6099
restart: unless-stopped
volumes:
@@ -19,7 +17,7 @@ services:
mongodb:
container_name: mongodb
environment:
- tz=Asia/Shanghai
- TZ=Asia/Shanghai
# - MONGO_INITDB_ROOT_USERNAME=your_username
# - MONGO_INITDB_ROOT_PASSWORD=your_password
expose:

20
docs/Jonathan R.md Normal file
View File

@@ -0,0 +1,20 @@
Jonathan R. Wolpaw 在 “Memory in neuroscience: rhetoric versus reality.” 一文中提到,从神经科学的感觉运动假设出发,整个神经系统的功能是将经验与适当的行为联系起来,而不是单纯的信息存储。
Jonathan R,Wolpaw. (2019). Memory in neuroscience: rhetoric versus reality.. Behavioral and cognitive neuroscience reviews(2).
1. **单一过程理论**
- 单一过程理论认为,识别记忆主要是基于熟悉性这一单一因素的影响。熟悉性是指对刺激的一种自动的、无意识的感知,它可以使我们在没有回忆起具体细节的情况下,判断一个刺激是否曾经出现过。
- 例如在一些实验中研究者发现被试可以在没有回忆起具体学习情境的情况下对曾经出现过的刺激做出正确的判断这被认为是熟悉性在起作用1。
2. **双重过程理论**
- 双重过程理论则认为,识别记忆是基于两个过程:回忆和熟悉性。回忆是指对过去经验的有意识的回忆,它可以使我们回忆起具体的细节和情境;熟悉性则是一种自动的、无意识的感知。
- 该理论认为,在识别记忆中,回忆和熟悉性共同作用,使我们能够判断一个刺激是否曾经出现过。例如,在 “记得 / 知道” 范式中被试被要求判断他们对一个刺激的记忆是基于回忆还是熟悉性。研究发现被试可以区分这两种不同的记忆过程这为双重过程理论提供了支持1。
1. **神经元节点与连接**:借鉴神经网络原理,将每个记忆单元视为一个神经元节点。节点之间通过连接相互关联,连接的强度代表记忆之间的关联程度。在形态学联想记忆中,具有相似形态特征的记忆节点连接强度较高。例如,苹果和橘子的记忆节点,由于在形状、都是水果等形态语义特征上相似,它们之间的连接强度大于苹果与汽车记忆节点间的连接强度。
2. **记忆聚类与层次结构**:依据形态特征的相似性对记忆进行聚类,形成不同的记忆簇。每个记忆簇内部的记忆具有较高的相似性,而不同记忆簇之间的记忆相似性较低。同时,构建记忆的层次结构,高层次的记忆节点代表更抽象、概括的概念,低层次的记忆节点对应具体的实例。比如,“水果” 作为高层次记忆节点,连接着 “苹果”“橘子”“香蕉” 等低层次具体水果的记忆节点。
3. **网络的动态更新**:随着新记忆的不断加入,记忆网络动态调整。新记忆节点根据其形态特征与现有网络中的节点建立连接,同时影响相关连接的强度。若新记忆与某个记忆簇的特征高度相似,则被纳入该记忆簇;若具有独特特征,则可能引发新的记忆簇的形成。例如,当系统学习到一种新的水果 “番石榴”,它会根据番石榴的形态、语义等特征,在记忆网络中找到与之最相似的区域(如水果记忆簇),并建立相应连接,同时调整周围节点连接强度以适应这一新记忆。
- **相似性联想**:该理论认为,当两个或多个事物在形态上具有相似性时,它们在记忆中会形成关联。例如,梨和苹果在形状和都是水果这一属性上有相似性,所以当我们看到梨时,很容易通过形态学联想记忆联想到苹果。这种相似性联想有助于我们对新事物进行分类和理解,当遇到一个新的类似水果时,我们可以通过与已有的水果记忆进行相似性匹配,来推测它的一些特征。
- **时空关联性联想**除了相似性联想MAM 还强调时空关联性联想。如果两个事物在时间或空间上经常同时出现,它们也会在记忆中形成关联。比如,每次在公园里看到花的时候,都能听到鸟儿的叫声,那么花和鸟儿叫声的形态特征(花的视觉形态和鸟叫的听觉形态)就会在记忆中形成关联,以后听到鸟叫可能就会联想到公园里的花。

View File

@@ -1,6 +1,7 @@
# 📂 文件及功能介绍 (2025年更新)
## 根目录
- **README.md**: 项目的概述和使用说明。
- **requirements.txt**: 项目所需的Python依赖包列表。
- **bot.py**: 主启动文件负责环境配置加载和NoneBot初始化。
@@ -10,6 +11,7 @@
- **run_*.bat**: 各种启动脚本包括数据库、maimai和thinking功能。
## `src/` 目录结构
- **`plugins/` 目录**: 存放不同功能模块的插件。
- **chat/**: 处理聊天相关的功能,如消息发送和接收。
- **memory_system/**: 处理机器人的记忆功能。
@@ -22,94 +24,96 @@
- **`common/` 目录**: 存放通用的工具和库。
- **database.py**: 处理与数据库的交互,负责数据的存储和检索。
- **__init__.py**: 初始化模块。
- ****init**.py**: 初始化模块。
## `config/` 目录
- **bot_config_template.toml**: 机器人配置模板。
- **auto_format.py**: 自动格式化工具。
### `src/plugins/chat/` 目录文件详细介绍
1. **`__init__.py`**:
1. **`__init__.py`**:
- 初始化 `chat` 模块,使其可以作为一个包被导入。
2. **`bot.py`**:
2. **`bot.py`**:
- 主要的聊天机器人逻辑实现,处理消息的接收、思考和回复。
- 包含 `ChatBot` 类,负责消息处理流程控制。
- 集成记忆系统和意愿管理。
3. **`config.py`**:
3. **`config.py`**:
- 配置文件,定义了聊天机器人的各种参数和设置。
- 包含 `BotConfig` 和全局配置对象 `global_config`
4. **`cq_code.py`**:
4. **`cq_code.py`**:
- 处理 CQ 码CoolQ 码),用于发送和接收特定格式的消息。
5. **`emoji_manager.py`**:
5. **`emoji_manager.py`**:
- 管理表情包的发送和接收,根据情感选择合适的表情。
- 提供根据情绪获取表情的方法。
6. **`llm_generator.py`**:
6. **`llm_generator.py`**:
- 生成基于大语言模型的回复,处理用户输入并生成相应的文本。
- 通过 `ResponseGenerator` 类实现回复生成。
7. **`message.py`**:
7. **`message.py`**:
- 定义消息的结构和处理逻辑,包含多种消息类型:
- `Message`: 基础消息类
- `MessageSet`: 消息集合
- `Message_Sending`: 发送中的消息
- `Message_Thinking`: 思考状态的消息
8. **`message_sender.py`**:
8. **`message_sender.py`**:
- 控制消息的发送逻辑,确保消息按照特定规则发送。
- 包含 `message_manager` 对象,用于管理消息队列。
9. **`prompt_builder.py`**:
9. **`prompt_builder.py`**:
- 构建用于生成回复的提示,优化机器人的响应质量。
10. **`relationship_manager.py`**:
10. **`relationship_manager.py`**:
- 管理用户之间的关系,记录用户的互动和偏好。
- 提供更新关系和关系值的方法。
11. **`Segment_builder.py`**:
11. **`Segment_builder.py`**:
- 构建消息片段的工具。
12. **`storage.py`**:
12. **`storage.py`**:
- 处理数据存储,负责将聊天记录和用户信息保存到数据库。
- 实现 `MessageStorage` 类管理消息存储。
13. **`thinking_idea.py`**:
13. **`thinking_idea.py`**:
- 实现机器人的思考机制。
14. **`topic_identifier.py`**:
14. **`topic_identifier.py`**:
- 识别消息中的主题,帮助机器人理解用户的意图。
15. **`utils.py`** 和 **`utils_*.py`** 系列文件:
15. **`utils.py`** 和 **`utils_*.py`** 系列文件:
- 存放各种工具函数,提供辅助功能以支持其他模块。
- 包括 `utils_cq.py``utils_image.py``utils_user.py` 等专门工具。
16. **`willing_manager.py`**:
16. **`willing_manager.py`**:
- 管理机器人的回复意愿,动态调整回复概率。
- 通过多种因素(如被提及、话题兴趣度)影响回复决策。
### `src/plugins/memory_system/` 目录文件介绍
1. **`memory.py`**:
1. **`memory.py`**:
- 实现记忆管理核心功能,包含 `memory_graph` 对象。
- 提供相关项目检索,支持多层次记忆关联。
2. **`draw_memory.py`**:
2. **`draw_memory.py`**:
- 记忆可视化工具。
3. **`memory_manual_build.py`**:
3. **`memory_manual_build.py`**:
- 手动构建记忆的工具。
4. **`offline_llm.py`**:
4. **`offline_llm.py`**:
- 离线大语言模型处理功能。
## 消息处理流程
### 1. 消息接收与预处理
- 通过 `ChatBot.handle_message()` 接收群消息。
- 进行用户和群组的权限检查。
- 更新用户关系信息。
@@ -117,12 +121,14 @@
- 对消息进行过滤和敏感词检测。
### 2. 主题识别与决策
- 使用 `topic_identifier` 识别消息主题。
- 通过记忆系统检查对主题的兴趣度。
- `willing_manager` 动态计算回复概率。
- 根据概率决定是否回复消息。
### 3. 回复生成与发送
- 如需回复,首先创建 `Message_Thinking` 对象表示思考状态。
- 调用 `ResponseGenerator.generate_response()` 生成回复内容和情感状态。
- 删除思考消息,创建 `MessageSet` 准备发送回复。

View File

@@ -1,67 +1,93 @@
# 🐳 Docker 部署指南
## 部署步骤推荐,但不一定是最新
## 部署步骤 (推荐,但不一定是最新)
**"更新镜像与容器"部分在本文档 [Part 6](#6-更新镜像与容器)**
### 0. 前提说明
**本文假设读者已具备一定的 Docker 基础知识。若您对 Docker 不熟悉,建议先参考相关教程或文档进行学习,或选择使用 [📦Linux手动部署指南](./manual_deploy_linux.md) 或 [📦Windows手动部署指南](./manual_deploy_windows.md) 。**
### 1. 获取Docker配置文件:
### 1. 获取Docker配置文件
- 建议先单独创建好一个文件夹并进入,作为工作目录
```bash
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml -O docker-compose.yml
```
- 若需要启用MongoDB数据库的用户名和密码可进入docker-compose.yml取消MongoDB处的注释并修改变量`=`后方的值为你的用户名和密码\
修改后请注意在之后配置`.env.prod`文件时指定MongoDB数据库的用户名密码
- 若需要启用MongoDB数据库的用户名和密码可进入docker-compose.yml取消MongoDB处的注释并修改变量`=` 后方的值为你的用户名和密码\
修改后请注意在之后配置 `.env.prod` 文件时指定MongoDB数据库的用户名密码
### 2. 启动服务
### 2. 启动服务:
- **!!! 请在第一次启动前确保当前工作目录下`.env.prod``bot_config.toml`文件存在 !!!**\
- **!!! 请在第一次启动前确保当前工作目录下 `.env.prod``bot_config.toml` 文件存在 !!!**\
由于Docker文件映射行为的特殊性若宿主机的映射路径不存在可能导致意外的目录创建而不会创建文件由于此处需要文件映射到文件需提前确保文件存在且路径正确可使用如下命令:
```bash
touch .env.prod
touch bot_config.toml
```
- 启动Docker容器:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose up -d
```
- 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
### 3. 修改配置并重启Docker:
### 3. 修改配置并重启Docker
- 请前往 [🎀 新手配置指南](docs/installation_cute.md) 或 [⚙️ 标准配置指南](docs/installation_standard.md) 完成`.env.prod``bot_config.toml`配置文件的编写\
**需要注意`.env.prod`中HOST处IP的填写Docker中部署和系统中直接安装的配置会有所不同**
- 重启Docker容器:
```bash
docker restart maimbot # 若修改过容器名称则替换maimbot为你自定的名
docker restart maimbot # 若修改过容器名称则替换maimbot为你自定的名
```
- 下方命令可以但不推荐只是同时重启NapCat、MongoDB、MaiMBot三个服务
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose restart
```
- 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
### 4. 登入NapCat管理页添加反向WebSocket
- 在浏览器地址栏输入`http://<宿主机IP>:6099/`进入NapCat的管理Web页添加一个Websocket客户端
- 在浏览器地址栏输入 `http://<宿主机IP>:6099/` 进入NapCat的管理Web页添加一个Websocket客户端
> 网络配置 -> 新建 -> Websocket客户端
- Websocket客户端的名称自定URL栏填入`ws://maimbot:8080/onebot/v11/ws`,启用并保存即可\
- Websocket客户端的名称自定URL栏填入 `ws://maimbot:8080/onebot/v11/ws`,启用并保存即可\
(若修改过容器名称则替换maimbot为你自定的名称)
### 5. 部署完成,愉快地和麦麦对话吧!
### 5. 愉快地和麦麦对话吧!
### 6. 更新镜像与容器
- 拉取最新镜像
```bash
docker-compose pull
```
- 执行启动容器指令,该指令会自动重建镜像有更新的容器并启动
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
# 旧版Docker中可能找不到docker compose请使用docker-compose工具替代
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker-compose up -d
```
## ⚠️ 注意事项
- 目前部署方案仍在测试中,可能存在未知问题
- 配置文件中的API密钥请妥善保管不要泄露
- 建议先在测试环境中运行,确认无误后再部署到生产环境
- 建议先在测试环境中运行,确认无误后再部署到生产环境

View File

@@ -1,8 +1,9 @@
# 🔧 配置指南 喵~
## 👋 你好呀
## 👋 你好呀
让咱来告诉你我们要做什么喵:
1. 我们要一起设置一个可爱的AI机器人
2. 这个机器人可以在QQ上陪你聊天玩耍哦
3. 需要设置两个文件才能让机器人工作呢
@@ -10,16 +11,19 @@
## 📝 需要设置的文件喵
要设置这两个文件才能让机器人跑起来哦:
1. `.env.prod` - 这个文件告诉机器人要用哪些AI服务呢
2. `bot_config.toml` - 这个文件教机器人怎么和你聊天喵
## 🔑 密钥和域名的对应关系
想象一下,你要进入一个游乐园,需要:
1. 知道游乐园的地址(这就是域名 base_url
2. 有入场的门票(这就是密钥 key
`.env.prod` 文件里,我们定义了三个游乐园的地址和门票喵:
```ini
# 硅基流动游乐园
SILICONFLOW_KEY=your_key # 硅基流动的门票
@@ -35,6 +39,7 @@ CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere的地
```
然后在 `bot_config.toml` 里,机器人会用这些门票和地址去游乐园玩耍:
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
@@ -47,22 +52,24 @@ base_url = "SILICONFLOW_BASE_URL" # 还是去硅基流动游乐园
key = "SILICONFLOW_KEY" # 用同一张门票就可以啦
```
### 🎪 举个例子喵
### 🎪 举个例子喵
如果你想用DeepSeek官方的服务就要这样改
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
name = "deepseek-reasoner" # 改成对应的模型名称,这里为DeepseekR1
base_url = "DEEP_SEEK_BASE_URL" # 改成去DeepSeek游乐园
key = "DEEP_SEEK_KEY" # 用DeepSeek的门票
[model.llm_normal]
name = "Pro/deepseek-ai/DeepSeek-V3"
name = "deepseek-chat" # 改成对应的模型名称,这里为DeepseekV3
base_url = "DEEP_SEEK_BASE_URL" # 也去DeepSeek游乐园
key = "DEEP_SEEK_KEY" # 用同一张DeepSeek门票
```
### 🎯 简单来说
### 🎯 简单来说
- `.env.prod` 文件就像是你的票夹,存放着各个游乐园的门票和地址
- `bot_config.toml` 就是告诉机器人:用哪张票去哪个游乐园玩
- 所有模型都可以用同一个游乐园的票,也可以去不同的游乐园玩耍
@@ -88,19 +95,25 @@ CHAT_ANY_WHERE_KEY=your_key
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# 如果你不知道这是什么,那么下面这些不用改,保持原样就好啦
HOST=127.0.0.1 # 如果使用Docker部署需要改成0.0.0.0喵,不然听不见群友讲话了喵
# 如果使用Docker部署需要改成0.0.0.0喵,不然听不见群友讲话了喵
HOST=127.0.0.1
PORT=8080
# 这些是数据库设置,一般也不用改呢
MONGODB_HOST=127.0.0.1 # 如果使用Docker部署需要改成数据库容器的名字喵默认是mongodb喵
# 如果使用Docker部署需要把MONGODB_HOST改成数据库容器的名字喵默认是mongodb喵
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 如果数据库需要用户名,就在这里填写
MONGODB_PASSWORD = "" # 如果数据库需要密码,就在这里填写呢
MONGODB_AUTH_SOURCE = "" # 数据库认证源,一般不用改哦
# 数据库认证信息,如果需要认证就取消注释并填写下面三行
# MONGODB_USERNAME = ""
# MONGODB_PASSWORD = ""
# MONGODB_AUTH_SOURCE = ""
# 插件设置喵
PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢
# 也可以使用URI连接数据库取消注释填写在下面这行喵URI的优先级比上面的高
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# 这里是机器人的插件列表呢
PLUGINS=["src2.plugins.chat"]
```
### 第二个文件:机器人配置 (bot_config.toml)
@@ -110,7 +123,8 @@ PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢
```toml
[bot]
qq = "把这里改成你的机器人QQ号喵" # 填写你的机器人QQ号
nickname = "麦麦" # 机器人的名字,你可以改成你喜欢的任何名字哦
nickname = "麦麦" # 机器人的名字,你可以改成你喜欢的任何名字哦建议和机器人QQ名称/群昵称一样哦
alias_names = ["小麦", "阿麦"] # 也可以用这个招呼机器人,可以不设置呢
[personality]
# 这里可以设置机器人的性格呢,让它更有趣一些喵
@@ -198,10 +212,12 @@ key = "SILICONFLOW_KEY"
- `topic`: 负责理解对话主题的能力呢
## 🌟 小提示
- 如果你刚开始使用,建议保持默认配置呢
- 不同的模型有不同的特长,可以根据需要调整它们的使用比例哦
## 🌟 小贴士喵
- 记得要好好保管密钥key不要告诉别人呢
- 配置文件要小心修改,改错了机器人可能就不能和你玩了喵
- 如果想让机器人更聪明,可以调整 personality 里的设置呢
@@ -209,7 +225,8 @@ key = "SILICONFLOW_KEY"
- QQ群号和QQ号都要用数字填写不要加引号哦除了机器人自己的QQ号
## ⚠️ 注意事项
- 这个机器人还在测试中呢,可能会有一些小问题喵
- 如果不知道怎么改某个设置,就保持原样不要动它哦~
- 记得要先有AI服务的密钥不然机器人就不能和你说话了呢
- 修改完配置后要重启机器人才能生效喵~
- 修改完配置后要重启机器人才能生效喵~

View File

@@ -3,14 +3,16 @@
## 简介
本项目需要配置两个主要文件:
1. `.env.prod` - 配置API服务和系统环境
2. `bot_config.toml` - 配置机器人行为和模型
## API配置说明
`.env.prod``bot_config.toml`中的API配置关系如下
`.env.prod``bot_config.toml` 中的API配置关系如下
### 在.env.prod中定义API凭证
### 在.env.prod中定义API凭证
```ini
# API凭证配置
SILICONFLOW_KEY=your_key # 硅基流动API密钥
@@ -23,7 +25,8 @@ CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere API密钥
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere API地址
```
### 在bot_config.toml中引用API凭证
### 在bot_config.toml中引用API凭证
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
@@ -32,9 +35,10 @@ key = "SILICONFLOW_KEY" # 引用.env.prod中定义的密钥
```
如需切换到其他API服务只需修改引用
```toml
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
name = "deepseek-reasoner" # 改成对应的模型名称,这里为DeepseekR1
base_url = "DEEP_SEEK_BASE_URL" # 切换为DeepSeek服务
key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥
```
@@ -42,6 +46,7 @@ key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥
## 配置文件详解
### 环境配置文件 (.env.prod)
```ini
# API配置
SILICONFLOW_KEY=your_key
@@ -52,26 +57,36 @@ CHAT_ANY_WHERE_KEY=your_key
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# 服务配置
HOST=127.0.0.1 # 如果使用Docker部署需要改成0.0.0.0否则QQ消息无法传入
PORT=8080
PORT=8080 # 与反向端口相同
# 数据库配置
MONGODB_HOST=127.0.0.1 # 如果使用Docker部署需要改成数据库容器的名字默认是mongodb
MONGODB_PORT=27017
MONGODB_PORT=27017 # MongoDB端口
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 数据库用户名
MONGODB_PASSWORD = "" # 数据库密码
MONGODB_AUTH_SOURCE = "" # 认证数据库
# 数据库认证信息,如果需要认证就取消注释并填写下面三行
# MONGODB_USERNAME = ""
# MONGODB_PASSWORD = ""
# MONGODB_AUTH_SOURCE = ""
# 也可以使用URI连接数据库取消注释填写在下面这行URI的优先级比上面的高
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# 插件配置
PLUGINS=["src2.plugins.chat"]
```
### 机器人配置文件 (bot_config.toml)
```toml
[bot]
qq = "机器人QQ号" # 必填
nickname = "麦麦" # 机器人昵称
# alias_names: 配置机器人可使用的别名。当机器人在群聊或对话中被调用时,别名可以作为直接命令或提及机器人的关键字使用。
# 该配置项为字符串数组。例如: ["小麦", "阿麦"]
alias_names = ["小麦", "阿麦"] # 机器人别名
[personality]
prompt_personality = [
@@ -151,4 +166,4 @@ key = "SILICONFLOW_KEY"
3. 其他说明:
- 项目处于测试阶段,可能存在未知问题
- 建议初次使用保持默认配置
- 建议初次使用保持默认配置

View File

@@ -1,6 +1,7 @@
# 📦 Linux系统如何手动部署MaiMbot麦麦
## 准备工作
- 一台联网的Linux设备本教程以Ubuntu/Debian系为例
- QQ小号QQ框架的使用可能导致qq被风控严重小概率可能会导致账号封禁强烈不推荐使用大号
- 可用的大模型API
@@ -20,6 +21,7 @@
- 数据库是什么如何安装并启动MongoDB
- 如何运行一个QQ机器人以及NapCat框架是什么
---
## 环境配置
@@ -33,7 +35,9 @@ python --version
# 或
python3 --version
```
如果版本低于3.9请更新Python版本。
```bash
# Ubuntu/Debian
sudo apt update
@@ -45,6 +49,7 @@ sudo update-alternatives --config python3
```
### 2⃣ **创建虚拟环境**
```bash
# 方法1使用venv(推荐)
python3 -m venv maimbot
@@ -65,32 +70,37 @@ pip install -r requirements.txt
---
## 数据库配置
### 3⃣ **安装并启动MongoDB**
- 安装与启动Debian参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-debian/)Ubuntu参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-ubuntu/)
### 3⃣ **安装并启动MongoDB**
- 安装与启动Debian参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-debian/)Ubuntu参考[官方文档](https://docs.mongodb.com/manual/tutorial/install-mongodb-on-ubuntu/)
- 默认连接本地27017端口
---
## NapCat配置
### 4⃣ **安装NapCat框架**
- 参考[NapCat官方文档](https://www.napcat.wiki/guide/boot/Shell#napcat-installer-linux%E4%B8%80%E9%94%AE%E4%BD%BF%E7%94%A8%E8%84%9A%E6%9C%AC-%E6%94%AF%E6%8C%81ubuntu-20-debian-10-centos9)安装
- 使用QQ小号登录添加反向WS地址
`ws://127.0.0.1:8080/onebot/v11/ws`
- 使用QQ小号登录添加反向WS地址: `ws://127.0.0.1:8080/onebot/v11/ws`
---
## 配置文件设置
### 5⃣ **配置文件设置让麦麦Bot正常工作**
- 修改环境配置文件:`.env.prod`
- 修改机器人配置文件:`bot_config.toml`
---
## 启动机器人
### 6⃣ **启动麦麦机器人**
```bash
# 在项目目录下操作
nb run
@@ -100,17 +110,70 @@ python3 bot.py
---
## **其他组件(可选)**
- 直接运行 knowledge.py生成知识库
### 7⃣ **使用systemctl管理maimbot**
使用以下命令添加服务文件:
```bash
sudo nano /etc/systemd/system/maimbot.service
```
输入以下内容:
`<maimbot_directory>`你的maimbot目录
`<venv_directory>`你的venv环境就是上文创建环境后执行的代码`source maimbot/bin/activate`中source后面的路径的绝对路径
```ini
[Unit]
Description=MaiMbot 麦麦
After=network.target mongod.service
[Service]
Type=simple
WorkingDirectory=<maimbot_directory>
ExecStart=<venv_directory>/python3 bot.py
ExecStop=/bin/kill -2 $MAINPID
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
```
输入以下命令重新加载systemd
```bash
sudo systemctl daemon-reload
```
启动并设置开机自启:
```bash
sudo systemctl start maimbot
sudo systemctl enable maimbot
```
输入以下命令查看日志:
```bash
sudo journalctl -xeu maimbot
```
---
## **其他组件(可选)**
- 直接运行 knowledge.py生成知识库
---
## 常见问题
🔧 权限问题:在命令前加`sudo`
🔌 端口占用:使用`sudo lsof -i :8080`查看端口占用
🛡️ 防火墙确保8080/27017端口开放
```bash
sudo ufw allow 8080/tcp
sudo ufw allow 27017/tcp
```
```

View File

@@ -30,12 +30,13 @@
在创建虚拟环境之前请确保你的电脑上安装了Python 3.9及以上版本。如果没有,可以按以下步骤安装:
1. 访问Python官网下载页面https://www.python.org/downloads/release/python-3913/
1. 访问Python官网下载页面<https://www.python.org/downloads/release/python-3913/>
2. 下载Windows安装程序 (64-bit): `python-3.9.13-amd64.exe`
3. 运行安装程序,并确保勾选"Add Python 3.9 to PATH"选项
4. 点击"Install Now"开始安装
或者使用PowerShell自动下载安装需要管理员权限
```powershell
# 下载并安装Python 3.9.13
$pythonUrl = "https://www.python.org/ftp/python/3.9.13/python-3.9.13-amd64.exe"
@@ -46,7 +47,7 @@ Start-Process -Wait -FilePath $pythonInstaller -ArgumentList "/quiet", "InstallA
### 2⃣ **创建Python虚拟环境来运行程序**
你可以选择使用以下两种方法之一来创建Python环境
> 你可以选择使用以下两种方法之一来创建Python环境
```bash
# ---方法1使用venvPython自带
@@ -60,6 +61,7 @@ maimbot\\Scripts\\activate
# 安装依赖
pip install -r requirements.txt
```
```bash
# ---方法2使用conda
# 创建一个新的conda环境环境名为maimbot
@@ -74,27 +76,35 @@ pip install -r requirements.txt
```
### 2⃣ **然后你需要启动MongoDB数据库来存储信息**
- 安装并启动MongoDB服务
- 默认连接本地27017端口
### 3⃣ **配置NapCat让麦麦bot与qq取得联系**
- 安装并登录NapCat用你的qq小号
- 添加反向WS`ws://127.0.0.1:8080/onebot/v11/ws`
- 添加反向WS: `ws://127.0.0.1:8080/onebot/v11/ws`
### 4⃣ **配置文件设置让麦麦Bot正常工作**
- 修改环境配置文件:`.env.prod`
- 修改机器人配置文件:`bot_config.toml`
### 5⃣ **启动麦麦机器人**
- 打开命令行cd到对应路径
```bash
nb run
```
- 或者cd到对应路径后
```bash
python bot.py
```
### 6⃣ **其他组件(可选)**
- `run_thingking.bat`: 启动可视化推理界面(未完善)
- 直接运行 knowledge.py生成知识库

56
flake.lock generated
View File

@@ -1,43 +1,21 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1741196730,
"narHash": "sha256-0Sj6ZKjCpQMfWnN0NURqRCQn2ob7YtXTAOTwCuz7fkA=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "48913d8f9127ea6530a2a2f1bd4daa1b8685d8a3",
"type": "github"
"lastModified": 0,
"narHash": "sha256-nJj8f78AYAxl/zqLiFGXn5Im1qjFKU8yBPKoWEeZN5M=",
"path": "/nix/store/f30jn7l0bf7a01qj029fq55i466vmnkh-source",
"type": "path"
},
"original": {
"owner": "NixOS",
"ref": "nixos-24.11",
"repo": "nixpkgs",
"type": "github"
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
"nixpkgs": "nixpkgs",
"utils": "utils"
}
},
"systems": {
@@ -54,6 +32,24 @@
"repo": "default",
"type": "github"
}
},
"utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
}
},
"root": "root",

View File

@@ -1,61 +1,38 @@
{
description = "MaiMBot Nix Dev Env";
# 本配置仅方便用于开发,但是因为 nb-cli 上游打包中并未包含 nonebot2因此目前本配置并不能用于运行和调试
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
flake-utils.url = "github:numtide/flake-utils";
utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
};
outputs = {
self,
nixpkgs,
utils,
...
}:
utils.lib.eachDefaultSystem (system: let
pkgs = import nixpkgs {inherit system;};
pythonPackages = pkgs.python3Packages;
in {
devShells.default = pkgs.mkShell {
name = "python-venv";
venvDir = "./.venv";
buildInputs = [
pythonPackages.python
pythonPackages.venvShellHook
pythonPackages.numpy
];
pythonEnv = pkgs.python3.withPackages (
ps: with ps; [
pymongo
python-dotenv
pydantic
jieba
openai
aiohttp
requests
urllib3
numpy
pandas
matplotlib
networkx
python-dateutil
APScheduler
loguru
tomli
customtkinter
colorama
pypinyin
pillow
setuptools
]
);
in
{
devShell = pkgs.mkShell {
buildInputs = [
pythonEnv
pkgs.nb-cli
];
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
pip install -r requirements.txt
'';
shellHook = ''
'';
};
}
);
}
postShellHook = ''
# allow pip to install wheels
unset SOURCE_DATE_EPOCH
'';
};
});
}

View File

@@ -1,23 +1,51 @@
[project]
name = "Megbot"
name = "MaiMaiBot"
version = "0.1.0"
description = "New Bot Project"
description = "MaiMaiBot"
[tool.nonebot]
plugins = ["src.plugins.chat"]
plugin_dirs = ["src/plugins"]
plugin_dirs = ["src/plugins"]
[tool.ruff]
# 设置 Python 版本
target-version = "py39"
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
# 行长度设置
line-length = 88
ignore = ["E711"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

Binary file not shown.

10
run.bat
View File

@@ -1,6 +1,10 @@
@ECHO OFF
chcp 65001
python -m venv venv
call venv\Scripts\activate.bat
pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade -r requirements.txt
if not exist "venv" (
python -m venv venv
call venv\Scripts\activate.bat
pip install -i https://mirrors.aliyun.com/pypi/simple --upgrade -r requirements.txt
) else (
call venv\Scripts\activate.bat
)
python run.py

65
run.py
View File

@@ -37,7 +37,7 @@ def extract_files(zip_path, target_dir):
f.write(zip_ref.read(file))
def run_cmd(command: str, open_new_window: bool = False):
def run_cmd(command: str, open_new_window: bool = True):
"""
运行 cmd 命令
@@ -45,26 +45,19 @@ def run_cmd(command: str, open_new_window: bool = False):
command (str): 指定要运行的命令
open_new_window (bool): 指定是否新建一个 cmd 窗口运行
"""
creationflags = 0
if open_new_window:
creationflags = subprocess.CREATE_NEW_CONSOLE
subprocess.Popen(
[
"cmd.exe",
"/c",
command,
],
creationflags=creationflags,
)
command = "start " + command
subprocess.Popen(command, shell=True)
def run_maimbot():
run_cmd(r"napcat\NapCatWinBootMain.exe 10001", False)
if not os.path.exists(r"mongodb\db"):
os.makedirs(r"mongodb\db")
run_cmd(
r"mongodb\bin\mongod.exe --dbpath=" + os.getcwd() + r"\mongodb\db --port 27017",
True,
r"mongodb\bin\mongod.exe --dbpath=" + os.getcwd() + r"\mongodb\db --port 27017"
)
run_cmd("nb run", True)
run_cmd("nb run")
def install_mongodb():
@@ -87,17 +80,35 @@ def install_mongodb():
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
extract_files("mongodb.zip", "mongodb")
print("MongoDB 下载完成")
os.remove("mongodb.zip")
extract_files("mongodb.zip", "mongodb")
print("MongoDB 下载完成")
os.remove("mongodb.zip")
choice = input(
"是否安装 MongoDB Compass此软件可以以可视化的方式修改数据库建议安装Y/n"
).upper()
if choice == "Y" or choice == "":
install_mongodb_compass()
def install_mongodb_compass():
run_cmd(
r"powershell Start-Process powershell -Verb runAs 'Set-ExecutionPolicy RemoteSigned'"
)
input("请在弹出的用户账户控制中点击“是”后按任意键继续安装")
run_cmd(r"powershell mongodb\bin\Install-Compass.ps1")
input("按任意键启动麦麦")
input("如不需要启动此窗口可直接关闭,无需等待 Compass 安装完成")
run_maimbot()
def install_napcat():
run_cmd("start https://github.com/NapNeko/NapCatQQ/releases", True)
run_cmd("start https://github.com/NapNeko/NapCatQQ/releases", False)
print("请检查弹出的浏览器窗口,点击**第一个**蓝色的“Win64无头” 下载 napcat")
napcat_filename = input(
"下载完成后请把文件复制到此文件夹,并将**不包含后缀的文件名**输入至此窗口,如 NapCat.32793.Shell"
)
if(napcat_filename[-4:] == ".zip"):
napcat_filename = napcat_filename[:-4]
extract_files(napcat_filename + ".zip", "napcat")
print("NapCat 安装完成")
os.remove(napcat_filename + ".zip")
@@ -114,14 +125,20 @@ if __name__ == "__main__":
"请输入要进行的操作:\n"
"1.首次安装\n"
"2.运行麦麦\n"
"3.运行麦麦并启动可视化推理界面\n"
)
os.system("cls")
if choice == "1":
install_napcat()
install_mongodb()
confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n")
if confirm == "1":
install_napcat()
install_mongodb()
else:
print("已取消安装")
elif choice == "2":
run_maimbot()
elif choice == "3":
run_maimbot()
run_cmd("python src/gui/reasoning_gui.py", True)
choice = input("是否启动推理可视化未完善y/N").upper()
if choice == "Y":
run_cmd(r"python src\gui\reasoning_gui.py")
choice = input("是否启动记忆可视化未完善y/N").upper()
if choice == "Y":
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")

29
run_memory_vis.bat Normal file
View File

@@ -0,0 +1,29 @@
@echo on
chcp 65001 > nul
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
call conda activate %CONDA_ENV%
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "%CONDA_ENV%" 激活成功
set /p OPTION="请选择运行选项 (1: 运行全部绘制, 2: 运行简单绘制): "
if "%OPTION%"=="1" (
python src/plugins/memory_system/memory_manual_build.py
) else if "%OPTION%"=="2" (
python src/plugins/memory_system/draw_memory.py
) else (
echo 无效的选项
pause
exit /b 1
)
if errorlevel 1 (
echo 命令执行失败,错误代码 %errorlevel%
pause
exit /b 1
)
echo 脚本成功完成
pause

View File

@@ -1,50 +1,51 @@
from typing import Optional
from pymongo import MongoClient
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None):
if username and password:
def __init__(
self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"):
# 优先使用URI连接
self.client = MongoClient(uri)
elif username and password:
# 如果有用户名和密码,使用认证连接
# TODO: 复杂情况直接支持URI吧
self.client = MongoClient(host, port, username=username, password=password, authSource=auth_source)
self.client = MongoClient(
host, port, username=username, password=password, authSource=auth_source
)
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: Optional[str] = None, password: Optional[str] = None, auth_source: Optional[str] = None) -> "Database":
def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance
@classmethod
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages
return cls._instance

View File

@@ -5,6 +5,9 @@ import threading
import time
from datetime import datetime
from typing import Dict, List
from loguru import logger
from typing import Optional
from ..common.database import Database
import customtkinter as ctk
from dotenv import load_dotenv
@@ -17,124 +20,97 @@ root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
# 加载环境变量
if os.path.exists(os.path.join(root_dir, '.env.dev')):
load_dotenv(os.path.join(root_dir, '.env.dev'))
print("成功加载开发环境配置")
logger.info("成功加载开发环境配置")
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
load_dotenv(os.path.join(root_dir, '.env.prod'))
print("成功加载生产环境配置")
logger.info("成功加载生产环境配置")
else:
print("未找到环境配置文件")
logger.error("未找到环境配置文件")
sys.exit(1)
from typing import Optional
from pymongo import MongoClient
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
if username and password:
self.client = MongoClient(
host=host,
port=port,
username=username,
password=password,
authSource=auth_source or 'admin'
)
else:
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance
@classmethod
def get_instance(cls) -> "Database":
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
class ReasoningGUI:
def __init__(self):
# 记录启动时间戳转换为Unix时间戳
self.start_timestamp = datetime.now().timestamp()
print(f"程序启动时间戳: {self.start_timestamp}")
logger.info(f"程序启动时间戳: {self.start_timestamp}")
# 设置主题
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")
# 创建主窗口
self.root = ctk.CTk()
self.root.title('麦麦推理')
self.root.geometry('800x600')
self.root.protocol("WM_DELETE_WINDOW", self._on_closing)
# 初始化数据库连接
try:
self.db = Database.get_instance().db
print("数据库连接成功")
logger.success("数据库连接成功")
except RuntimeError:
print("数据库未初始化,正在尝试初始化...")
logger.warning("数据库未初始化,正在尝试初始化...")
try:
Database.initialize("127.0.0.1", 27017, "maimai_bot")
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db
print("数据库初始化成功")
except Exception as e:
print(f"数据库初始化失败: {e}")
logger.success("数据库初始化成功")
except Exception:
logger.exception("数据库初始化失败")
sys.exit(1)
# 存储群组数据
self.group_data: Dict[str, List[dict]] = {}
# 创建更新队列
self.update_queue = queue.Queue()
# 创建主框架
self.frame = ctk.CTkFrame(self.root)
self.frame.pack(pady=20, padx=20, fill="both", expand=True)
# 添加标题
self.title = ctk.CTkLabel(self.frame, text="麦麦的脑内所想", font=("Arial", 24))
self.title.pack(pady=10, padx=10)
# 创建左右分栏
self.paned = ctk.CTkFrame(self.frame)
self.paned.pack(fill="both", expand=True, padx=10, pady=10)
# 左侧群组列表
self.left_frame = ctk.CTkFrame(self.paned, width=200)
self.left_frame.pack(side="left", fill="y", padx=5, pady=5)
self.group_label = ctk.CTkLabel(self.left_frame, text="群组列表", font=("Arial", 16))
self.group_label.pack(pady=5)
# 创建可滚动框架来容纳群组按钮
self.group_scroll_frame = ctk.CTkScrollableFrame(self.left_frame, width=180, height=400)
self.group_scroll_frame.pack(pady=5, padx=5, fill="both", expand=True)
# 存储群组按钮的字典
self.group_buttons: Dict[str, ctk.CTkButton] = {}
# 当前选中的群组ID
self.selected_group_id: Optional[str] = None
# 右侧内容显示
self.right_frame = ctk.CTkFrame(self.paned)
self.right_frame.pack(side="right", fill="both", expand=True, padx=5, pady=5)
self.content_label = ctk.CTkLabel(self.right_frame, text="推理内容", font=("Arial", 16))
self.content_label.pack(pady=5)
# 创建富文本显示框
self.content_text = ctk.CTkTextbox(self.right_frame, width=500, height=400)
self.content_text.pack(pady=5, padx=5, fill="both", expand=True)
# 配置文本标签 - 只使用颜色
self.content_text.tag_config("timestamp", foreground="#888888") # 时间戳使用灰色
self.content_text.tag_config("user", foreground="#4CAF50") # 用户名使用绿色
@@ -144,11 +120,11 @@ class ReasoningGUI:
self.content_text.tag_config("reasoning", foreground="#FF9800") # 推理过程使用橙色
self.content_text.tag_config("response", foreground="#E91E63") # 回复使用粉色
self.content_text.tag_config("separator", foreground="#666666") # 分隔符使用深灰色
# 底部控制栏
self.control_frame = ctk.CTkFrame(self.frame)
self.control_frame.pack(fill="x", padx=10, pady=5)
self.clear_button = ctk.CTkButton(
self.control_frame,
text="清除显示",
@@ -156,19 +132,19 @@ class ReasoningGUI:
width=120
)
self.clear_button.pack(side="left", padx=5)
# 启动自动更新线程
self.update_thread = threading.Thread(target=self._auto_update, daemon=True)
self.update_thread.start()
# 启动GUI更新检查
self.root.after(100, self._process_queue)
def _on_closing(self):
"""处理窗口关闭事件"""
self.root.quit()
sys.exit(0)
def _process_queue(self):
"""处理更新队列中的任务"""
try:
@@ -183,14 +159,14 @@ class ReasoningGUI:
finally:
# 继续检查队列
self.root.after(100, self._process_queue)
def _update_group_list_gui(self):
"""在主线程中更新群组列表"""
# 清除现有按钮
for button in self.group_buttons.values():
button.destroy()
self.group_buttons.clear()
# 创建新的群组按钮
for group_id in self.group_data.keys():
button = ctk.CTkButton(
@@ -203,16 +179,16 @@ class ReasoningGUI:
)
button.pack(pady=2, padx=5)
self.group_buttons[group_id] = button
# 如果有选中的群组,保持其高亮状态
if self.selected_group_id and self.selected_group_id in self.group_buttons:
self._highlight_selected_group(self.selected_group_id)
def _on_group_select(self, group_id: str):
"""处理群组选择事件"""
self._highlight_selected_group(group_id)
self._update_display_gui(group_id)
def _highlight_selected_group(self, group_id: str):
"""高亮显示选中的群组按钮"""
# 重置所有按钮的颜色
@@ -223,9 +199,9 @@ class ReasoningGUI:
else:
# 恢复其他按钮的默认颜色
button.configure(fg_color="#2B2B2B", hover_color="#404040")
self.selected_group_id = group_id
def _update_display_gui(self, group_id: str):
"""在主线程中更新显示内容"""
if group_id in self.group_data:
@@ -234,19 +210,19 @@ class ReasoningGUI:
# 时间戳
time_str = item['time'].strftime("%Y-%m-%d %H:%M:%S")
self.content_text.insert("end", f"[{time_str}]\n", "timestamp")
# 用户信息
self.content_text.insert("end", "用户: ", "timestamp")
self.content_text.insert("end", f"{item.get('user', '未知')}\n", "user")
# 消息内容
self.content_text.insert("end", "消息: ", "timestamp")
self.content_text.insert("end", f"{item.get('message', '')}\n", "message")
# 模型信息
self.content_text.insert("end", "模型: ", "timestamp")
self.content_text.insert("end", f"{item.get('model', '')}\n", "model")
# Prompt内容
self.content_text.insert("end", "Prompt内容:\n", "timestamp")
prompt_text = item.get('prompt', '')
@@ -257,7 +233,7 @@ class ReasoningGUI:
self.content_text.insert("end", " " + line + "\n", "prompt")
else:
self.content_text.insert("end", " 无Prompt内容\n", "prompt")
# 推理过程
self.content_text.insert("end", "推理过程:\n", "timestamp")
reasoning_text = item.get('reasoning', '')
@@ -268,53 +244,53 @@ class ReasoningGUI:
self.content_text.insert("end", " " + line + "\n", "reasoning")
else:
self.content_text.insert("end", " 无推理过程\n", "reasoning")
# 回复内容
self.content_text.insert("end", "回复: ", "timestamp")
self.content_text.insert("end", f"{item.get('response', '')}\n", "response")
# 分隔符
self.content_text.insert("end", f"\n{'='*50}\n\n", "separator")
self.content_text.insert("end", f"\n{'=' * 50}\n\n", "separator")
# 滚动到顶部
self.content_text.see("1.0")
def _auto_update(self):
"""自动更新函数"""
while True:
try:
# 从数据库获取最新数据,只获取启动时间之后的记录
query = {"time": {"$gt": self.start_timestamp}}
print(f"查询条件: {query}")
logger.debug(f"查询条件: {query}")
# 先获取一条记录检查时间格式
sample = self.db.reasoning_logs.find_one()
if sample:
print(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
cursor = self.db.reasoning_logs.find(query).sort("time", -1)
new_data = {}
total_count = 0
for item in cursor:
# 调试输出
if total_count == 0:
print(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
logger.debug(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
total_count += 1
group_id = str(item.get('group_id', 'unknown'))
if group_id not in new_data:
new_data[group_id] = []
# 转换时间戳为datetime对象
if isinstance(item['time'], (int, float)):
time_obj = datetime.fromtimestamp(item['time'])
elif isinstance(item['time'], datetime):
time_obj = item['time']
else:
print(f"未知的时间格式: {type(item['time'])}")
logger.warning(f"未知的时间格式: {type(item['time'])}")
time_obj = datetime.now() # 使用当前时间作为后备
new_data[group_id].append({
'time': time_obj,
'user': item.get('user', '未知'),
@@ -324,13 +300,13 @@ class ReasoningGUI:
'response': item.get('response', ''),
'prompt': item.get('prompt', '') # 添加prompt字段
})
print(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
logger.info(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
# 更新数据
if new_data != self.group_data:
self.group_data = new_data
print("数据已更新,正在刷新显示...")
logger.info("数据已更新,正在刷新显示...")
# 将更新任务添加到队列
self.update_queue.put({'type': 'update_group_list'})
if self.group_data:
@@ -341,16 +317,16 @@ class ReasoningGUI:
'type': 'update_display',
'group_id': self.selected_group_id
})
except Exception as e:
print(f"自动更新出错: {e}")
except Exception:
logger.exception("自动更新出错")
# 每5秒更新一次
time.sleep(5)
def clear_display(self):
"""清除显示内容"""
self.content_text.delete("1.0", "end")
def run(self):
"""运行GUI"""
self.root.mainloop()
@@ -359,18 +335,18 @@ class ReasoningGUI:
def main():
"""主函数"""
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
app = ReasoningGUI()
app.run()
if __name__ == "__main__":
main()

View File

@@ -1,12 +1,10 @@
import asyncio
import os
import random
import time
import os
from loguru import logger
from nonebot import get_driver, on_command, on_message, require
from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.rule import to_me
from nonebot.typing import T_State
from ...common.database import Database
@@ -18,6 +16,11 @@ from .config import global_config
from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager
from .willing_manager import willing_manager
from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
from .message_sender import message_manager, message_sender
# 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt")
@@ -30,27 +33,21 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
print("\033[1;32m[初始化数据库完成]\033[0m")
logger.success("初始化数据库成功")
# 导入其他模块
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
# from .message_send_control import message_sender
from .message_sender import message_manager, message_sender
# 初始化表情管理器
emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 创建机器人实例
chat_bot = ChatBot()
# 注册群消息处理器
@@ -59,69 +56,80 @@ group_msg = on_message(priority=5)
scheduler = require("nonebot_plugin_apscheduler").scheduler
@driver.on_startup
async def start_background_tasks():
"""启动后台任务"""
# 启动LLM统计
llm_stats.start()
print("\033[1;32m[初始化]\033[0m LLM统计功能启动")
logger.success("LLM统计功能启动成功")
# 初始化并启动情绪管理器
mood_manager = MoodManager.get_instance()
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
print("\033[1;32m[初始化]\033[0m 情绪管理器启动")
logger.success("情绪管理器启动成功")
# 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
await bot_schedule.initialize()
bot_schedule.print_schedule()
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
logger.debug("正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect
async def _(bot: Bot):
"""Bot连接成功时的处理"""
global _message_manager_started
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
await willing_manager.ensure_started()
message_sender.set_bot(bot)
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
logger.success("-----------消息发送器已启动!-----------")
if not _message_manager_started:
asyncio.create_task(message_manager.start_processor())
_message_manager_started = True
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
logger.success("-----------消息处理器已启动!-----------")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
logger.success("-----------开始偷表情包!-----------")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
logger.debug(
"[记忆构建]"
"------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time()
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
logger.success(
f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
"秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task():
@@ -130,9 +138,9 @@ async def merge_memory_task():
# await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
@scheduler.scheduled_job("interval", seconds=30, id="print_mood")
async def print_mood_task():
"""每30秒打印一次情绪状态"""
mood_manager = MoodManager.get_instance()
mood_manager.print_mood_status()

View File

@@ -1,26 +1,27 @@
import re
import time
from random import random
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config
from .cq_code import CQCode # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import (
Message,
Message_Sending,
Message_Thinking, # 导入 Message_Thinking 类
MessageSet,
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import (
MessageRecvCQ,
)
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager
from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot:
@@ -31,10 +32,10 @@ class ChatBot:
self._started = False
self.mood_manager = MoodManager.get_instance() # 获取情绪管理器单例
self.mood_manager.start_mood_update() # 启动情绪更新
self.emoji_chance = 0.2 # 发送表情包的基础概率
# self.message_streams = MessageStreamContainer()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
@@ -42,185 +43,232 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
try:
group_info_api = await bot.get_group_info(group_id=event.group_id)
logger.info(f"成功获取群信息: {group_info_api}")
group_name = group_info_api["group_name"]
except Exception as e:
logger.error(f"获取群信息失败: {str(e)}")
group_name = None
# 白名单设定由nontbot侧完成
# 消息过滤涉及到config有待更新
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
if event.user_id in global_config.ban_user_id:
return
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
message = Message(
group_id=event.group_id,
user_info = UserInfo(
user_id=event.user_id,
message_id=event.message_id,
user_cardname=sender_info['card'],
raw_message=str(event.original_message),
plain_text=event.get_plaintext(),
reply_message=event.reply,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
await message.initialize()
group_info = GroupInfo(
group_id=event.group_id,
group_name=group_name, # 使用获取到的群名称或None
platform="qq",
)
message_cq = MessageRecvCQ(
message_id=event.message_id,
user_info=user_info,
raw_message=str(event.original_message),
group_info=group_info,
reply_message=event.reply,
platform="qq",
)
message_json = message_cq.to_dict()
# 进入maimbot
message = MessageRecv(message_json)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo
)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(
chat_stream=chat,
)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=0.5)
await message.process()
# 过滤词
for word in global_config.ban_words:
if word in message.detailed_plain_text:
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
if word in message.processed_plain_text:
logger.info(f"[群{groupinfo.group_id}]{userinfo.user_nickname}:{message.processed_plain_text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
# 正则表达式过滤
for pattern in global_config.ban_msgs_regex:
if re.search(pattern, message.raw_message):
logger.info(
f"[群{message.message_info.group_info.group_id}]{message.user_nickname}:{message.raw_message}"
)
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = ''
topic = ""
interested_rate = 0
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text)/100
print(f"\033[1;32m[记忆激活]\033[0m {message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
logger.debug(f"{message.processed_plain_text}的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
reply_probability = willing_manager.change_reply_willing_received(
event.group_id,
topic[0] if topic else None,
is_mentioned,
global_config,
event.user_id,
message.is_emoji,
interested_rate
await self.storage.store_message(message, chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_message(message)
reply_probability = await willing_manager.change_reply_willing_received(
chat_stream=chat,
topic=topic[0] if topic else None,
is_mentioned_bot=is_mentioned,
config=global_config,
is_emoji=message.is_emoji,
interested_rate=interested_rate,
)
current_willing = willing_manager.get_willing(event.group_id)
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info(
f"[{current_time}][{chat.group_info.group_id}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
)
response = None
response = ""
if random() < reply_probability:
tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message,message_id=think_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, platform=messageinfo.platform
)
thinking_time_point = round(time.time(), 2)
think_id = "mt" + str(thinking_time_point)
thinking_message = MessageThinking(
message_id=think_id, chat_stream=chat, bot_user_info=bot_user_info, reply=message
)
message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id)
response,raw_content = await self.gpt.generate_response(message)
willing_manager.change_reply_willing_sent(chat)
response, raw_content = await self.gpt.generate_response(message)
# print(f"response: {response}")
if response:
container = message_manager.get_container(event.group_id)
# print(f"有response: {response}")
container = message_manager.get_container(chat.stream_id)
thinking_message = None
# 找到message,删除
# print(f"开始找思考消息")
for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
# print(f"找到思考消息: {msg}")
thinking_message = msg
container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break
# 如果找不到思考消息,直接返回
if not thinking_message:
print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除")
logger.warning("未找到对应的思考消息,可能已超时被移除")
return
#记录开始思考的时间,避免从思考到回复的时间太久
# 记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
#计算打字时间1是为了模拟打字2是避免多条回复乱序
message_set = MessageSet(chat, think_id)
# 计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False
for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
#通过时间改变时间戳
# 通过时间改变时间戳
typing_time = calculate_typing_time(msg)
print(f"typing_time: {typing_time}")
accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
timepoint = thinking_time_point + accu_typing_time
message_segment = Seg(type="text", data=msg)
print(f"message_segment: {message_segment}")
bot_message = MessageSending(
message_id=think_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=timepoint, #记录了回复生成的时间
thinking_start_time=thinking_start_time, #记录了思考开始的时间
reply_message_id=message.message_id
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
is_emoji=False,
)
await bot_message.initialize()
print(f"bot_message: {bot_message}")
if not mark_head:
bot_message.is_head = True
mark_head = True
print(f"添加消息到message_set: {bot_message}")
message_set.add_message(bot_message)
#message_set 可以直接加入 message_manager
# message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
print("添加message_set到message_manager")
message_manager.add_message(message_set)
bot_response_time = tinking_time_point
bot_response_time = thinking_time_point
if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
# 检查是否 <没有找到> emoji
if emoji_raw != None:
emoji_path,discription = emoji_raw
emoji_path, description = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
emoji_cq = CQCode.create_emoji_cq(emoji_path)
if random() < 0.5:
bot_response_time = tinking_time_point - 1
bot_response_time = thinking_time_point - 1
else:
bot_response_time = bot_response_time + 1
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
processed_plain_text=emoji_cq,
detailed_plain_text=discription,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
message_segment = Seg(type="emoji", data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=False,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
)
await bot_message.initialize()
message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict={
'happy': 0.5,
'angry': -1,
'sad': -0.5,
'surprised': 0.2,
'disgusted': -1.5,
'fearful': -0.7,
'neutral': 0.1
logger.debug(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict = {
"happy": 0.5,
"angry": -1,
"sad": -0.5,
"surprised": 0.2,
"disgusted": -1.5,
"fearful": -0.7,
"neutral": 0.1,
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
await relationship_manager.update_relationship_value(
chat_stream=chat, relationship_value=valuedict[emotion[0]]
)
# 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id)
# willing_manager.change_reply_willing_after_sent(
# chat_stream=chat
# )
# 创建全局ChatBot实例
chat_bot = ChatBot()
chat_bot = ChatBot()

View File

@@ -0,0 +1,226 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
components = [
platform,
str(group_info.group_id)
]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@@ -1,50 +1,55 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Dict, List, Optional
import tomli
from loguru import logger
from packaging import version
from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet,InvalidSpecifier
from packaging.specifiers import SpecifierSet, InvalidSpecifier
@dataclass
class BotConfig:
"""机器人配置类"""
"""机器人配置类"""
INNER_VERSION: Version = None
BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
# 消息处理相关配置
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
emoji_chance: float = 0.2 # 发送表情包的基础概率
ENABLE_PIC_TRANSLATE: bool = True # 是否启用图片翻译
talk_allowed_groups = set()
talk_frequency_down_groups = set()
thinking_timeout: int = 100 # 思考时间
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3.5 # 降低回复频率的群组回复意愿降低系数
ban_user_id = set()
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_SAVE: bool = True # 偷表情包
EMOJI_CHECK: bool = False #是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
ban_words = set()
ban_msgs_regex = set()
max_response_length: int = 1024 # 最大回复长度
# 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
@@ -60,83 +65,86 @@ class BotConfig:
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
enable_kuuki_read: bool = True # 是否启用读空气功能
enable_debug_output: bool = False # 是否启用调试输出
keywords_reaction_rules = [] # 关键词回复规则
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
chinese_typo_enable=True # 是否启用中文错别字生成器
chinese_typo_error_rate=0.03 # 单字替换概率
chinese_typo_min_freq=7 # 最小字频阈值
chinese_typo_tone_error_rate=0.2 # 声调错误概率
chinese_typo_word_replace_rate=0.02 # 整词替换概率
keywords_reaction_rules = [] # 关键词回复规则
chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# 默认人设
PROMPT_PERSONALITY=[
PROMPT_PERSONALITY = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书",
"是一个女大学生你会刷b站对ACG文化感兴趣"
"是一个女大学生你会刷b站对ACG文化感兴趣",
]
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第种人格概率
PERSONALITY_3: float = 0.1 # 第种人格概率
PROMPT_SCHEDULE_GEN = "一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第种人格概率
PERSONALITY_2: float = 0.3 # 第种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet
Args:
value[str]: 版本表达式(字符串)
Returns:
SpecifierSet
SpecifierSet
"""
try:
converted = SpecifierSet(value)
except InvalidSpecifier as e:
logger.error(
f"{value} 分类使用了错误的版本约束表达式\n",
"请阅读 https://semver.org/lang/zh-CN/ 修改代码"
)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@classmethod
def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
Version
"""
if 'inner' in toml:
if "inner" in toml:
try:
config_version : str = toml["inner"]["version"]
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = { "version": "0.0.0" }
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
try:
ver = version.parse(config_version)
except InvalidVersion as e:
@@ -145,41 +153,41 @@ class BotConfig:
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n")
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config=parent['personality']
personality=personality_config.get('prompt_personality')
personality_config = parent["personality"]
personality = personality_config.get("prompt_personality")
if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)
logger.debug(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY = personality_config.get("prompt_personality", config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN = personality_config.get("prompt_schedule", config.PROMPT_SCHEDULE_GEN)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.PERSONALITY_1=personality_config.get('personality_1_probability',config.PERSONALITY_1)
config.PERSONALITY_2=personality_config.get('personality_2_probability',config.PERSONALITY_2)
config.PERSONALITY_3=personality_config.get('personality_3_probability',config.PERSONALITY_3)
config.PERSONALITY_1 = personality_config.get("personality_1_probability", config.PERSONALITY_1)
config.PERSONALITY_2 = personality_config.get("personality_2_probability", config.PERSONALITY_2)
config.PERSONALITY_3 = personality_config.get("personality_3_probability", config.PERSONALITY_3)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get("auto_save", config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
def cq_code(parent: dict):
cq_code_config = parent["cq_code"]
config.ENABLE_PIC_TRANSLATE = cq_code_config.get("enable_pic_translate", config.ENABLE_PIC_TRANSLATE)
def bot(parent: dict):
# 机器人基础配置
bot_config = parent["bot"]
@@ -187,16 +195,21 @@ class BotConfig:
config.BOT_QQ = int(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
if config.INNER_VERSION in SpecifierSet(">=0.0.5"):
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def response(parent: dict):
response_config = parent["response"]
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get(
"model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY
)
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
def model(parent: dict):
# 加载模型配置
model_config:dict = parent["model"]
model_config: dict = parent["model"]
config_list = [
"llm_reasoning",
@@ -208,29 +221,23 @@ class BotConfig:
"llm_emotion_judge",
"vlm",
"embedding",
"moderation"
"moderation",
]
for item in config_list:
if item in model_config:
cfg_item:dict = model_config[item]
cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name" : "",
"base_url" : "",
"key" : "",
"pri_in" : 0,
"pri_out" : 0
}
cfg_target = {"name": "", "base_url": "", "key": "", "pri_in": 0, "pri_out": 0}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name","pri_in","pri_out"]
pricing_item = ["pri_in","pri_out"]
stable_item = ["name", "pri_in", "pri_out"]
pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段
for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0
@@ -241,21 +248,19 @@ class BotConfig:
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
provider = cfg_item.get("provider")
if provider == None:
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config,item,cfg_target)
setattr(config, item, cfg_target)
else:
logger.error(f"模型 {item} 在config中不存在请检查")
raise KeyError(f"模型 {item} 在config中不存在请检查")
@@ -265,19 +270,30 @@ class BotConfig:
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
config.ban_words=msg_config.get("ban_words",config.ban_words)
config.ban_words = msg_config.get("ban_words", config.ban_words)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout)
config.response_willing_amplifier = msg_config.get("response_willing_amplifier", config.response_willing_amplifier)
config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier", config.response_interested_rate_amplifier)
config.response_willing_amplifier = msg_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = msg_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.6"):
config.ban_msgs_regex = msg_config.get("ban_msgs_regex", config.ban_msgs_regex)
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
# 在版本 >= 0.0.4 时才处理新增的配置项
if config.INNER_VERSION in SpecifierSet(">=0.0.4"):
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
@@ -294,8 +310,12 @@ class BotConfig:
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get("tone_error_rate", config.chinese_typo_tone_error_rate)
config.chinese_typo_word_replace_rate = chinese_typo_config.get("word_replace_rate", config.chinese_typo_word_replace_rate)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def groups(parent: dict):
groups_config = parent["groups"]
@@ -307,6 +327,7 @@ class BotConfig:
others_config = parent["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
config.enable_debug_output = others_config.get("enable_debug_output", config.enable_debug_output)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
@@ -314,60 +335,19 @@ class BotConfig:
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
include_configs = {
"personality": {
"func": personality,
"support": ">=0.0.0"
},
"emoji": {
"func": emoji,
"support": ">=0.0.0"
},
"cq_code": {
"func": cq_code,
"support": ">=0.0.0"
},
"bot": {
"func": bot,
"support": ">=0.0.0"
},
"response": {
"func": response,
"support": ">=0.0.0"
},
"model": {
"func": model,
"support": ">=0.0.0"
},
"message": {
"func": message,
"support": ">=0.0.0"
},
"memory": {
"func": memory,
"support": ">=0.0.0"
},
"mood": {
"func": mood,
"support": ">=0.0.0"
},
"keywords_reaction": {
"func": keywords_reaction,
"support": ">=0.0.2",
"necessary": False
},
"chinese_typo": {
"func": chinese_typo,
"support": ">=0.0.3",
"necessary": False
},
"groups": {
"func": groups,
"support": ">=0.0.0"
},
"others": {
"func": others,
"support": ">=0.0.0"
}
"personality": {"func": personality, "support": ">=0.0.0"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"cq_code": {"func": cq_code, "support": ">=0.0.0"},
"bot": {"func": bot, "support": ">=0.0.0"},
"response": {"func": response, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"message": {"func": message, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"groups": {"func": groups, "support": ">=0.0.0"},
"others": {"func": others, "support": ">=0.0.0"},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
@@ -379,10 +359,10 @@ class BotConfig:
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except(tomli.TOMLDecodeError) as e:
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
# 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict)
@@ -394,7 +374,7 @@ class BotConfig:
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if 'notice' in include_configs[key]:
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
@@ -406,31 +386,32 @@ class BotConfig:
f"当前程序仅支持以下版本范围: {group_specifierset}"
)
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") == False:
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.success(f"成功加载配置文件: {config_path}")
return config
return config
# 获取配置文件路径
bot_config_floder_path = BotConfig.get_config_dir()
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
logger.debug(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
print(f"异常的新鲜,异常的美味: {bot_config_path}")
logger.debug(f"异常的新鲜,异常的美味: {bot_config_path}")
logger.info("使用bot配置文件")
else:
# 配置文件不存在
@@ -439,8 +420,10 @@ else:
global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
logger.remove()
pass
# 调试输出功能
if global_config.enable_debug_output:
logger.remove()
logger.add(sys.stdout, level="DEBUG")

View File

@@ -1,23 +1,24 @@
import base64
import html
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional, Union
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .utils_image import storage_emoji, storage_image
from .utils_user import get_user_nickname
from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver()
config = driver.config
@@ -35,65 +36,80 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass
class CQCode:
"""
CQ码数据类用于存储和处理CQ码
属性:
type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
# raw_code: str
group_id: int
user_id: int
group_name: str = ""
user_nickname: str = ""
translated_plain_text: Optional[str] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
pass
async def translate(self):
"""根据CQ码类型进行相应的翻译处理"""
if self.type == 'text':
self.translated_plain_text = self.params.get('text', '')
elif self.type == 'image':
if self.params.get('sub_type') == '0':
self.translated_plain_text = await self.translate_image()
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text":
self.translated_segments = Seg(
type="text", data=self.params.get("text", "")
)
elif self.type == "image":
base64_data = self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else:
self.translated_segments = Seg(type="emoji", data=base64_data)
else:
self.translated_plain_text = await self.translate_emoji()
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname:
self.translated_plain_text = f"[@{user_nickname}]"
self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at":
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == "reply":
reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_plain_text = "@某人"
elif self.type == 'reply':
self.translated_plain_text = await self.translate_reply()
elif self.type == 'face':
face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]"
elif self.type == 'forward':
self.translated_plain_text = await self.translate_forward()
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward":
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_plain_text = f"[{self.type}]"
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self):
'''
"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
@@ -102,18 +118,18 @@ class CQCode:
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
'''
"""
# 腾讯专用请求头配置
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
'Accept': 'text/html, application/xhtml xml, */*',
'Accept-Encoding': 'gbk, GB2312',
'Accept-Language': 'zh-cn',
'Content-Type': 'application/x-www-form-urlencoded',
'Cache-Control': 'no-cache'
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
"Accept-Encoding": "gbk, GB2312",
"Accept-Language": "zh-cn",
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params['url'])
if not url.startswith(('http://', 'https://')):
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
# 创建专用会话
@@ -129,247 +145,214 @@ class CQCode:
headers=headers,
timeout=15,
allow_redirects=True,
stream=True # 流式传输避免大内存问题
stream=True, # 流式传输避免大内存问题
)
# 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型
content_type = response.headers.get('Content-Type', '')
if not content_type.startswith('image/'):
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8')
image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避
logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5**retry) # 指数退避
except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
except Exception:
logger.exception("[未知错误]")
return None
return None
async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码"""
if 'url' not in self.params:
return '[表情包]'
base64_str = self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params:
return None
return self.get_img()
async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]"
if "content" not in self.params:
return None
async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
try:
if 'content' not in self.params:
return '[转发消息]'
# 解析content内容需要先反转义
content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
content = self.unescape(self.params["content"])
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
return '[转发消息]'
logger.error(f"解析转发消息内容失败: {str(e)}")
return None
# 处理每条消息
formatted_messages = []
formatted_segments = []
for msg in messages:
sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
sender = msg.get("sender", {})
nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
message_array = msg.get("message", [])
if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array:
if message_part.get('type') == 'forward':
content = '[转发消息]'
if message_part.get("type") == "forward":
content_seg = Seg(type="text", data="[转发消息]")
break
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
if raw_message:
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
else:
content_seg = Seg(type="text", data="[空消息]")
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
content_seg = Seg(type="text", data="[空消息]")
formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg)
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type="text", data="\n"))
# 合并所有消息
combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
return formatted_segments
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]'
logger.error(f"处理转发消息失败: {str(e)}")
return None
async def translate_reply(self) -> str:
"""处理回复类型的CQ码"""
def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
# 创建Message对象
from .message import Message
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message is None:
return None
if self.reply_message.sender.user_id:
message_obj = Message(
user_id=self.reply_message.sender.user_id,
message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname),
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_id=self.group_id
group_info=GroupInfo(group_id=self.reply_message.group_id),
)
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else:
segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=[message_obj.message_segment]))
segments.append(Seg(type="text", data="]"))
return segments
else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空")
return '[回复某人消息]'
return None
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \
.replace('&#91;', '[') \
.replace('&#93;', ']') \
.replace('&amp;', '&')
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
return (
text.replace("&#44;", ",")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&")
)
class CQCode_tool:
@staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选)
Returns:
CQCode对象
"""
# 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text')
cq_type = cq_code.get("type", "text")
params = {}
if cq_type == 'text':
params['text'] = cq_code.get('data', {}).get('text', '')
if cq_type == "text":
params["text"] = cq_code.get("data", {}).get("text", "")
else:
params = cq_code.get('data', {})
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type,
params=params,
group_id=0,
user_id=0,
group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info,
reply_message=reply
)
# 进行翻译处理
await instance.translate()
instance.translate()
return instance
@staticmethod
@@ -383,5 +366,64 @@ class CQCode_tool:
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = (
abs_path.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool()

View File

@@ -1,9 +1,11 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional
from typing import Optional, Tuple
from loguru import logger
from nonebot import get_driver
@@ -11,34 +13,37 @@ from nonebot import get_driver
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64
from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
image_manager = ImageManager()
class EmojiManager:
_instance = None
EMOJI_DIR = "data/emoji" # 表情包存储目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度更少的token后续可以根据情绪来调整温度
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
os.makedirs(self.EMOJI_DIR, exist_ok=True)
def initialize(self):
"""初始化数据库连接和表情目录"""
if not self._initialized:
@@ -49,16 +54,16 @@ class EmojiManager:
self._initialized = True
# 启动时执行一次完整性检查
self.check_emoji_file_integrity()
except Exception as e:
logger.error(f"初始化表情管理器失败: {str(e)}")
except Exception:
logger.exception("初始化表情管理器失败")
def _ensure_db(self):
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self):
"""确保emoji集合存在并创建索引
@@ -74,9 +79,8 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
@@ -88,7 +92,7 @@ class EmojiManager:
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[str]:
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
"""根据文本内容获取相关表情包
Args:
text: 输入文本
@@ -102,9 +106,9 @@ class EmojiManager:
"""
try:
self._ensure_db()
# 获取文本的embedding
text_for_search= await self._get_kimoji_for_text(text)
text_for_search = await self._get_kimoji_for_text(text)
if not text_for_search:
logger.error("无法获取文本的情绪")
return None
@@ -112,15 +116,15 @@ class EmojiManager:
if not text_embedding:
logger.error("无法获取文本的embedding")
return None
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
return None
# 计算余弦相似度并排序
def cosine_similarity(v1, v2):
if not v1 or not v2:
@@ -131,25 +135,25 @@ class EmojiManager:
if norm_v1 == 0 or norm_v2 == 0:
return 0
return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度
emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis
]
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis:
if not top_10_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
@@ -157,57 +161,61 @@ class EmojiManager:
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success(f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'],"[ %s ]" % selected_emoji.get('description', '无描述')
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('description', '无描述')
except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}")
return None
return None
except Exception as e:
logger.error(f"获取表情包失败: {str(e)}")
return None
async def _get_emoji_description(self, image_base64: str) -> str:
"""获取表情包的标签"""
async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip('[]').replace('表情包:', '')
return description
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _check_emoji(self, image_base64: str) -> str:
try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _get_kimoji_for_text(self, text:str):
async def _get_kimoji_for_text(self, text: str):
try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
content, _ = await self.llm_emotion_judge.generate_response_async(prompt,temperature=1.5)
logger.info(f"输出描述: {content}")
return content
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def scan_new_emojis(self):
"""扫描新的表情包"""
try:
@@ -215,62 +223,122 @@ class EmojiManager:
os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
files_to_process = [f for f in os.listdir(emoji_dir) if
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
# 获取图片的base64编码和哈希值
image_base64 = image_path_to_base64(image_path)
if image_base64 is None:
os.remove(image_path)
continue
# 获取表情包的描述
description = await self._get_emoji_description(image_base64)
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述
description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {description}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}")
continue
logger.info(f"check通过 {check}")
embedding = await get_embedding(description)
if description is not None:
embedding = await get_embedding(description)
if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录
emoji_record = {
'filename': filename,
'path': image_path,
'embedding':embedding,
'description': description,
'embedding': embedding,
'discription': description,
'hash': image_hash,
'timestamp': int(time.time())
}
# 保存到数据库
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else:
logger.warning(f"跳过表情包: {filename}")
except Exception as e:
logger.error(f"扫描表情包失败: {str(e)}")
logger.error(traceback.format_exc())
except Exception:
logger.exception("扫描表情包失败")
async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包"""
while True:
print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
logger.info("开始扫描新表情包...")
await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self):
"""检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录
@@ -281,7 +349,7 @@ class EmojiManager:
all_emojis = list(self.db.db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
for emoji in all_emojis:
try:
if 'path' not in emoji:
@@ -289,27 +357,27 @@ class EmojiManager:
self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
# 检查文件是否存在
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.success(f"成功删除数据库记录: {emoji['_id']}")
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
else:
logger.error(f"删除数据库记录失败: {emoji['_id']}")
except Exception as item_error:
logger.error(f"处理表情包记录时出错: {str(item_error)}")
continue
# 验证清理结果
remaining_count = self.db.db.emoji.count_documents({})
if removed_count > 0:
@@ -317,7 +385,7 @@ class EmojiManager:
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
else:
logger.info(f"已检查 {total_count} 个表情包记录")
except Exception as e:
logger.error(f"检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
@@ -328,6 +396,8 @@ class EmojiManager:
await asyncio.sleep(interval_MINS * 60)
# 创建全局单例
emoji_manager = EmojiManager()
emoji_manager = EmojiManager()

View File

@@ -3,11 +3,12 @@ import time
from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from loguru import logger
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import Message
from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@@ -18,58 +19,89 @@ config = driver.config
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
self.model_r1 = LLM_request(
model=global_config.llm_reasoning,
temperature=0.7,
max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1'
self.current_model_type = "r1"
current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = 'v3'
elif (
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3
else:
self.current_model_type = 'r1_distill'
self.current_model_type = "r1_distill"
current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(
message, current_model
)
raw_content = model_response
# print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}")
if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response)
if model_response:
return model_response, raw_content
return None, raw_content
return model_response ,raw_content
return None,raw_content
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
async def _generate_response_with_model(
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
if message.user_cardname:
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
sender_name = (
message.chat_stream.user_info.user_nickname
or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass
# 构建prompt
prompt, prompt_check = await prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
relationship_value=relationship_value,
group_id=message.group_id
stream_id=message.chat_stream.stream_id,
)
# 读空气模块 简化逻辑,先停用
@@ -92,10 +124,10 @@ class ResponseGenerator:
# 生成回复
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception as e:
print(f"生成回复时出错: {e}")
except Exception:
logger.exception("生成回复时出错")
return None
# 保存到数据库
self._save_to_db(
message=message,
@@ -107,54 +139,73 @@ class ResponseGenerator:
reasoning_content=reasoning_content,
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
)
return content
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, reasoning_content: str,):
def _save_to_db(
self,
message: MessageRecv,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
'prompt_check': prompt_check
})
self.db.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
"user": sender_name,
"message": message.processed_plain_text,
"model": self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
"reasoning": reasoning_content,
"response": content,
"prompt": prompt,
"prompt_check": prompt_check,
}
)
async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签"""
try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好,不要输出其他内容:
内容:{content}
输出:
'''
"""
content, _ = await self.model_v25.generate_response(prompt)
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
content = content.strip()
if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content]
else:
return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签"""
if not content:
return None, []
processed_response = process_llm_response(content)
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response
@@ -172,7 +223,7 @@ class InitiativeMessageGenerate:
prompt_builder._build_initiative_prompt_select(message.group_id)
)
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
print(f"[DEBUG] {content_select} {reasoning}")
logger.debug(f"{content_select} {reasoning}")
topics_list = [dot[0] for dot in dots_for_select]
if content_select:
if content_select in topics_list:
@@ -185,12 +236,12 @@ class InitiativeMessageGenerate:
select_dot[1], prompt_template
)
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
print(f"[DEBUG] {content_check} {reasoning_check}")
logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower():
return None
prompt = prompt_builder._build_initiative_prompt(
select_dot, prompt_template, memory
)
content, reasoning = self.model_r1.generate_response_async(prompt)
print(f"[DEBUG] {content} {reasoning}")
logger.debug(f"[DEBUG] {content} {reasoning}")
return content

View File

@@ -1,14 +1,16 @@
import time
import html
import re
import json
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional
from typing import Dict, List, Optional
import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
Message = ForwardRef('Message') # 添加这行
from .utils_image import image_manager
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -16,216 +18,383 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
def __init__(
self,
message_id: str,
time: int,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
group_info=chat_stream.group_info,
user_info=user_info
)
# 调用父类初始化
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息
self.reply = reply
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
)
# 构建详细文本
if self.time is None:
self.time = int(time.time())
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
)
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
async def parse_message_segments(self, message: str) -> List[CQCode]:
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典,包含:
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
start = 0
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
message_segment = message_dict.get('message_segment', {})
if message_segment.get('data','') == '[json]':
# 提取json消息中的展示信息
pattern = r'\[CQ:json,data=(?P<json_data>.+?)\]'
match = re.search(pattern, message_dict.get('raw_message',''))
raw_json = html.unescape(match.group('json_data'))
try:
json_message = json.loads(raw_json)
except json.JSONDecodeError:
json_message = {}
message_segment['data'] = json_message.get('prompt','')
self.message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
self.raw_message = message_dict.get('raw_message')
self.message_id = message_id
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
# 思考状态相关属性
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
self.is_emoji=True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.interupt=False
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
reply_message_id: int = None # 存储 回复的 源消息ID
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False,
is_emoji: bool = False
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
def set_reply(self, reply: Optional['MessageRecv']) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False
) -> 'MessageSending':
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji
)
def to_dict(self):
ret= super().to_dict()
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id
self.user_id = user_id
def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注
self.messages: List[MessageSending] = []
self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None:
"""添加消息到集合只接受Message_Sending类型"""
if not isinstance(message, Message_Sending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
# 按时间排序
self.messages.sort(key=lambda x: x.time)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].time < target_time:
if self.messages[mid].message_info.time < target_time:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool:
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List['Seg']]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
type=data.get('type')
data=data.get('data')
if type == 'seglist':
data = [Seg.from_dict(seg) for seg in data]
return cls(
type=type,
data=data
)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str,int,None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message',None)
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@@ -0,0 +1,169 @@
import time
from dataclasses import dataclass
from typing import Dict, Optional
import urllib3
from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
raw_message: str,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
if group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
data: Dict
):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -2,224 +2,212 @@ import asyncio
import time
from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet
from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config
class Message_Sender:
"""发送器"""
def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0
self._current_bot = None
def set_bot(self, bot: Bot):
"""设置当前bot实例"""
self._current_bot = bot
async def send_group_message(
self,
group_id: int,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
) -> None:
if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text
# 如果需要回复
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
async def send_message(
self,
message: MessageSending,
) -> None:
"""发送消息"""
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ(
data=message_json
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
if message_send.message_info.group_info:
try:
await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
else:
try:
await self._current_bot.send_private_msg(
user_id=message.message_info.user_info.user_id,
message=message_send.raw_message,
auto_escape=False
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
class MessageContainer:
"""单个的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
"""单个聊天流的发送/思考消息容器"""
def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id
self.max_size = max_size
self.messages = []
self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]:
def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time()
timeout_messages = []
for msg in self.messages:
if isinstance(msg, Message_Sending):
if isinstance(msg, MessageSending):
if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float('inf')
earliest_message = None
for msg in self.messages:
for msg in self.messages:
msg_time = msg.thinking_start_time
if msg_time < earliest_time:
earliest_time = msg_time
earliest_message = msg
earliest_message = msg
return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet):
for single_message in message.messages:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False"""
try:
if message in self.messages:
self.messages.remove(message)
return True
return False
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}")
except Exception:
logger.exception("移除消息时发生错误")
return False
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息"""
return list(self.messages)
class MessageManager:
"""管理所有的消息容器"""
"""管理所有聊天流的消息容器"""
def __init__(self):
self.containers: Dict[int, MessageContainer] = {}
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage()
self._running = True
def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建的消息容器"""
if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id]
def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建聊天流的消息容器"""
if chat_id not in self.containers:
self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[chat_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id)
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message)
async def process_group_messages(self, group_id: int):
"""处理消息"""
# if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
async def process_chat_messages(self, chat_id: str):
"""处理聊天流消息"""
container = self.get_container(chat_id)
if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
# print(f"处理有message的容器chat_id: {chat_id}")
message_earliest = container.get_earliest_message()
#如果是思考消息
if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息
if isinstance(message_earliest, MessageThinking):
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True)
print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest)
else:# 如果不是message_thinking就只能是message_sending
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() >30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
else:
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_message(message_earliest.set_reply())
else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
#移除消息
if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None)
await message_sender.send_message(message_earliest)
await message_earliest.process()
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
container.remove_message(message_earliest)
#获取并处理超时消息
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
message_timeout = container.get_timeout_messages()
if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout:
if msg == message_earliest:
continue # 跳过已经处理过的消息
continue
try:
#发送
if msg.is_head and msg.update_thinking_time() >30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_message(msg.set_reply())
else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
await message_sender.send_message(msg)
#如果是表情包,则替换为"[表情包]"
if msg.is_emoji:
msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None)
# if msg.is_emoji:
# msg.processed_plain_text = "[表情包]"
await msg.process()
await self.storage.store_message(msg,msg.chat_stream, None)
# 安全地移除消息
if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}")
logger.warning("尝试删除不存在的消息")
except Exception:
logger.exception("处理超时消息时发生错误")
continue
async def start_processor(self):
"""启动消息处理器"""
while self._running:
await asyncio.sleep(1)
tasks = []
for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id))
for chat_id in self.containers.keys():
tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks)
# 创建全局消息管理器实例
message_manager = MessageManager()
# 创建全局发送器实例

View File

@@ -1,6 +1,7 @@
import random
import time
from typing import Optional
from loguru import logger
from ...common.database import Database
from ..memory_system.memory import hippocampus, memory_graph
@@ -8,6 +9,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import chat_manager
class PromptBuilder:
@@ -22,7 +24,7 @@ class PromptBuilder:
message_txt: str,
sender_name: str = "某人",
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt
Args:
@@ -33,57 +35,62 @@ class PromptBuilder:
Returns:
str: 构建好的prompt
"""
#先禁用关系
"""
# 先禁用关系
if 0 > 30:
relation_prompt = "关系特别特别好,你很喜欢喜欢他"
relation_prompt_2 = "热情发言或者回复"
elif 0 <-20:
elif 0 < -20:
relation_prompt = "关系很差,你很讨厌他"
relation_prompt_2 = "骂他"
else:
relation_prompt = "关系一般"
relation_prompt_2 = "发言或者回复"
#开始构建prompt
#心情
# 开始构建prompt
# 心情
mood_manager = MoodManager.get_instance()
mood_prompt = mood_manager.get_prompt()
#日程构建
# 日程构建
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
#知识构建
# 知识构建
start_time = time.time()
prompt_info = ''
promt_info_prompt = ''
prompt_info = await self.get_prompt_info(message_txt,threshold=0.5)
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[
知识],之后可能会用到-'''
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文
chat_in_group=True
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法
memory_prompt = ''
start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt,
@@ -91,67 +98,64 @@ class PromptBuilder:
similarity_threshold=0.4,
max_memory_num=5
)
if relevant_memories:
# 格式化记忆内容
memory_items = []
for memory in relevant_memories:
memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}")
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
# 打印调试信息
print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:")
logger.debug("[记忆检索]找到以下相关记忆:")
for memory in relevant_memories:
print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
end_time = time.time()
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}")
#激活prompt构建
logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建
activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot:
# is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认'
# else:
# is_bot_prompt = ''
if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 关键词检测与反应
keywords_reaction_prompt = ''
for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
print(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
keywords_reaction_prompt += rule.get("reaction", "") + ''
#人格选择
personality=global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3
prompt_personality = ''
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}'
personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[1]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。'''
#中文高手(新加的好玩功能)
# 中文高手(新加的好玩功能)
prompt_ger = ''
if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句'
@@ -159,23 +163,23 @@ class PromptBuilder:
prompt_ger += '你喜欢用反问句'
if random.random() < 0.01:
prompt_ger += '你喜欢用文言文'
#额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
#合并prompt
# 额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
# 合并prompt
prompt = ""
prompt += f"{prompt_info}\n"
prompt += f"{prompt_date}\n"
prompt += f"{chat_talking_prompt}\n"
prompt += f"{chat_talking_prompt}\n"
prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < probability_1: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
@@ -183,34 +187,36 @@ class PromptBuilder:
else: # 第三种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response
def _build_initiative_prompt_select(self,group_id):
prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt, prompt_check_if_response
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题
all_nodes=memory_graph.dots
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes)
nodes_for_select=random.sample(all_nodes,5)
topics=[info[0] for info in nodes_for_select]
infos=[info[1] for info in nodes_for_select]
all_nodes = memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select]
infos = [info[1] for info in nodes_for_select]
#激活prompt构建
# 激活prompt构建
activate_prompt = ''
activate_prompt = "以上是群里正在进行的聊天。"
personality=global_config.PROMPT_PERSONALITY
personality = global_config.PROMPT_PERSONALITY
prompt_personality = ''
personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格
@@ -219,32 +225,31 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}'''
topics_str=','.join(f"\"{topics}\"")
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular=f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select,nodes_for_select,prompt_regular
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
memory=random.sample(selected_node['memory_items'],3)
memory='\n'.join(memory)
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check,memory
def _build_initiative_prompt(self,selected_node,prompt_regular,memory):
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
topics_str = ','.join(f"\"{topics}\"")
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular = f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt_check(self, selected_node, prompt_regular):
memory = random.sample(selected_node['memory_items'], 3)
memory = '\n'.join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory
def _build_initiative_prompt(self, selected_node, prompt_regular, memory):
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
return prompt_for_initiative
async def get_prompt_info(self,message:str,threshold:float):
async def get_prompt_info(self, message: str, threshold: float):
related_info = ''
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold)
related_info += self.get_info_from_db(embedding, threshold=threshold)
return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
@@ -305,14 +310,15 @@ class PromptBuilder:
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:
return ''
# 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results)
prompt_builder = PromptBuilder()
prompt_builder = PromptBuilder()

View File

@@ -1,115 +1,177 @@
import asyncio
from typing import Optional
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo
from .chat_stream import ChatStream
class Impression:
traits: str = None
called: str = None
know_time: float = None
relationship_value: float = None
class Relationship:
user_id: int = None
# impression: Impression = None
# group_id: int = None
# group_name: str = None
platform: str = None
gender: str = None
age: int = None
nickname: str = None
relationship_value: float = None
saved = False
def __init__(self, user_id: int, data=None, **kwargs):
if isinstance(data, dict):
# 如果输入是字典,使用字典解析
self.user_id = data.get('user_id')
self.gender = data.get('gender')
self.age = data.get('age')
self.nickname = data.get('nickname')
self.relationship_value = data.get('relationship_value', 0.0)
self.saved = data.get('saved', False)
else:
# 如果是直接传入属性值
self.user_id = kwargs.get('user_id')
self.gender = kwargs.get('gender')
self.age = kwargs.get('age')
self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False)
def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
self.platform=chat.platform if chat else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) if data else 0
self.age=data.get('age',0) if data else 0
self.gender=data.get('gender','') if data else ''
class RelationshipManager:
def __init__(self):
self.relationships: dict[int, Relationship] = {}
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self, user_id: int, data=None, **kwargs):
async def update_relationship(self,
chat_stream:ChatStream,
data: dict = None,
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
chat_stream: 聊天流对象
data: 字典格式的数据(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if chat_stream.user_info is not None:
user_id = chat_stream.user_info.user_id
platform = chat_stream.user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
# 如果存在,更新现有对象
if isinstance(data, dict):
for key, value in data.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
else:
for key, value in kwargs.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
# 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
self.relationships[user_id] = relationship
# 更新 id_name_nickname_table
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
if chat_stream.user_info is not None:
relationship = Relationship(chat=chat_stream, **kwargs)
else:
raise ValueError("必须提供user_id或user_info")
self.relationships[key] = relationship
# 保存到数据库
await self.storage_relationship(relationship)
relationship.saved = True
return relationship
async def update_relationship_value(self, user_id: int, **kwargs):
async def update_relationship_value(self,
chat_stream:ChatStream,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
for key, value in kwargs.items():
if key == 'relationship_value':
for k, value in kwargs.items():
if k == 'relationship_value':
relationship.relationship_value += value
await self.storage_relationship(relationship)
relationship.saved = True
return relationship
else:
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
return None
def get_relationship(self, user_id: int) -> Optional[Relationship]:
"""获取用户关系对象"""
if user_id in self.relationships:
return self.relationships[user_id]
def get_relationship(self,
chat_stream:ChatStream) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
else:
return 0
async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True
self.relationships[rela.user_id] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
@@ -117,39 +179,37 @@ class RelationshipManager:
all_relationships = db.db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
user_id = data['user_id']
relationship = await self.load_relationship(data)
self.relationships[user_id] = relationship
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
await self.load_relationship(data)
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
while True:
print("\033[1;32m[关系管理]\033[0m 正在自动保存关系")
logger.debug("正在自动保存关系")
await asyncio.sleep(300) # 等待300秒(5分钟)
await self._save_all_relationships()
async def _save_all_relationships(self):
"""将所有关系数据保存到数据库"""
"""将所有关系数据保存到数据库"""
# 保存所有关系数据
for userid, relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items():
if not relationship.saved:
relationship.saved = True
await self.storage_relationship(relationship)
async def storage_relationship(self,relationship: Relationship):
"""
将关系记录存储到数据库中
"""
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
user_id = relationship.user_id
platform = relationship.platform
nickname = relationship.nickname
relationship_value = relationship.relationship_value
gender = relationship.gender
age = relationship.age
saved = relationship.saved
db = Database.get_instance()
db.db.relationships.update_one(
{'user_id': user_id},
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,
'nickname': nickname,
'relationship_value': relationship_value,
'gender': gender,
@@ -159,14 +219,38 @@ class RelationshipManager:
upsert=True
)
def get_name(self, user_id: int) -> str:
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 确保user_id是整数类型
user_id = int(user_id)
if user_id in self.relationships:
return self.relationships[user_id].nickname
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
else:
return "某人"
relationship_manager = RelationshipManager()
relationship_manager = RelationshipManager()

View File

@@ -1,49 +1,30 @@
from typing import Optional
from typing import Optional, Union
from ...common.database import Database
from .message import Message
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
if not message.is_emoji:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id":chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
else:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
self.db.db.messages.insert_one(message_data)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
except Exception:
logger.exception("存储消息失败")
# 如果需要其他存储相关的函数,可以在这里添加
# 如果需要其他存储相关的函数,可以在这里添加

View File

@@ -4,9 +4,11 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request
from .config import global_config
from loguru import logger
driver = get_driver()
config = driver.config
config = driver.config
class TopicIdentifier:
def __init__(self):
@@ -23,19 +25,20 @@ class TopicIdentifier:
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_topic_judge.generate_response(prompt)
if not topic:
print("\033[1;31m[错误]\033[0m LLM API 返回为空")
logger.error("LLM API 返回为空")
return None
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
logger.info(f"主题: {topic_list}")
return topic_list if topic_list else None
topic_identifier = TopicIdentifier()
topic_identifier = TopicIdentifier()

View File

@@ -7,65 +7,44 @@ from typing import Dict, List
import jieba
import numpy as np
from nonebot import get_driver
from loguru import logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message import Message
from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}")
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}")
logger.debug(f"result: {result}")
return result
def is_mentioned_bot_in_message(message: Message) -> bool:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
for keyword in keywords:
if keyword in message.processed_plain_text:
return True
return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
for nickname in nicknames:
if nickname in message.processed_plain_text:
return True
return False
@@ -98,40 +77,45 @@ def calculate_information_content(text):
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
Returns:
list: 消息记录列表,每个记录包含时间和文本信息
"""
chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_id = closest_record['chat_id'] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
{
"time": {"$gt": closest_time},
"chat_id": chat_id # 添加chat_id过滤
}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
# 转换记录格式
formatted_records = []
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
# print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
# print(f"消息已读取3次跳过")
return ''
formatted_records.append({
'time': record["time"],
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
})
return formatted_records
return []
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@@ -145,38 +129,31 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
{"chat_id": chat_id},
).sort("time", -1).limit(limit))
if not recent_messages:
return []
# 转换为 Message对象列表
from .message import Message
message_objects = []
for msg_data in recent_messages:
try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
detailed_plain_text=msg_data.get("detailed_plain_text", "")
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
logger.warning("数据库中存在无效的消息")
continue
# 按时间正序排列
@@ -184,13 +161,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"chat_id":1,
"chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}
@@ -292,11 +270,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}")
logger.info(f"处理后的句子: {sentences_done}")
return sentences_done
def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯
@@ -324,11 +301,10 @@ def random_remove_punctuation(text: str) -> str:
return result
def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content)
if len(text) > 200:
print(f"回复过长 ({len(text)} 字符),返回默认回复")
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说']
# 处理长消息
typo_generator = ChineseTypoGenerator(
@@ -348,9 +324,9 @@ def process_llm_response(text: str) -> List[str]:
else:
sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条
if len(sentences) > 5:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences
@@ -372,15 +348,15 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_
mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1/typing_speed_multiplier
english_time *= 1/typing_speed_multiplier
chinese_time *= 1 / typing_speed_multiplier
english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff')
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:

View File

@@ -1,296 +1,353 @@
import base64
import io
import os
import time
import zlib # 用于 CRC32
import aiohttp
import hashlib
from typing import Optional, Union
from loguru import logger
from nonebot import get_driver
from PIL import Image
from ...common.database import Database
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args:
base64_data: base64编码的图片数据
max_size: 最大文件大小KB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保图片目录存在
images_dir = "data/images"
os.makedirs(images_dir, exist_ok=True)
# 连接数据库
db = Database(
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片
collection = db.db['images']
existing_image = collection.find_one({'hash': hash_value})
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
# 如果已经小于目标大小,直接使用原图
if current_size <= max_size:
compressed_data = image_data
else:
# 压缩逻辑
# 先缩放到50%
new_width = int(img.width * 0.5)
new_height = int(img.height * 0.5)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 如果缩放后的最大边长仍然大于400继续缩放
max_dimension = 400
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 使用固定质量参数压缩
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(images_dir, filename)
# 保存文件
with open(image_path, "wb") as f:
f.write(compressed_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
try:
# 准备数据库记录
image_record = {
'filename': filename,
'path': image_path,
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
}
# 保存记录
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes:
"""
存储表情包到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
if not global_config.EMOJI_SAVE:
return image_data
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(emoji_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
emoji_path = os.path.join(emoji_dir, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
image_dir = "data/image"
os.makedirs(image_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(image_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji''image')
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
)
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
"""
try:
# 转换为字节格式
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return None
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return None
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'url': url,
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
return file_path
except Exception as e:
logger.error(f"保存图像失败: {str(e)}")
return None
async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重)
Args:
url: 图像URL
Returns:
str: 本地文件路径,不存在返回None
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 下载图像
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
except Exception as e:
logger.error(f"获取图像失败: {str(e)}")
return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
return base64.b64encode(compressed_data).decode('utf-8')
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'emoji',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码

View File

@@ -1,10 +1,15 @@
import asyncio
from typing import Dict
from .config import global_config
from .chat_stream import ChatStream
class WillingManager:
def __init__(self):
self.group_reply_willing = {} # 存储每个的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None
self._started = False
@@ -12,20 +17,38 @@ class WillingManager:
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(5)
for group_id in self.group_reply_willing:
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self, group_id: int) -> float:
"""获取指定群组的回复意愿"""
return self.group_reply_willing.get(group_id, 0)
def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定聊天流的回复意愿"""
stream = chat_stream
if stream:
return self.chat_reply_willing.get(stream.stream_id, 0)
return 0
def set_willing(self, group_id: int, willing: float):
"""设置指定群组的回复意愿"""
self.group_reply_willing[group_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
"""改变指定群组的回复意愿并返回回复概率"""
current_willing = self.group_reply_willing.get(group_id, 0)
async def change_reply_willing_received(self,
chat_stream:ChatStream,
topic: str = None,
is_mentioned_bot: bool = False,
config = None,
is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流
stream = chat_stream
chat_id = stream.stream_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
# print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0:
@@ -49,31 +72,33 @@ class WillingManager:
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
reply_probability = max((current_willing - 0.45) * 2, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
# 检查群组权限(如果是群聊)
if chat_stream.group_info:
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
reply_probability = min(reply_probability, 1)
if reply_probability < 0:
reply_probability = 0
self.group_reply_willing[group_id] = min(current_willing, 3.0)
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
return reply_probability
def change_reply_willing_sent(self, group_id: int):
"""开始思考后降低群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
self.group_reply_willing[group_id] = max(0, current_willing - 2)
def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int):
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1:
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""
@@ -83,4 +108,4 @@ class WillingManager:
self._started = True
# 创建全局实例
willing_manager = WillingManager()
willing_manager = WillingManager()

View File

@@ -0,0 +1,10 @@
from nonebot import get_app
from .api import router
from loguru import logger
# 获取主应用实例并挂载路由
app = get_app()
app.include_router(router, prefix="/api")
# 打印日志方便确认API已注册
logger.success("配置重载API已注册可通过 /api/reload-config 访问")

View File

@@ -0,0 +1,17 @@
from fastapi import APIRouter, HTTPException
from src.plugins.chat.config import BotConfig
import os
# 创建APIRouter而不是FastAPI实例
router = APIRouter()
@router.post("/reload-config")
async def reload_config():
try:
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
global_config = BotConfig.load_config(config_path=bot_config_path)
return {"message": "配置重载成功", "status": "success"}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}")

View File

@@ -0,0 +1,3 @@
import requests
response = requests.post("http://localhost:8080/api/reload-config")
print(response.json())

View File

@@ -1,198 +0,0 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.dev")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
)
class KnowledgeLibrary:
def __init__(self):
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self):
"""处理raw_info目录下的所有txt文件"""
for filename in os.listdir(self.raw_info_dir):
if filename.endswith('.txt'):
file_path = os.path.join(self.raw_info_dir, filename)
self.process_single_file(file_path)
def process_single_file(self, file_path: str):
"""处理单个文件"""
try:
# 检查文件是否已处理
if self.db.db.processed_files.find_one({"file_path": file_path}):
print(f"文件已处理过,跳过: {file_path}")
return
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按1024字符分段
segments = [content[i:i+600] for i in range(0, len(content), 600)]
# 处理每个分段
for segment in segments:
if not segment.strip(): # 跳过空段
continue
# 获取embedding
embedding = self.get_embedding(segment)
if not embedding:
continue
# 存储到数据库
doc = {
"content": segment,
"embedding": embedding,
"file_path": file_path,
"segment_length": len(segment)
}
# 使用文本内容的哈希值作为唯一标识
content_hash = hash(segment)
# 更新或插入文档
self.db.db.knowledges.update_one(
{"content_hash": content_hash},
{"$set": doc},
upsert=True
)
# 记录文件已处理
self.db.db.processed_files.insert_one({
"file_path": file_path,
"processed_time": time.time()
})
print(f"成功处理文件: {file_path}")
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
# 测试知识库功能
print("开始处理知识库文件...")
knowledge_library.process_files()
# 测试搜索功能
test_query = "麦麦评价一下僕と花"
print(f"\n搜索与'{test_query}'相似的内容:")
results = knowledge_library.search_similar_segments(test_query)
for result in results:
print(f"相似度: {result['similarity']:.4f}")
print(f"内容: {result['content'][:100]}...")
print("-" * 50)

View File

View File

@@ -7,23 +7,27 @@ import jieba
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from loguru import logger
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
@@ -37,7 +41,7 @@ class Memory_graph:
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
@@ -45,20 +49,20 @@ class Memory_graph:
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
@@ -69,7 +73,7 @@ class Memory_graph:
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
@@ -84,42 +88,44 @@ class Memory_graph:
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
@@ -159,145 +165,86 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
logger.debug("第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
logger.debug(item)
logger.debug("第二层记忆:")
for item in second_layer_items:
print(item)
logger.debug(item)
else:
print("未找到相关记忆。")
logger.debug("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = G.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
for node in nodes:
degree = G.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50)
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
node_size=200,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
logger.debug("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 计算节点大小和颜色
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
@@ -307,7 +254,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
@@ -315,37 +262,38 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio ** 2) # 使用方函数使差异明显
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
red = min(1.0, degree / max_degree)
r = (degree / max_degree) ** 0.3
red = min(1.0, r)
# 蓝色分量随着度数减少而增加
blue = 1.0 - red
color = (red, 0, blue)
blue = max(0.0, 1 - red)
# blue = 1
color = (red, 0.1, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.7)
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.9)
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()
main()

File diff suppressed because it is too large Load Diff

View File

@@ -10,12 +10,15 @@ from pathlib import Path
import matplotlib.pyplot as plt
import networkx as nx
import pymongo
from dotenv import load_dotenv
from loguru import logger
import jieba
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.plugins.memory_system.offline_llm import LLMModel
@@ -34,45 +37,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
@@ -86,23 +50,26 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
list: 消息记录字典列表,每个字典包含消息内容和时间信息
"""
chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
group_id = closest_record['group_id']
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
for record in records:
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
if current_memorized > 3:
print("消息已读取3次跳过")
return ''
@@ -112,11 +79,14 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
# 添加到记录列表中
chat_records.append({
'text': record["detailed_plain_text"],
'time': record["time"],
'group_id': record["group_id"]
})
return chat_text
print("消息已读取3次跳过")
return ''
return chat_records
class Memory_graph:
def __init__(self):
@@ -195,7 +165,7 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
@@ -205,22 +175,34 @@ class Hippocampus:
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
"""获取记忆样本
Returns:
list: 消息记录列表,每个元素是一个消息记录字典列表
"""
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return [chat for chat in chat_text if chat]
chat_samples = []
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
return chat_samples
def calculate_topic_num(self,text, compress_rate):
"""计算文本的话题数量"""
@@ -231,16 +213,49 @@ class Hippocampus:
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
return topic_num
async def memory_compress(self, input_text, compress_rate=0.1):
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩消息记录为记忆
Args:
messages: 消息记录字典列表每个字典包含text和time字段
compress_rate: 压缩率
Returns:
set: (话题, 记忆) 元组集合
"""
if not messages:
return set()
# 合并消息文本,同时保留时间信息
input_text = ""
time_info = ""
# 计算最早和最晚时间
earliest_time = min(msg['time'] for msg in messages)
latest_time = max(msg['time'] for msg in messages)
earliest_dt = datetime.datetime.fromtimestamp(earliest_time)
latest_dt = datetime.datetime.fromtimestamp(latest_time)
# 如果是同一年
if earliest_dt.year == latest_dt.year:
earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%m-%d %H:%M:%S")
time_info += f"是在{earliest_dt.year}年,{earliest_str}{latest_str} 的对话:\n"
else:
earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S")
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages:
input_text += f"{msg['text']}\n"
print(input_text)
topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
# 修改话题处理逻辑
# 定义需要过滤的关键词
filter_keywords = ['表情包', '图片', '回复', '聊天记录']
# 过滤topics
filter_keywords = ['表情包', '图片', '回复', '聊天记录']
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
@@ -250,7 +265,7 @@ class Hippocampus:
# 创建所有话题的请求任务
tasks = []
for topic in filtered_topics:
topic_what_prompt = self.topic_what(input_text, topic)
topic_what_prompt = self.topic_what(input_text, topic , time_info)
# 创建异步任务
task = self.llm_model_small.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task))
@@ -267,37 +282,35 @@ class Hippocampus:
async def operation_build_memory(self, chat_size=12):
# 最近消息获取频率
time_frequency = {'near': 3, 'mid': 8, 'far': 5}
memory_sample = self.get_memory_sample(chat_size, time_frequency)
memory_samples = self.get_memory_sample(chat_size, time_frequency)
all_topics = [] # 用于存储所有话题
for i, input_text in enumerate(memory_sample, 1):
for i, messages in enumerate(memory_samples, 1):
# 加载进度可视化
all_topics = []
progress = (i / len(memory_sample)) * 100
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_sample))
filled_length = int(bar_length * i // len(memory_samples))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set()
# 生成压缩后记忆
compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate)
compressed_memory = await self.memory_compress(messages, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
# 将记忆加入到图谱中
for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}")
self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题
all_topics.append(topic)
# 连接相关话题
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db()
@@ -375,7 +388,7 @@ class Hippocampus:
if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加
logger.info(f"添加新节点: {concept}")
# logger.info(f"添加新节点: {concept}")
node_data = {
'concept': concept,
'memory_items': memory_items,
@@ -389,7 +402,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
logger.info(f"更新节点内容: {concept}")
# logger.info(f"更新节点内容: {concept}")
self.memory_graph.db.db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
@@ -402,7 +415,7 @@ class Hippocampus:
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
logger.info(f"删除多余节点: {db_node['concept']}")
# logger.info(f"删除多余节点: {db_node['concept']}")
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
@@ -460,9 +473,10 @@ class Hippocampus:
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt
def topic_what(self,text, topic):
def topic_what(self,text, topic, time_info):
# prompt = f'这是一段文字:{text}。我想知道这段文字里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
# 获取当前时间
prompt = f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt
def remove_node_from_db(self, topic):
@@ -597,7 +611,7 @@ class Hippocampus:
print(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(merged_text, 0.1)
compressed_memories = await self.memory_compress(selected_memories, 0.1)
# 从原记忆列表中移除被选中的记忆
for memory in selected_memories:
@@ -647,6 +661,164 @@ class Hippocampus:
else:
print("\n本次检查没有需要合并的节点")
async def _identify_topics(self, text: str) -> list:
"""从文本中识别可能的主题"""
topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(text, 5))
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
return topics
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
"""查找与给定主题相似的记忆主题"""
all_memory_topics = list(self.memory_graph.G.nodes())
all_similar_topics = []
for topic in topics:
if debug_info:
pass
topic_vector = text_to_vector(topic)
has_similar_topic = False
for memory_topic in all_memory_topics:
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= similarity_threshold:
has_similar_topic = True
all_similar_topics.append((memory_topic, similarity))
return all_similar_topics
def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list:
"""获取相似度最高的主题"""
seen_topics = set()
top_topics = []
for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True):
if topic not in seen_topics and len(top_topics) < max_topics:
seen_topics.add(topic)
top_topics.append((topic, score))
return top_topics
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度"""
logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
identified_topics = await self._identify_topics(text)
if not identified_topics:
return 0
all_similar_topics = self._find_similar_topics(
identified_topics,
similarity_threshold=similarity_threshold,
debug_info="记忆激活"
)
if not all_similar_topics:
return 0
top_topics = self._get_top_topics(all_similar_topics, max_topics)
if len(top_topics) == 1:
topic, score = top_topics[0]
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
activation = int(score * 50 * penalty)
print(f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
return activation
matched_topics = set()
topic_similarities = {}
for memory_topic, similarity in top_topics:
memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1))
for input_topic in identified_topics:
topic_vector = text_to_vector(input_topic)
memory_vector = text_to_vector(memory_topic)
all_words = set(topic_vector.keys()) | set(memory_vector.keys())
v1 = [topic_vector.get(word, 0) for word in all_words]
v2 = [memory_vector.get(word, 0) for word in all_words]
sim = cosine_similarity(v1, v2)
if sim >= similarity_threshold:
matched_topics.add(input_topic)
adjusted_sim = sim * penalty
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
print(f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
topic_match = len(matched_topics) / len(identified_topics)
average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0
activation = int((topic_match + average_similarities) / 2 * 100)
print(f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
return activation
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5) -> list:
"""根据输入文本获取相关的记忆内容"""
identified_topics = await self._identify_topics(text)
all_similar_topics = self._find_similar_topics(
identified_topics,
similarity_threshold=similarity_threshold,
debug_info="记忆检索"
)
relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
relevant_memories = []
for topic, score in relevant_topics:
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer:
if len(first_layer) > max_memory_num/2:
first_layer = random.sample(first_layer, max_memory_num//2)
for memory in first_layer:
relevant_memories.append({
'topic': topic,
'similarity': score,
'content': memory
})
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
if len(relevant_memories) > max_memory_num:
relevant_memories = random.sample(relevant_memories, max_memory_num)
return relevant_memories
def segment_text(text):
"""使用jieba进行文本分词"""
seg_text = list(jieba.cut(text))
return seg_text
def text_to_vector(text):
"""将文本转换为词频向量"""
words = segment_text(text)
vector = {}
for word in words:
vector[word] = vector.get(word, 0) + 1
return vector
def cosine_similarity(v1, v2):
"""计算两个向量的余弦相似度"""
dot_product = sum(a * b for a, b in zip(v1, v2))
norm1 = math.sqrt(sum(a * a for a in v1))
norm2 = math.sqrt(sum(b * b for b in v2))
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
@@ -732,59 +904,67 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare['do_build_memory']:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare['do_query']:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
@@ -799,9 +979,6 @@ async def main():
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,11 @@ from typing import Tuple, Union
import aiohttp
from loguru import logger
from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver()
config = driver.config
@@ -28,10 +29,10 @@ class LLM_request:
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"]
self.params = kwargs
self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0)
# 获取数据库实例
self.db = Database.get_instance()
self._init_database()
@@ -44,12 +45,12 @@ class LLM_request:
self.db.db.llm_usage.create_index([("model_name", 1)])
self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败: {e}")
except Exception:
logger.error("创建数据库索引失败")
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat",
endpoint: str = "/chat/completions"):
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat",
endpoint: str = "/chat/completions"):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
@@ -79,8 +80,8 @@ class LLM_request:
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {e}")
except Exception:
logger.error("记录token使用情况失败")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
@@ -140,12 +141,12 @@ class LLM_request:
}
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
#判断是否为流式
# 判断是否为流式
stream_mode = self.params.get("stream", False)
if self.params.get("stream", False) is True:
logger.info(f"进入流式输出模式发送请求到URL: {api_url}")
logger.debug(f"进入流式输出模式发送请求到URL: {api_url}")
else:
logger.info(f"发送请求到URL: {api_url}")
logger.debug(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}")
# 构建请求体
@@ -158,7 +159,7 @@ class LLM_request:
try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
#似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode:
headers["Accept"] = "text/event-stream"
@@ -182,19 +183,38 @@ class LLM_request:
continue
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
if response.status == 403 :
if global_config.llm_normal == "Pro/deepseek-ai/DeepSeek-V3":
logger.error("可能是没有给硅基流动充钱普通模型自动退化至非Pro模型反应速度可能会变慢")
global_config.llm_normal = "deepseek-ai/DeepSeek-V3"
if global_config.llm_reasoning == "Pro/deepseek-ai/DeepSeek-R1":
logger.error("可能是没有给硅基流动充钱推理模型自动退化至非Pro模型反应速度可能会变慢")
global_config.llm_reasoning = "deepseek-ai/DeepSeek-R1"
if response.status == 403:
#只针对硅基流动的V3和R1进行降级处理
if self.model_name.startswith(
"Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get('name') == old_model_name:
global_config.llm_normal['name'] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get('name') == old_model_name:
global_config.llm_reasoning['name'] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
# 更新payload中的模型名
if payload and 'model' in payload:
payload['model'] = self.model_name
# 重新尝试请求
retry -= 1 # 不计入重试次数
continue
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status()
#将流式输出转化为非流式输出
# 将流式输出转化为非流式输出
if stream_mode:
flag_delta_content_finished = False
accumulated_content = ""
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
@@ -206,13 +226,25 @@ class LLM_request:
break
try:
chunk = json.loads(data_str)
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
except Exception as e:
logger.error(f"解析流式输出错误: {e}")
if flag_delta_content_finished:
usage = chunk.get("usage", None) # 获取tokn用量
else:
delta = chunk["choices"][0]["delta"]
delta_content = delta.get("content")
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
# 检测流式输出文本是否结束
finish_reason = chunk["choices"][0]["finish_reason"]
if finish_reason == "stop":
usage = chunk.get("usage", None)
if usage:
break
# 部分平台在文本输出结束前不会返回token用量此时需要再获取一次chunk
flag_delta_content_finished = True
except Exception:
logger.exception("解析流式输出错误")
content = accumulated_content
reasoning_content = ""
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
@@ -220,12 +252,15 @@ class LLM_request:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "usage": usage}
return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
else:
result = await response.json()
# 使用自定义处理器或默认处理
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
except Exception as e:
if retry < policy["max_retries"] - 1:
@@ -239,8 +274,8 @@ class LLM_request:
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def _transform_parameters(self, params: dict) ->dict:
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数:
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temprature' 参数,
@@ -249,7 +284,8 @@ class LLM_request:
# 复制一份参数,避免直接修改原始数据
new_params = dict(params)
# 定义需要转换的模型列表
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"]
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"]
if self.model_name.lower() in models_needing_transformation:
# 删除 'temprature' 参数(如果存在)
new_params.pop("temperature", None)
@@ -285,13 +321,13 @@ class LLM_request:
**params_copy
}
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
def _default_response_handler(self, result: dict, user_id: str = "system",
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
def _default_response_handler(self, result: dict, user_id: str = "system",
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
"""默认响应解析"""
if "choices" in result and result["choices"]:
message = result["choices"][0]["message"]
@@ -336,15 +372,15 @@ class LLM_request:
"""构建请求头"""
if no_key:
return {
"Authorization": f"Bearer **********",
"Authorization": "Bearer **********",
"Content-Type": "application/json"
}
else:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 防止小朋友们截图自己的key
}
# 防止小朋友们截图自己的key
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的异步响应"""
@@ -391,6 +427,7 @@ class LLM_request:
Returns:
list: embedding向量如果失败则返回None
"""
def embedding_handler(result):
"""处理响应"""
if "data" in result and len(result["data"]) > 0:
@@ -413,3 +450,77 @@ class LLM_request:
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data

View File

@@ -4,7 +4,7 @@ import time
from dataclasses import dataclass
from ..chat.config import global_config
from loguru import logger
@dataclass
class MoodState:
@@ -210,7 +210,7 @@ class MoodManager:
def print_mood_status(self) -> None:
"""打印当前情绪状态"""
print(f"\033[1;35m[情绪状态]\033[0m 愉悦度: {self.current_mood.valence:.2f}, "
logger.info(f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
f"唤醒度: {self.current_mood.arousal:.2f}, "
f"心情: {self.current_mood.text}")

View File

@@ -1,3 +1,4 @@
import os
import datetime
import json
from typing import Dict, Union
@@ -13,21 +14,21 @@ from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
Database.initialize(
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator:
def __init__(self):
#根据global_config.llm_normal这一字典配置指定模型
# 根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9)
self.db = Database.get_instance()
self.today_schedule_text = ""
self.today_schedule = {}
@@ -35,39 +36,41 @@ class ScheduleGenerator:
self.tomorrow_schedule = {}
self.yesterday_schedule_text = ""
self.yesterday_schedule = {}
async def initialize(self):
today = datetime.datetime.now()
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today)
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,read_only=True)
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,
read_only=True)
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(
target_date=yesterday, read_only=True)
async def generate_daily_schedule(self, target_date: datetime.datetime = None, read_only: bool = False) -> Dict[
str, str]:
date_str = target_date.strftime("%Y-%m-%d")
weekday = target_date.strftime("%A")
schedule_text = str
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
if existing_schedule:
print(f"{date_str}的日程已存在:")
logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"]
# print(self.schedule_text)
elif read_only == False:
print(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:"""+\
"""
elif not read_only:
logger.debug(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:""" + \
"""
1. 早上的学习和工作安排
2. 下午的活动和任务
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制格式为{"时间": "活动","时间": "活动",...}。"""
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释不要添加任何markdown或代码块样式时间采用24小时制格式为{"时间": "活动","时间": "活动",...}。"""
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
@@ -76,36 +79,35 @@ class ScheduleGenerator:
schedule_text = "生成日程时出错了"
# print(self.schedule_text)
else:
print(f"{date_str}的日程不存在。")
logger.debug(f"{date_str}的日程不存在。")
schedule_text = "忘了"
return schedule_text,None
return schedule_text, None
schedule_form = self._parse_schedule(schedule_text)
return schedule_text,schedule_form
return schedule_text, schedule_form
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
"""解析日程文本,转换为时间和活动的字典"""
try:
try:
schedule_dict = json.loads(schedule_text)
return schedule_dict
except json.JSONDecodeError as e:
print(schedule_text)
print(f"解析日程失败: {str(e)}")
except json.JSONDecodeError:
logger.exception("解析日程失败: {}".format(schedule_text))
return False
def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间"""
return datetime.datetime.strptime(time_str, "%H:%M")
def get_current_task(self) -> str:
"""获取当前时间应该进行的任务"""
current_time = datetime.datetime.now().strftime("%H:%M")
# 找到最接近当前时间的任务
closest_time = None
min_diff = float('inf')
# 检查今天的日程
if not self.today_schedule:
return "摸鱼"
@@ -114,7 +116,7 @@ class ScheduleGenerator:
if closest_time is None or diff < min_diff:
closest_time = time_str
min_diff = diff
# 检查昨天的日程中的晚间任务
if self.yesterday_schedule:
for time_str in self.yesterday_schedule.keys():
@@ -125,17 +127,17 @@ class ScheduleGenerator:
closest_time = time_str
min_diff = diff
return closest_time, self.yesterday_schedule[closest_time]
if closest_time:
return closest_time, self.today_schedule[closest_time]
return "摸鱼"
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1=="24:00":
time1="23:59"
if time2=="24:00":
time2="23:59"
if time1 == "24:00":
time1 = "23:59"
if time2 == "24:00":
time2 = "23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
@@ -146,17 +148,18 @@ class ScheduleGenerator:
diff -= 1440 # 减一天的分钟
# print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟")
return diff
def print_schedule(self):
"""打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text):
print("今日日程有误,将在下次运行时重新生成")
logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
print("\n=== 今日日程安排 ===")
logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n")
logger.info(f"时间[{time_str}]: 活动[{activity}]")
logger.info("==================")
# def main():
# # 使用示例
@@ -165,7 +168,7 @@ class ScheduleGenerator:
# scheduler.print_schedule()
# print("\n当前任务")
# print(scheduler.get_current_task())
# print("昨天日程:")
# print(scheduler.yesterday_schedule)
# print("今天日程:")
@@ -174,6 +177,6 @@ class ScheduleGenerator:
# print(scheduler.tomorrow_schedule)
# if __name__ == "__main__":
# main()
# main()
bot_schedule = ScheduleGenerator()

View File

@@ -3,6 +3,7 @@ import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict
from loguru import logger
from ...common.database import Database
@@ -153,8 +154,8 @@ class LLMStatistics:
try:
all_stats = self._collect_all_statistics()
self._save_statistics(all_stats)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
except Exception:
logger.exception("统计数据处理失败")
# 等待1分钟
for _ in range(60):

View File

@@ -0,0 +1,383 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import Database
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str:
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 首先按段落分割
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 如果单个段落就超过最大长度
if para_length > max_length:
# 如果当前chunk不为空先保存
if current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [s.strip() for s in para.replace('', '\n').replace('', '\n').replace('', '\n').split('\n') if s.strip()]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i:i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append('\n'.join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else:
# 保存当前chunk并开始新的chunk
chunks.append('\n'.join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self, knowledge_length:int=512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {
"processed_files": 0,
"total_chunks": 0,
"failed_files": [],
"skipped_files": []
}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {
"status": "success",
"chunks_processed": 0,
"error": None
}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now()
}
self.db.db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
self.db.db.processed_files.update_one(
{"file_path": file_path},
{
"$set": {
"hash": current_hash,
"last_processed": datetime.now(),
"split_by": split_by
}
},
upsert=True
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
def _update_stats(self, total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == 'q':
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y':
knowledge_library.db.db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == '1':
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == 'q':
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == 'q':
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue

View File

@@ -11,12 +11,14 @@ from pathlib import Path
import random
import math
import time
from loguru import logger
class ChineseTypoGenerator:
def __init__(self,
error_rate=0.3,
min_freq=5,
tone_error_rate=0.2,
def __init__(self,
error_rate=0.3,
min_freq=5,
tone_error_rate=0.2,
word_replace_rate=0.3,
max_freq_diff=200):
"""
@@ -34,27 +36,27 @@ class ChineseTypoGenerator:
self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff
# 加载数据
print("正在加载汉字数据库,请稍候...")
logger.debug("正在加载汉字数据库,请稍候...")
self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency()
def _load_or_create_char_frequency(self):
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
@@ -63,15 +65,15 @@ class ChineseTypoGenerator:
for char in word:
if self._is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
def _create_pinyin_dict(self):
@@ -81,7 +83,7 @@ class ChineseTypoGenerator:
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
@@ -89,7 +91,7 @@ class ChineseTypoGenerator:
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def _is_chinese_char(self, char):
@@ -107,7 +109,7 @@ class ChineseTypoGenerator:
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
@@ -117,7 +119,7 @@ class ChineseTypoGenerator:
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def _get_similar_tone_pinyin(self, py):
@@ -127,19 +129,19 @@ class ChineseTypoGenerator:
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
@@ -152,11 +154,11 @@ class ChineseTypoGenerator:
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > self.max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / self.max_freq_diff)
@@ -166,42 +168,42 @@ class ChineseTypoGenerator:
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有一定概率使用错误声调
if random.random() < self.tone_error_rate:
wrong_tone_py = self._get_similar_tone_pinyin(py)
homophones.extend(self.pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(self.pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = self.char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, self.char_frequency.get(h, 0))
for h in homophones
if h != char and self.char_frequency.get(h, 0) >= self.min_freq]
freq_diff = [(h, self.char_frequency.get(h, 0))
for h in homophones
if h != char and self.char_frequency.get(h, 0) >= self.min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = self._calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
@@ -223,10 +225,10 @@ class ChineseTypoGenerator:
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = self._get_word_pinyin(word)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
@@ -234,11 +236,11 @@ class ChineseTypoGenerator:
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
@@ -249,11 +251,11 @@ class ChineseTypoGenerator:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
@@ -268,7 +270,7 @@ class ChineseTypoGenerator:
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= self.min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
@@ -286,19 +288,19 @@ class ChineseTypoGenerator:
"""
result = []
typo_info = []
# 分词
words = self._segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not self._is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = self._get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < self.word_replace_rate:
word_homophones = self._get_word_homophones(word)
@@ -307,15 +309,15 @@ class ChineseTypoGenerator:
# 计算词的平均频率
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(self._get_word_pinyin(typo_word)),
orig_freq, typo_freq))
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(self._get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
@@ -339,7 +341,7 @@ class ChineseTypoGenerator:
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = self._get_similar_frequency_chars(char, py)
if similar_chars:
@@ -354,7 +356,7 @@ class ChineseTypoGenerator:
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_typo_info(self, typo_info):
@@ -369,7 +371,7 @@ class ChineseTypoGenerator:
"""
if not typo_info:
return "未生成错别字"
result = []
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
@@ -379,12 +381,12 @@ class ChineseTypoGenerator:
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]")
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]")
return "\n".join(result)
def set_params(self, **kwargs):
"""
设置参数
@@ -399,9 +401,10 @@ class ChineseTypoGenerator:
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}")
logger.debug(f"参数 {key} 已设置为 {value}")
else:
print(f"警告: 参数 {key} 不存在")
logger.warning(f"警告: 参数 {key} 不存在")
def main():
# 创建错别字生成器实例
@@ -411,27 +414,27 @@ def main():
tone_error_rate=0.02,
word_replace_rate=0.3
)
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
start_time = time.time()
typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
logger.debug("原句:", sentence)
logger.debug("错字版:", typo_sentence)
# 打印错别字信息
if typo_info:
print("\n错别字信息:")
print(typo_generator.format_typo_info(typo_info))
logger.debug(f"错别字信息:{typo_generator.format_typo_info(typo_info)})")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
logger.debug(f"总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

View File

@@ -5,13 +5,18 @@ PORT=8080
PLUGINS=["src2.plugins.chat"]
# 默认配置
MONGODB_HOST=127.0.0.1 # 如果工作在Docker下请改成 MONGODB_HOST=mongodb
# 如果工作在Docker下请改成 MONGODB_HOST=mongodb
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
# 也可以使用 URI 连接数据库(优先级比上面的高)
# MONGODB_URI=mongodb://127.0.0.1:27017/MegBot
# MongoDB 认证信息,若需要认证,请取消注释以下三行并填写正确的信息
# MONGODB_USERNAME=user
# MONGODB_PASSWORD=password
# MONGODB_AUTH_SOURCE=admin
#key and url
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
@@ -21,4 +26,4 @@ DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
#定义你要用的api的base_url
DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=
SILICONFLOW_KEY=

View File

@@ -1,48 +0,0 @@
import os
import sys
from pathlib import Path
import tomli
import tomli_w
def sync_configs():
# 读取两个配置文件
try:
with open('bot_config_dev.toml', 'rb') as f: # tomli需要使用二进制模式读取
dev_config = tomli.load(f)
with open('bot_config.toml', 'rb') as f:
prod_config = tomli.load(f)
except FileNotFoundError as e:
print(f"错误:找不到配置文件 - {e}")
sys.exit(1)
except tomli.TOMLDecodeError as e:
print(f"错误TOML格式解析失败 - {e}")
sys.exit(1)
# 递归合并配置
def merge_configs(source, target):
for key, value in source.items():
if key not in target:
target[key] = value
elif isinstance(value, dict) and isinstance(target[key], dict):
merge_configs(value, target[key])
# 将dev配置的新属性合并到prod配置中
merge_configs(dev_config, prod_config)
# 保存更新后的配置
try:
with open('bot_config.toml', 'wb') as f: # tomli_w需要使用二进制模式写入
tomli_w.dump(prod_config, f)
print("配置文件同步完成!")
except Exception as e:
print(f"错误:保存配置文件失败 - {e}")
sys.exit(1)
if __name__ == '__main__':
# 确保在正确的目录下运行
script_dir = Path(__file__).parent
os.chdir(script_dir)
sync_configs()

View File

@@ -1,9 +1,21 @@
[inner]
version = "0.0.3"
version = "0.0.6"
#如果你想要修改配置文件请在修改后将version的值进行变更
#如果新增项目请在BotConfig类下新增相应的变量
#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{
#"func":memory,
#"support":">=0.0.0", #新的版本号
#"necessary":False #是否必须
#}
#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断:
# if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
# config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
[bot]
qq = 123
nickname = "麦麦"
alias_names = ["小麦", "阿麦"]
[personality]
prompt_personality = [
@@ -29,6 +41,13 @@ ban_words = [
# "403","张三"
]
ban_msgs_regex = [
# 需要过滤的消息原始消息匹配的正则表达式匹配到的消息将被过滤支持CQ码若不了解正则表达式请勿修改
#"https?://[^\\s]+", # 匹配https链接
#"\\d{4}-\\d{2}-\\d{2}", # 匹配日期
# "\\[CQ:at,qq=\\d+\\]" # 匹配@
]
[emoji]
check_interval = 120 # 检查表情包的时间间隔
register_interval = 10 # 注册表情包的时间间隔
@@ -49,6 +68,10 @@ max_response_length = 1024 # 麦麦回答的最大token数
build_memory_interval = 300 # 记忆构建间隔 单位秒
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
memory_ban_words = [ #不希望记忆的词
# "403","张三"
]
[mood]
mood_update_interval = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate = 0.95 # 情绪衰减率
@@ -69,14 +92,15 @@ reaction = "回答“测试成功”"
[chinese_typo]
enable = true # 是否启用中文错别字生成器
error_rate=0.03 # 单字替换概率
error_rate=0.006 # 单字替换概率
min_freq=7 # 最小字频阈值
tone_error_rate=0.2 # 声调错误概率
word_replace_rate=0.02 # 整词替换概率
word_replace_rate=0.006 # 整词替换概率
[others]
enable_advance_output = true # 是否启用高级输出
enable_kuuki_read = true # 是否启用读空气功能
enable_debug_output = false # 是否启用调试输出
[groups]
talk_allowed = [

View File

@@ -0,0 +1,4 @@
更新版本后建议删除数据库messages中所有内容不然会出现报错
该操作不会影响你的记忆
如果显示配置文件版本过低运行根目录的bat

View File

@@ -0,0 +1,45 @@
@echo off
setlocal enabledelayedexpansion
chcp 65001
cd /d %~dp0
echo =====================================
echo 选择Python环境:
echo 1 - venv (推荐)
echo 2 - conda
echo =====================================
choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 (
echo =====================================
set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空
if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空
pause
exit /b 1
)
call conda activate !CONDA_ENV!
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "!CONDA_ENV!" 激活成功
python config/auto_update.py
) else (
if exist "venv\Scripts\python.exe" (
venv\Scripts\python config/auto_update.py
) else (
echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境
pause
exit /b 1
)
)
endlocal
pause

45
麦麦开始学习.bat Normal file
View File

@@ -0,0 +1,45 @@
@echo off
setlocal enabledelayedexpansion
chcp 65001
cd /d %~dp0
echo =====================================
echo 选择Python环境:
echo 1 - venv (推荐)
echo 2 - conda
echo =====================================
choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 (
echo =====================================
set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空
if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空
pause
exit /b 1
)
call conda activate !CONDA_ENV!
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/plugins/zhishi/knowledge_library.py
) else (
if exist "venv\Scripts\python.exe" (
venv\Scripts\python src/plugins/zhishi/knowledge_library.py
) else (
echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境
pause
exit /b 1
)
)
endlocal
pause