diff --git a/.dockerignore b/.dockerignore index 0ed9090fd..6c2d07736 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,4 +3,6 @@ __pycache__ *.pyc *.pyo *.pyd -.DS_Store \ No newline at end of file +.DS_Store +mongodb +napcat \ No newline at end of file diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..8392d159f --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..cf5cffa22 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.bat text eol=crlf +*.cmd text eol=crlf \ No newline at end of file diff --git a/.gitignore b/.gitignore index 51a11d8c2..4e1606a54 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,10 @@ config/bot_config.toml __pycache__/ *.py[cod] *$py.class +llm_statistics.txt +mongodb +napcat +run_dev.bat # C extensions *.so @@ -188,3 +192,10 @@ cython_debug/ # jieba jieba.cache + + +# vscode +/.vscode + +# direnv +/.direnv \ No newline at end of file diff --git a/README.md b/README.md index a3365b934..533d38383 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,19 @@ **🍔麦麦是一个基于大语言模型的智能QQ群聊机器人** -- 🤖 基于 nonebot2 框架开发 -- 🧠 LLM 提供对话能力 -- 💾 MongoDB 提供数据持久化支持 -- 🐧 NapCat 作为QQ协议端支持 +- 基于 nonebot2 框架开发 +- LLM 提供对话能力 +- MongoDB 提供数据持久化支持 +- NapCat 作为QQ协议端支持 + +**最新版本: v0.5.***
麦麦演示视频
👆 点击观看麦麦演示视频 👆 +
@@ -31,13 +34,32 @@ > - 文档未完善,有问题可以提交 Issue 或者 Discussion > - QQ机器人存在被限制风险,请自行了解,谨慎使用 > - 由于持续迭代,可能存在一些已知或未知的bug +> - 由于开发中,可能消耗较多token **交流群**: 766798517 一群人较多,建议加下面的(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 **交流群**: 571780722 另一个群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 -## 📚 文档 +## +
+

📚 文档 ⬇️ 快速开始使用麦麦 ⬇️

+
+ +### 部署方式 + +如果你不知道Docker是什么,建议寻找相关教程或使用手动部署 + +- [🐳 Docker部署指南](docs/docker_deploy.md) + +- [📦 手动部署指南](docs/manual_deploy.md) + +### 配置说明 +- [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘 +- [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户 + +
+

了解麦麦

+
-- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明 - [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节 ## 🎯 功能介绍 @@ -70,6 +92,12 @@ ## 开发计划TODO:LIST + +规划主线 +0.6.0:记忆系统更新 +0.7.0: 麦麦RunTime + + - 人格功能:WIP - 群氛围功能:WIP - 图片发送,转发功能:WIP @@ -87,10 +115,21 @@ - 改进表情包发送逻辑 - 自动生成的回复逻辑,例如自生成的回复方向,回复风格 - 采用截断生成加快麦麦的反应速度 -- 改进发送消息的触发: +- 改进发送消息的触发 + + +## 设计理念 + +- **千石可乐说:** +- 这个项目最初只是为了给牛牛bot添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在. +- 程序的功能设计理念基于一个核心的原则:"最像而不是好" +- 主打一个陪伴 +- 如果人类真的需要一个AI来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的helpful assistant,而是一个会犯错的,拥有自己感知和想法的"生命形式"。 +- 代码会保持开源和开放,但个人希望MaiMbot的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器. + ## 📌 注意事项 -纯编程外行,面向cursor编程,很多代码史一样多多包涵 +SengokuCola纯编程外行,面向cursor编程,很多代码史一样多多包涵 > ⚠️ **警告**:本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本人观点和立场。 diff --git a/bot.py b/bot.py index 50c8cfaa4..51979a5ea 100644 --- a/bot.py +++ b/bot.py @@ -1,14 +1,15 @@ import os + import nonebot -from nonebot.adapters.onebot.v11 import Adapter from dotenv import load_dotenv from loguru import logger +from nonebot.adapters.onebot.v11 import Adapter '''彩蛋''' -from colorama import init, Fore +from colorama import Fore, init init() -text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" +text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] rainbow_text = "" for i, char in enumerate(text): @@ -17,11 +18,15 @@ print(rainbow_text) '''彩蛋''' # 初次启动检测 -if not os.path.exists("config/bot_config.toml") or not os.path.exists(".env"): - logger.info("检测到bot_config.toml不存在,正在从模板复制") +if not os.path.exists("config/bot_config.toml"): + logger.warning("检测到bot_config.toml不存在,正在从模板复制") import shutil + # 检查config目录是否存在 + if not os.path.exists("config"): + os.makedirs("config") + logger.info("创建config目录") - shutil.copy("config/bot_config_template.toml", "config/bot_config.toml") + shutil.copy("template/bot_config_template.toml", "config/bot_config.toml") logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动") # 初始化.env 默认ENVIRONMENT=prod diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml deleted file mode 100644 index 28ffb0ce3..000000000 --- a/config/bot_config_template.toml +++ /dev/null @@ -1,98 +0,0 @@ -[bot] -qq = 123 -nickname = "麦麦" - -[personality] -prompt_personality = [ - "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧人格 - "是一个女大学生,你有黑色头发,你会刷小红书" # 小红书人格 - ] -prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" - -[message] -min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息 -max_context_size = 15 # 麦麦获得的上文数量 -emoji_chance = 0.2 # 麦麦使用表情包的概率 -ban_words = [ - # "403","张三" - ] - -[emoji] -check_interval = 120 # 检查表情包的时间间隔 -register_interval = 10 # 注册表情包的时间间隔 - -[cq_code] -enable_pic_translate = false - -[response] -model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 -model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 -model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 - -[memory] -build_memory_interval = 300 # 记忆构建间隔 单位秒 -forget_memory_interval = 300 # 记忆遗忘间隔 单位秒 - -[others] -enable_advance_output = true # 是否启用高级输出 -enable_kuuki_read = true # 是否启用读空气功能 - -[groups] -talk_allowed = [ - 123, - 123, -] #可以回复消息的群 -talk_frequency_down = [] #降低回复频率的群 -ban_user_id = [] #禁止回复消息的QQ号 - - -#V3 -#name = "deepseek-chat" -#base_url = "DEEP_SEEK_BASE_URL" -#key = "DEEP_SEEK_KEY" - -#R1 -#name = "deepseek-reasoner" -#base_url = "DEEP_SEEK_BASE_URL" -#key = "DEEP_SEEK_KEY" - -#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env.prod自定义的宏,使用自定义模型则选择定位相似的模型自己填写 - -[model.llm_reasoning] #R1 -name = "Pro/deepseek-ai/DeepSeek-R1" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_reasoning_minor] #R1蒸馏 -name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal] #V3 -name = "Pro/deepseek-ai/DeepSeek-V3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal_minor] #V2.5 -name = "deepseek-ai/DeepSeek-V2.5" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.vlm] #图像识别 -name = "deepseek-ai/deepseek-vl2" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.embedding] #嵌入 -name = "BAAI/bge-m3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -# 主题提取,jieba和snownlp不用api,llm需要api -[topic] -topic_extract='snownlp' # 只支持jieba,snownlp,llm三种选项 - -[topic.llm_topic] -name = "Pro/deepseek-ai/DeepSeek-V3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" diff --git a/docker-compose.yml b/docker-compose.yml index cfe787c04..dd2650b23 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: volumes: - napcatQQ:/app/.config/QQ - napcatCONFIG:/app/napcat/config - - maimbotDATA:/MaiMBot/data #麦麦的图片等要给napcat不然发送图片会有问题 + - maimbotDATA:/MaiMBot/data # 麦麦的图片等要给napcat不然发送图片会有问题 image: mlikiowa/napcat-docker:latest mongodb: @@ -39,7 +39,8 @@ services: - mongodb - napcat volumes: - - maimbotCONFIG:/MaiMBot/config + - napcatCONFIG:/MaiMBot/napcat # 自动根据配置中的qq号创建ws反向客户端配置 + - ./bot_config.toml:/MaiMBot/config/bot_config.toml - maimbotDATA:/MaiMBot/data - ./.env.prod:/MaiMBot/.env.prod image: sengokucola/maimbot:latest diff --git a/docs/doc1.md b/docs/doc1.md index 34de628ed..158136b9c 100644 --- a/docs/doc1.md +++ b/docs/doc1.md @@ -83,7 +83,6 @@ 14. **`topic_identifier.py`**: - 识别消息中的主题,帮助机器人理解用户的意图。 - - 使用多种方法(LLM、jieba、snownlp)进行主题识别。 15. **`utils.py`** 和 **`utils_*.py`** 系列文件: - 存放各种工具函数,提供辅助功能以支持其他模块。 diff --git a/docs/docker_deploy.md b/docs/docker_deploy.md new file mode 100644 index 000000000..c9b069309 --- /dev/null +++ b/docs/docker_deploy.md @@ -0,0 +1,24 @@ +# 🐳 Docker 部署指南 + +## 部署步骤(推荐,但不一定是最新) + +1. 获取配置文件: +```bash +wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml +``` + +2. 启动服务: +```bash +NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d +``` + +3. 修改配置后重启: +```bash +NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart +``` + +## ⚠️ 注意事项 + +- 目前部署方案仍在测试中,可能存在未知问题 +- 配置文件中的API密钥请妥善保管,不要泄露 +- 建议先在测试环境中运行,确认无误后再部署到生产环境 \ No newline at end of file diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index c988eb7c9..000000000 --- a/docs/installation.md +++ /dev/null @@ -1,145 +0,0 @@ -# 🔧 安装与配置指南 - -## 部署方式 - -如果你不知道Docker是什么,建议寻找相关教程或使用手动部署 - -### 🐳 Docker部署(推荐,但不一定是最新) - -1. 获取配置文件: -```bash -wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml -``` - -2. 启动服务: -```bash -NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d -``` - -3. 修改配置后重启: -```bash -NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart -``` - -### 📦 手动部署 - -1. **环境准备** -```bash -# 创建虚拟环境(推荐) -python -m venv venv -venv\\Scripts\\activate # Windows -# 安装依赖 -pip install -r requirements.txt -``` - -2. **配置MongoDB** -- 安装并启动MongoDB服务 -- 默认连接本地27017端口 - -3. **配置NapCat** -- 安装并登录NapCat -- 添加反向WS:`ws://localhost:8080/onebot/v11/ws` - -4. **配置文件设置** -- 修改环境配置文件:`.env.prod` -- 修改机器人配置文件:`bot_config.toml` - -5. **启动麦麦机器人** -- 打开命令行,cd到对应路径 -```bash -nb run -``` - -6. **其他组件** -- `run_thingking.bat`: 启动可视化推理界面(未完善) - -- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~ -- 直接运行 knowledge.py生成知识库 - -## ⚙️ 配置说明 - -### 环境配置 (.env.prod) -```ini -# API配置,你可以在这里定义你的密钥和base_url -# 你可以选择定义其他服务商提供的KEY,完全可以自定义 -SILICONFLOW_KEY=your_key -SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ -DEEP_SEEK_KEY=your_key -DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 - -# 服务配置,如果你不知道这是什么,保持默认 -HOST=127.0.0.1 -PORT=8080 - -# 数据库配置,如果你不知道这是什么,保持默认 -MONGODB_HOST=127.0.0.1 -MONGODB_PORT=27017 -DATABASE_NAME=MegBot -``` - -### 机器人配置 (bot_config.toml) -```toml -[bot] -qq = "你的机器人QQ号" -nickname = "麦麦" - -[message] -min_text_length = 2 -max_context_size = 15 -emoji_chance = 0.2 - -[emoji] -check_interval = 120 -register_interval = 10 - -[cq_code] -enable_pic_translate = false - -[response] -#现已移除deepseek或硅基流动选项,可以直接切换分别配置任意模型 -model_r1_probability = 0.8 #推理模型权重 -model_v3_probability = 0.1 #非推理模型权重 -model_r1_distill_probability = 0.1 - -[memory] -build_memory_interval = 300 - -[others] -enable_advance_output = true # 是否启用详细日志输出 - -[groups] -talk_allowed = [] # 允许回复的群号列表 -talk_frequency_down = [] # 降低回复频率的群号列表 -ban_user_id = [] # 禁止回复的用户QQ号列表 - -[model.llm_reasoning] -name = "Pro/deepseek-ai/DeepSeek-R1" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_reasoning_minor] -name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal] -name = "Pro/deepseek-ai/DeepSeek-V3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal_minor] -name = "deepseek-ai/DeepSeek-V2.5" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.vlm] -name = "deepseek-ai/deepseek-vl2" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" -``` - -## ⚠️ 注意事项 - -- 目前部署方案仍在测试中,可能存在未知问题 -- 配置文件中的API密钥请妥善保管,不要泄露 -- 建议先在测试环境中运行,确认无误后再部署到生产环境 \ No newline at end of file diff --git a/docs/installation_cute.md b/docs/installation_cute.md new file mode 100644 index 000000000..278cbfe20 --- /dev/null +++ b/docs/installation_cute.md @@ -0,0 +1,215 @@ +# 🔧 配置指南 喵~ + +## 👋 你好呀! + +让咱来告诉你我们要做什么喵: +1. 我们要一起设置一个可爱的AI机器人 +2. 这个机器人可以在QQ上陪你聊天玩耍哦 +3. 需要设置两个文件才能让机器人工作呢 + +## 📝 需要设置的文件喵 + +要设置这两个文件才能让机器人跑起来哦: +1. `.env.prod` - 这个文件告诉机器人要用哪些AI服务呢 +2. `bot_config.toml` - 这个文件教机器人怎么和你聊天喵 + +## 🔑 密钥和域名的对应关系 + +想象一下,你要进入一个游乐园,需要: +1. 知道游乐园的地址(这就是域名 base_url) +2. 有入场的门票(这就是密钥 key) + +在 `.env.prod` 文件里,我们定义了三个游乐园的地址和门票喵: +```ini +# 硅基流动游乐园 +SILICONFLOW_KEY=your_key # 硅基流动的门票 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ # 硅基流动的地址 + +# DeepSeek游乐园 +DEEP_SEEK_KEY=your_key # DeepSeek的门票 +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 # DeepSeek的地址 + +# ChatAnyWhere游乐园 +CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere的门票 +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere的地址 +``` + +然后在 `bot_config.toml` 里,机器人会用这些门票和地址去游乐园玩耍: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" # 告诉机器人:去硅基流动游乐园玩 +key = "SILICONFLOW_KEY" # 用硅基流动的门票进去 + +[model.llm_normal] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" # 还是去硅基流动游乐园 +key = "SILICONFLOW_KEY" # 用同一张门票就可以啦 +``` + +### 🎪 举个例子喵: + +如果你想用DeepSeek官方的服务,就要这样改: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "DEEP_SEEK_BASE_URL" # 改成去DeepSeek游乐园 +key = "DEEP_SEEK_KEY" # 用DeepSeek的门票 + +[model.llm_normal] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "DEEP_SEEK_BASE_URL" # 也去DeepSeek游乐园 +key = "DEEP_SEEK_KEY" # 用同一张DeepSeek门票 +``` + +### 🎯 简单来说: +- `.env.prod` 文件就像是你的票夹,存放着各个游乐园的门票和地址 +- `bot_config.toml` 就是告诉机器人:用哪张票去哪个游乐园玩 +- 所有模型都可以用同一个游乐园的票,也可以去不同的游乐园玩耍 +- 如果用硅基流动的服务,就保持默认配置不用改呢~ + +记住:门票(key)要保管好,不能给别人看哦,不然别人就可以用你的票去玩了喵! + +## ---让我们开始吧--- + +### 第一个文件:环境配置 (.env.prod) + +这个文件就像是机器人的"身份证"呢,告诉它要用哪些AI服务喵~ + +```ini +# 这些是AI服务的密钥,就像是魔法钥匙一样呢 +# 要把 your_key 换成真正的密钥才行喵 +# 比如说:SILICONFLOW_KEY=sk-123456789abcdef +SILICONFLOW_KEY=your_key +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY=your_key +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +CHAT_ANY_WHERE_KEY=your_key +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 + +# 如果你不知道这是什么,那么下面这些不用改,保持原样就好啦 +HOST=127.0.0.1 +PORT=8080 + +# 这些是数据库设置,一般也不用改呢 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot +MONGODB_USERNAME = "" # 如果数据库需要用户名,就在这里填写喵 +MONGODB_PASSWORD = "" # 如果数据库需要密码,就在这里填写呢 +MONGODB_AUTH_SOURCE = "" # 数据库认证源,一般不用改哦 + +# 插件设置喵 +PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢 +``` + +### 第二个文件:机器人配置 (bot_config.toml) + +这个文件就像是教机器人"如何说话"的魔法书呢! + +```toml +[bot] +qq = "把这里改成你的机器人QQ号喵" # 填写你的机器人QQ号 +nickname = "麦麦" # 机器人的名字,你可以改成你喜欢的任何名字哦 + +[personality] +# 这里可以设置机器人的性格呢,让它更有趣一些喵 +prompt_personality = [ + "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧风格的性格 + "是一个女大学生,你有黑色头发,你会刷小红书" # 小红书风格的性格 +] +prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + +[message] +min_text_length = 2 # 机器人每次至少要说几个字呢 +max_context_size = 15 # 机器人能记住多少条消息喵 +emoji_chance = 0.2 # 机器人使用表情的概率哦(0.2就是20%的机会呢) +ban_words = ["脏话", "不文明用语"] # 在这里填写不让机器人说的词 + +[emoji] +auto_save = true # 是否自动保存看到的表情包呢 +enable_check = false # 是否要检查表情包是不是合适的喵 +check_prompt = "符合公序良俗" # 检查表情包的标准呢 + +[groups] +talk_allowed = [123456, 789012] # 比如:让机器人在群123456和789012里说话 +talk_frequency_down = [345678] # 比如:在群345678里少说点话 +ban_user_id = [111222] # 比如:不回复QQ号为111222的人的消息 + +[others] +enable_advance_output = true # 是否要显示更多的运行信息呢 +enable_kuuki_read = true # 让机器人能够"察言观色"喵 + +# 模型配置部分的详细说明喵~ + + +#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成在.env.prod自己指定的密钥和域名,使用自定义模型则选择定位相似的模型自己填写 + +[model.llm_reasoning] #推理模型R1,用来理解和思考的喵 +name = "Pro/deepseek-ai/DeepSeek-R1" # 模型名字 +# name = "Qwen/QwQ-32B" # 如果想用千问模型,可以把上面那行注释掉,用这个呢 +base_url = "SILICONFLOW_BASE_URL" # 使用在.env.prod里设置的服务地址 +key = "SILICONFLOW_KEY" # 使用在.env.prod里设置的密钥 + +[model.llm_reasoning_minor] #R1蒸馏模型,是个轻量版的推理模型喵 +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal] #V3模型,用来日常聊天的喵 +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal_minor] #V2.5模型,是V3的前代版本呢 +name = "deepseek-ai/DeepSeek-V2.5" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.vlm] #图像识别模型,让机器人能看懂图片喵 +name = "deepseek-ai/deepseek-vl2" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.embedding] #嵌入模型,帮助机器人理解文本的相似度呢 +name = "BAAI/bge-m3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +# 如果选择了llm方式提取主题,就用这个模型配置喵 +[topic.llm_topic] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +``` + +## 💡 模型配置说明喵 + +1. **关于模型服务**: + - 如果你用硅基流动的服务,这些配置都不用改呢 + - 如果用DeepSeek官方API,要把base_url和key改成你在.env.prod里设置的值喵 + - 如果要用自定义模型,选择一个相似功能的模型配置来改呢 + +2. **主要模型功能**: + - `llm_reasoning`: 负责思考和推理的大脑喵 + - `llm_normal`: 负责日常聊天的嘴巴呢 + - `vlm`: 负责看图片的眼睛哦 + - `embedding`: 负责理解文字含义的理解力喵 + - `topic`: 负责理解对话主题的能力呢 + +## 🌟 小提示 +- 如果你刚开始使用,建议保持默认配置呢 +- 不同的模型有不同的特长,可以根据需要调整它们的使用比例哦 + +## 🌟 小贴士喵 +- 记得要好好保管密钥(key)哦,不要告诉别人呢 +- 配置文件要小心修改,改错了机器人可能就不能和你玩了喵 +- 如果想让机器人更聪明,可以调整 personality 里的设置呢 +- 不想让机器人说某些话,就把那些词放在 ban_words 里面喵 +- QQ群号和QQ号都要用数字填写,不要加引号哦(除了机器人自己的QQ号) + +## ⚠️ 注意事项 +- 这个机器人还在测试中呢,可能会有一些小问题喵 +- 如果不知道怎么改某个设置,就保持原样不要动它哦~ +- 记得要先有AI服务的密钥,不然机器人就不能和你说话了呢 +- 修改完配置后要重启机器人才能生效喵~ \ No newline at end of file diff --git a/docs/installation_standard.md b/docs/installation_standard.md new file mode 100644 index 000000000..6e4920220 --- /dev/null +++ b/docs/installation_standard.md @@ -0,0 +1,154 @@ +# 🔧 配置指南 + +## 简介 + +本项目需要配置两个主要文件: +1. `.env.prod` - 配置API服务和系统环境 +2. `bot_config.toml` - 配置机器人行为和模型 + +## API配置说明 + +`.env.prod`和`bot_config.toml`中的API配置关系如下: + +### 在.env.prod中定义API凭证: +```ini +# API凭证配置 +SILICONFLOW_KEY=your_key # 硅基流动API密钥 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ # 硅基流动API地址 + +DEEP_SEEK_KEY=your_key # DeepSeek API密钥 +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 # DeepSeek API地址 + +CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere API密钥 +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere API地址 +``` + +### 在bot_config.toml中引用API凭证: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" # 引用.env.prod中定义的地址 +key = "SILICONFLOW_KEY" # 引用.env.prod中定义的密钥 +``` + +如需切换到其他API服务,只需修改引用: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "DEEP_SEEK_BASE_URL" # 切换为DeepSeek服务 +key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥 +``` + +## 配置文件详解 + +### 环境配置文件 (.env.prod) +```ini +# API配置 +SILICONFLOW_KEY=your_key +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY=your_key +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +CHAT_ANY_WHERE_KEY=your_key +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 + +# 服务配置 +HOST=127.0.0.1 +PORT=8080 + +# 数据库配置 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot +MONGODB_USERNAME = "" # 数据库用户名 +MONGODB_PASSWORD = "" # 数据库密码 +MONGODB_AUTH_SOURCE = "" # 认证数据库 + +# 插件配置 +PLUGINS=["src2.plugins.chat"] +``` + +### 机器人配置文件 (bot_config.toml) +```toml +[bot] +qq = "机器人QQ号" # 必填 +nickname = "麦麦" # 机器人昵称 + +[personality] +prompt_personality = [ + "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", + "是一个女大学生,你有黑色头发,你会刷小红书" +] +prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + +[message] +min_text_length = 2 # 最小回复长度 +max_context_size = 15 # 上下文记忆条数 +emoji_chance = 0.2 # 表情使用概率 +ban_words = [] # 禁用词列表 + +[emoji] +auto_save = true # 自动保存表情 +enable_check = false # 启用表情审核 +check_prompt = "符合公序良俗" + +[groups] +talk_allowed = [] # 允许对话的群号 +talk_frequency_down = [] # 降低回复频率的群号 +ban_user_id = [] # 禁止回复的用户QQ号 + +[others] +enable_advance_output = true # 启用详细日志 +enable_kuuki_read = true # 启用场景理解 + +# 模型配置 +[model.llm_reasoning] # 推理模型 +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_reasoning_minor] # 轻量推理模型 +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal] # 对话模型 +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal_minor] # 备用对话模型 +name = "deepseek-ai/DeepSeek-V2.5" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.vlm] # 图像识别模型 +name = "deepseek-ai/deepseek-vl2" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.embedding] # 文本向量模型 +name = "BAAI/bge-m3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + + +[topic.llm_topic] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +``` + +## 注意事项 + +1. API密钥安全: + - 妥善保管API密钥 + - 不要将含有密钥的配置文件上传至公开仓库 + +2. 配置修改: + - 修改配置后需重启服务 + - 使用默认服务(硅基流动)时无需修改模型配置 + - QQ号和群号使用数字格式(机器人QQ号除外) + +3. 其他说明: + - 项目处于测试阶段,可能存在未知问题 + - 建议初次使用保持默认配置 \ No newline at end of file diff --git a/docs/manual_deploy.md b/docs/manual_deploy.md new file mode 100644 index 000000000..6d53beb4e --- /dev/null +++ b/docs/manual_deploy.md @@ -0,0 +1,100 @@ +# 📦 如何手动部署MaiMbot麦麦? + +## 你需要什么? + +- 一台电脑,能够上网的那种 + +- 一个QQ小号(QQ框架的使用可能导致qq被风控,严重(小概率)可能会导致账号封禁,强烈不推荐使用大号) + +- 可用的大模型API + +- 一个AI助手,网上随便搜一家打开来用都行,可以帮你解决一些不懂的问题 + +## 你需要知道什么? + +- 如何正确向AI助手提问,来学习新知识 + +- Python是什么 + +- Python的虚拟环境是什么?如何创建虚拟环境 + +- 命令行是什么 + +- 数据库是什么?如何安装并启动MongoDB + +- 如何运行一个QQ机器人,以及NapCat框架是什么 + +## 如果准备好了,就可以开始部署了 + +### 1️⃣ **首先,我们需要安装正确版本的Python** + +在创建虚拟环境之前,请确保你的电脑上安装了Python 3.9及以上版本。如果没有,可以按以下步骤安装: + +1. 访问Python官网下载页面:https://www.python.org/downloads/release/python-3913/ +2. 下载Windows安装程序 (64-bit): `python-3.9.13-amd64.exe` +3. 运行安装程序,并确保勾选"Add Python 3.9 to PATH"选项 +4. 点击"Install Now"开始安装 + +或者使用PowerShell自动下载安装(需要管理员权限): +```powershell +# 下载并安装Python 3.9.13 +$pythonUrl = "https://www.python.org/ftp/python/3.9.13/python-3.9.13-amd64.exe" +$pythonInstaller = "$env:TEMP\python-3.9.13-amd64.exe" +Invoke-WebRequest -Uri $pythonUrl -OutFile $pythonInstaller +Start-Process -Wait -FilePath $pythonInstaller -ArgumentList "/quiet", "InstallAllUsers=0", "PrependPath=1" -Verb RunAs +``` + +### 2️⃣ **创建Python虚拟环境来运行程序** + + 你可以选择使用以下两种方法之一来创建Python环境: + +```bash +# ---方法1:使用venv(Python自带) +# 在命令行中创建虚拟环境(环境名为maimbot) +# 这会让你在运行命令的目录下创建一个虚拟环境 +# 请确保你已通过cd命令前往到了对应路径,不然之后你可能找不到你的python环境 +python -m venv maimbot + +maimbot\\Scripts\\activate + +# 安装依赖 +pip install -r requirements.txt +``` +```bash +# ---方法2:使用conda +# 创建一个新的conda环境(环境名为maimbot) +# Python版本为3.9 +conda create -n maimbot python=3.9 + +# 激活环境 +conda activate maimbot + +# 安装依赖 +pip install -r requirements.txt +``` + +### 2️⃣ **然后你需要启动MongoDB数据库,来存储信息** +- 安装并启动MongoDB服务 +- 默认连接本地27017端口 + +### 3️⃣ **配置NapCat,让麦麦bot与qq取得联系** +- 安装并登录NapCat(用你的qq小号) +- 添加反向WS:`ws://localhost:8080/onebot/v11/ws` + +### 4️⃣ **配置文件设置,让麦麦Bot正常工作** +- 修改环境配置文件:`.env.prod` +- 修改机器人配置文件:`bot_config.toml` + +### 5️⃣ **启动麦麦机器人** +- 打开命令行,cd到对应路径 +```bash +nb run +``` +- 或者cd到对应路径后 +```bash +python bot.py +``` + +### 6️⃣ **其他组件(可选)** +- `run_thingking.bat`: 启动可视化推理界面(未完善) +- 直接运行 knowledge.py生成知识库 diff --git a/flagged/log.csv b/flagged/log.csv deleted file mode 100644 index daeef4a9b..000000000 --- a/flagged/log.csv +++ /dev/null @@ -1,2 +0,0 @@ -输入消息,推理内容,flag,username,timestamp -显示内容,,,,2025-02-18 16:50:53.643238 diff --git a/flake.lock b/flake.lock new file mode 100644 index 000000000..dd215f1c6 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1741196730, + "narHash": "sha256-0Sj6ZKjCpQMfWnN0NURqRCQn2ob7YtXTAOTwCuz7fkA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "48913d8f9127ea6530a2a2f1bd4daa1b8685d8a3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000..54737d640 --- /dev/null +++ b/flake.nix @@ -0,0 +1,61 @@ +{ + description = "MaiMBot Nix Dev Env"; + # 本配置仅方便用于开发,但是因为 nb-cli 上游打包中并未包含 nonebot2,因此目前本配置并不能用于运行和调试 + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { + inherit system; + }; + + pythonEnv = pkgs.python3.withPackages ( + ps: with ps; [ + pymongo + python-dotenv + pydantic + jieba + openai + aiohttp + requests + urllib3 + numpy + pandas + matplotlib + networkx + python-dateutil + APScheduler + loguru + tomli + customtkinter + colorama + pypinyin + pillow + setuptools + ] + ); + in + { + devShell = pkgs.mkShell { + buildInputs = [ + pythonEnv + pkgs.nb-cli + ]; + + shellHook = '' + ''; + }; + } + ); +} diff --git a/pyproject.toml b/pyproject.toml index 4f06cd5ae..e54dcdacd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,4 +5,19 @@ description = "New Bot Project" [tool.nonebot] plugins = ["src.plugins.chat"] -plugin_dirs = ["src/plugins"] \ No newline at end of file +plugin_dirs = ["src/plugins"] + +[tool.ruff] +# 设置 Python 版本 +target-version = "py39" + +# 启用的规则 +select = [ + "E", # pycodestyle 错误 + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear +] + +# 行长度设置 +line-length = 88 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 49c102dc6..4f969682f 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run.bat b/run.bat new file mode 100644 index 000000000..1d1385671 --- /dev/null +++ b/run.bat @@ -0,0 +1,6 @@ +@ECHO OFF +chcp 65001 +REM python -m venv venv +call venv\Scripts\activate.bat +REM pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade -r requirements.txt +python run.py \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 000000000..0a195544f --- /dev/null +++ b/run.py @@ -0,0 +1,122 @@ +import os +import subprocess +import zipfile + +import requests +from tqdm import tqdm + + +def extract_files(zip_path, target_dir): + """ + 解压 + + Args: + zip_path: 源ZIP压缩包路径(需确保是有效压缩包) + target_dir: 目标文件夹路径(会自动创建不存在的目录) + """ + # 打开ZIP压缩包(上下文管理器自动处理关闭) + with zipfile.ZipFile(zip_path) as zip_ref: + # 通过第一个文件路径推断顶层目录名(格式如:top_dir/) + top_dir = zip_ref.namelist()[0].split("/")[0] + "/" + + # 遍历压缩包内所有文件条目 + for file in zip_ref.namelist(): + # 跳过目录条目,仅处理文件 + if file.startswith(top_dir) and not file.endswith("/"): + # 截取顶层目录后的相对路径(如:sub_dir/file.txt) + rel_path = file[len(top_dir) :] + + # 创建目标目录结构(含多级目录) + os.makedirs( + os.path.dirname(f"{target_dir}/{rel_path}"), + exist_ok=True, # 忽略已存在目录的错误 + ) + + # 读取压缩包内文件内容并写入目标路径 + with open(f"{target_dir}/{rel_path}", "wb") as f: + f.write(zip_ref.read(file)) + + +def run_cmd(command: str, open_new_window: bool = False): + """ + 运行 cmd 命令 + + Args: + command (str): 指定要运行的命令 + open_new_window (bool): 指定是否新建一个 cmd 窗口运行 + """ + creationflags = 0 + if open_new_window: + creationflags = subprocess.CREATE_NEW_CONSOLE + subprocess.Popen( + [ + "cmd.exe", + "/c", + command, + ], + creationflags=creationflags, + ) + + +def run_maimbot(): + run_cmd(r"napcat\NapCatWinBootMain.exe 10001", False) + run_cmd( + r"mongodb\bin\mongod.exe --dbpath=" + os.getcwd() + r"\mongodb\db --port 27017", + True, + ) + run_cmd("nb run", True) + + +def install_mongodb(): + """ + 安装 MongoDB + """ + print("下载 MongoDB") + resp = requests.get( + "https://fastdl.mongodb.org/windows/mongodb-windows-x86_64-latest.zip", + stream=True, + ) + total = int(resp.headers.get("content-length", 0)) # 计算文件大小 + with open("mongodb.zip", "w+b") as file, tqdm( # 展示下载进度条,并解压文件 + desc="mongodb.zip", + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + extract_files("mongodb.zip", "mongodb") + print("MongoDB 下载完成") + os.remove("mongodb.zip") + + +def install_napcat(): + run_cmd("start https://github.com/NapNeko/NapCatQQ/releases", True) + print("请检查弹出的浏览器窗口,点击**第一个**蓝色的“Win64无头” 下载 napcat") + napcat_filename = input( + "下载完成后请把文件复制到此文件夹,并将**不包含后缀的文件名**输入至此窗口,如 NapCat.32793.Shell:" + ) + extract_files(napcat_filename + ".zip", "napcat") + print("NapCat 安装完成") + os.remove(napcat_filename + ".zip") + + +if __name__ == "__main__": + os.system("cls") + choice = input( + "请输入要进行的操作:\n" + "1.首次安装\n" + "2.运行麦麦\n" + "3.运行麦麦并启动可视化推理界面\n" + ) + os.system("cls") + if choice == "1": + install_napcat() + install_mongodb() + elif choice == "2": + run_maimbot() + elif choice == "3": + run_maimbot() + run_cmd("python src/gui/reasoning_gui.py", True) diff --git a/run_db.bat b/script/run_db.bat similarity index 100% rename from run_db.bat rename to script/run_db.bat diff --git a/run_maimai.bat b/script/run_maimai.bat similarity index 65% rename from run_maimai.bat rename to script/run_maimai.bat index ff00cc5c1..3a099fd7f 100644 --- a/run_maimai.bat +++ b/script/run_maimai.bat @@ -1,5 +1,5 @@ chcp 65001 -call conda activate niuniu +call conda activate maimbot cd . REM 执行nb run命令 diff --git a/run_thingking.bat b/script/run_thingking.bat similarity index 100% rename from run_thingking.bat rename to script/run_thingking.bat diff --git a/script/run_windows.bat b/script/run_windows.bat new file mode 100644 index 000000000..bea397ddc --- /dev/null +++ b/script/run_windows.bat @@ -0,0 +1,68 @@ +@echo off +setlocal enabledelayedexpansion +chcp 65001 + +REM 修正路径获取逻辑 +cd /d "%~dp0" || ( + echo 错误:切换目录失败 + exit /b 1 +) + +if not exist "venv\" ( + echo 正在初始化虚拟环境... + + where python >nul 2>&1 + if %errorlevel% neq 0 ( + echo 未找到Python解释器 + exit /b 1 + ) + + for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a + for /f "tokens=1,2 delims=." %%b in ("!version!") do ( + set major=%%b + set minor=%%c + ) + + if !major! lss 3 ( + echo 需要Python大于等于3.0,当前版本 !version! + exit /b 1 + ) + + if !major! equ 3 if !minor! lss 9 ( + echo 需要Python大于等于3.9,当前版本 !version! + exit /b 1 + ) + + echo 正在安装virtualenv... + python -m pip install virtualenv || ( + echo virtualenv安装失败 + exit /b 1 + ) + + echo 正在创建虚拟环境... + python -m virtualenv venv || ( + echo 虚拟环境创建失败 + exit /b 1 + ) + + call venv\Scripts\activate.bat + +) else ( + call venv\Scripts\activate.bat +) + +echo 正在更新依赖... +pip install -r requirements.txt + +echo 当前代理设置: +echo HTTP_PROXY=%HTTP_PROXY% +echo HTTPS_PROXY=%HTTPS_PROXY% + +set HTTP_PROXY= +set HTTPS_PROXY= +echo 代理已取消。 + +set no_proxy=0.0.0.0/32 + +call nb run +pause \ No newline at end of file diff --git a/setup.py b/setup.py index a6152a972..2598a38a8 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="maimai-bot", diff --git a/src/common/database.py b/src/common/database.py index 5928abc42..45ac05dac 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,6 +1,8 @@ -from pymongo import MongoClient from typing import Optional +from pymongo import MongoClient + + class Database: _instance: Optional["Database"] = None diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index 61fb34560..340791ee3 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -1,12 +1,12 @@ -import customtkinter as ctk -from typing import Dict, List -import json -from datetime import datetime -import time -import threading +import os import queue import sys -import os +import threading +import time +from datetime import datetime +from typing import Dict, List + +import customtkinter as ctk from dotenv import load_dotenv # 获取当前文件的目录 @@ -25,9 +25,11 @@ else: print("未找到环境配置文件") sys.exit(1) -from pymongo import MongoClient from typing import Optional +from pymongo import MongoClient + + class Database: _instance: Optional["Database"] = None diff --git a/src/plugins/chat/Segment_builder.py b/src/plugins/chat/Segment_builder.py index 09673a044..ed75f7092 100644 --- a/src/plugins/chat/Segment_builder.py +++ b/src/plugins/chat/Segment_builder.py @@ -1,6 +1,5 @@ -from typing import Dict, List, Union, Optional, Any import base64 -import os +from typing import Any, Dict, List, Union """ OneBot v11 Message Segment Builder diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index ab99f6477..0bffaed19 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -1,20 +1,29 @@ -from loguru import logger -from nonebot import on_message, on_command, require, get_driver -from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment -from nonebot.typing import T_State -from ...common.database import Database -from .config import global_config -import os import asyncio +import os import random -from .relationship_manager import relationship_manager -from ..schedule.schedule_generator import bot_schedule -from .willing_manager import willing_manager -from nonebot.rule import to_me -from .bot import chat_bot -from .emoji_manager import emoji_manager import time +from loguru import logger +from nonebot import get_driver, on_command, on_message, require +from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment +from nonebot.rule import to_me +from nonebot.typing import T_State + +from ...common.database import Database +from ..moods.moods import MoodManager # 导入情绪管理器 +from ..schedule.schedule_generator import bot_schedule +from ..utils.statistic import LLMStatistics +from .bot import chat_bot +from .config import global_config +from .emoji_manager import emoji_manager +from .relationship_manager import relationship_manager +from .willing_manager import willing_manager + +# 创建LLM统计实例 +llm_stats = LLMStatistics("llm_statistics.txt") + +# 添加标志变量 +_message_manager_started = False # 获取驱动器 driver = get_driver() @@ -32,12 +41,11 @@ print("\033[1;32m[初始化数据库完成]\033[0m") # 导入其他模块 +from ..memory_system.memory import hippocampus, memory_graph from .bot import ChatBot -from .emoji_manager import emoji_manager + # from .message_send_control import message_sender -from .relationship_manager import relationship_manager -from .message_sender import message_manager,message_sender -from ..memory_system.memory import memory_graph,hippocampus +from .message_sender import message_manager, message_sender # 初始化表情管理器 emoji_manager.initialize() @@ -55,6 +63,15 @@ scheduler = require("nonebot_plugin_apscheduler").scheduler @driver.on_startup async def start_background_tasks(): """启动后台任务""" + # 启动LLM统计 + llm_stats.start() + print("\033[1;32m[初始化]\033[0m LLM统计功能已启动") + + # 初始化并启动情绪管理器 + mood_manager = MoodManager.get_instance() + mood_manager.start_mood_update(update_interval=global_config.mood_update_interval) + print("\033[1;32m[初始化]\033[0m 情绪管理器已启动") + # 只启动表情包管理任务 asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) await bot_schedule.initialize() @@ -70,18 +87,20 @@ async def init_relationships(): @driver.on_bot_connect async def _(bot: Bot): """Bot连接成功时的处理""" + global _message_manager_started print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m") await willing_manager.ensure_started() - message_sender.set_bot(bot) print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m") - asyncio.create_task(message_manager.start_processor()) - print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") + + if not _message_manager_started: + asyncio.create_task(message_manager.start_processor()) + _message_manager_started = True + print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") - # 启动消息发送控制任务 @group_msg.handle() async def _(bot: Bot, event: GroupMessageEvent, state: T_State): @@ -90,7 +109,7 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State): # 添加build_memory定时任务 @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") async def build_memory_task(): - """每30秒执行一次记忆构建""" + """每build_memory_interval秒执行一次记忆构建""" print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------") start_time = time.time() await hippocampus.operation_build_memory(chat_size=20) @@ -110,4 +129,10 @@ async def merge_memory_task(): # print("\033[1;32m[记忆整合]\033[0m 开始整合") # await hippocampus.operation_merge_memory(percentage=0.1) # print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") + +@scheduler.scheduled_job("interval", seconds=30, id="print_mood") +async def print_mood_task(): + """每30秒打印一次情绪状态""" + mood_manager = MoodManager.get_instance() + mood_manager.print_mood_status() diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index e3525b3bb..a02c4a059 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -1,22 +1,27 @@ -from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot -from .message import Message, MessageSet, Message_Sending -from .config import BotConfig, global_config -from .storage import MessageStorage -from .llm_generator import ResponseGenerator -# from .message_stream import MessageStream, MessageStreamContainer -from .topic_identifier import topic_identifier -from random import random, choice -from .emoji_manager import emoji_manager # 导入表情包管理器 import time -import os -from .cq_code import CQCode # 导入CQCode模块 -from .message_sender import message_manager # 导入新的消息管理器 -from .message import Message_Thinking # 导入 Message_Thinking 类 -from .relationship_manager import relationship_manager -from .willing_manager import willing_manager # 导入意愿管理器 -from .utils import is_mentioned_bot_in_txt, calculate_typing_time -from ..memory_system.memory import memory_graph +from random import random + from loguru import logger +from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent + +from ..memory_system.memory import hippocampus +from ..moods.moods import MoodManager # 导入情绪管理器 +from .config import global_config +from .cq_code import CQCode # 导入CQCode模块 +from .emoji_manager import emoji_manager # 导入表情包管理器 +from .llm_generator import ResponseGenerator +from .message import ( + Message, + Message_Sending, + Message_Thinking, # 导入 Message_Thinking 类 + MessageSet, +) +from .message_sender import message_manager # 导入新的消息管理器 +from .relationship_manager import relationship_manager +from .storage import MessageStorage +from .utils import calculate_typing_time, is_mentioned_bot_in_txt +from .willing_manager import willing_manager # 导入意愿管理器 + class ChatBot: def __init__(self): @@ -24,6 +29,8 @@ class ChatBot: self.gpt = ResponseGenerator() self.bot = None # bot 实例引用 self._started = False + self.mood_manager = MoodManager.get_instance() # 获取情绪管理器单例 + self.mood_manager.start_mood_update() # 启动情绪更新 self.emoji_chance = 0.2 # 发送表情包的基础概率 # self.message_streams = MessageStreamContainer() @@ -58,6 +65,7 @@ class ChatBot: plain_text=event.get_plaintext(), reply_message=event.reply, ) + await message.initialize() # 过滤词 for word in global_config.ban_words: @@ -70,24 +78,12 @@ class ChatBot: - topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) - - - # topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text) - # topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text) - # topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text) - logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") - - all_num = 0 - interested_num = 0 - if topic: - for current_topic in topic: - all_num += 1 - first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) - if first_layer_items: - interested_num += 1 - print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象") - interested_rate = interested_num / all_num if all_num > 0 else 0 + # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) + topic = '' + interested_rate = 0 + interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text)/100 + print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n") + # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") await self.storage.store_message(message, topic[0] if topic else None) @@ -119,14 +115,9 @@ class ChatBot: willing_manager.change_reply_willing_sent(thinking_message.group_id) - response, emotion = await self.gpt.generate_response(message) - - # if response is None: - # thinking_message.interupt=True + response,raw_content = await self.gpt.generate_response(message) if response: - # print(f"\033[1;32m[思考结束]\033[0m 思考结束,已得到回复,开始回复") - # 找到并删除对应的thinking消息 container = message_manager.get_container(event.group_id) thinking_message = None # 找到message,删除 @@ -134,8 +125,13 @@ class ChatBot: if isinstance(msg, Message_Thinking) and msg.message_id == think_id: thinking_message = msg container.messages.remove(msg) - print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除") + # print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除") break + + # 如果找不到思考消息,直接返回 + if not thinking_message: + print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除") + return #记录开始思考的时间,避免从思考到回复的时间太久 thinking_start_time = thinking_message.thinking_start_time @@ -144,6 +140,7 @@ class ChatBot: accu_typing_time = 0 # print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器") + mark_head = False for msg in response: # print(f"\033[1;32m[回复内容]\033[0m {msg}") #通过时间改变时间戳 @@ -164,16 +161,25 @@ class ChatBot: thinking_start_time=thinking_start_time, #记录了思考开始的时间 reply_message_id=message.message_id ) + await bot_message.initialize() + if not mark_head: + bot_message.is_head = True + mark_head = True message_set.add_message(bot_message) #message_set 可以直接加入 message_manager - print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") + # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") message_manager.add_message(message_set) bot_response_time = tinking_time_point + if random() < global_config.emoji_chance: - emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) - if emoji_path: + emoji_raw = await emoji_manager.get_emoji_for_text(response) + + # 检查是否 <没有找到> emoji + if emoji_raw != None: + emoji_path,discription = emoji_raw + emoji_cq = CQCode.create_emoji_cq(emoji_path) if random() < 0.5: @@ -188,6 +194,7 @@ class ChatBot: raw_message=emoji_cq, plain_text=emoji_cq, processed_plain_text=emoji_cq, + detailed_plain_text=discription, user_nickname=global_config.BOT_NICKNAME, group_name=message.group_name, time=bot_response_time, @@ -196,9 +203,24 @@ class ChatBot: thinking_start_time=thinking_start_time, # reply_message_id=message.message_id ) + await bot_message.initialize() message_manager.add_message(bot_message) + emotion = await self.gpt._get_emotion_tags(raw_content) + print(f"为 '{response}' 获取到的情感标签为:{emotion}") + valuedict={ + 'happy': 0.5, + 'angry': -1, + 'sad': -0.5, + 'surprised': 0.2, + 'disgusted': -1.5, + 'fearful': -0.7, + 'neutral': 0.1 + } + await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) + # 使用情绪管理器更新情绪 + self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) - willing_manager.change_reply_willing_after_sent(event.group_id) + # willing_manager.change_reply_willing_after_sent(event.group_id) # 创建全局ChatBot实例 chat_bot = ChatBot() \ No newline at end of file diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index be599f48a..fd65c116d 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -1,12 +1,9 @@ -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, Set import os -import configparser -import tomli -import sys -from loguru import logger -from nonebot import get_driver +from dataclasses import dataclass, field +from typing import Dict, Optional +import tomli +from loguru import logger @dataclass @@ -24,43 +21,61 @@ class BotConfig: talk_allowed_groups = set() talk_frequency_down_groups = set() + thinking_timeout: int = 100 # 思考时间 + + response_willing_amplifier: float = 1.0 # 回复意愿放大系数 + response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数 + down_frequency_rate: float = 3.5 # 降低回复频率的群组回复意愿降低系数 + ban_user_id = set() build_memory_interval: int = 30 # 记忆构建间隔(秒) forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) + EMOJI_SAVE: bool = True # 偷表情包 + EMOJI_CHECK: bool = False #是否开启过滤 + EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 ban_words = set() + + max_response_length: int = 1024 # 最大回复长度 # 模型配置 llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {}) llm_normal: Dict[str, str] = field(default_factory=lambda: {}) llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {}) + llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {}) + llm_summary_by_topic: Dict[str, str] = field(default_factory=lambda: {}) + llm_emotion_judge: Dict[str, str] = field(default_factory=lambda: {}) embedding: Dict[str, str] = field(default_factory=lambda: {}) vlm: Dict[str, str] = field(default_factory=lambda: {}) + moderation: Dict[str, str] = field(default_factory=lambda: {}) - # 主题提取配置 - topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm - llm_topic_extract: Dict[str, str] = field(default_factory=lambda: {}) - - API_USING: str = "siliconflow" # 使用的API - API_PAID: bool = False # 是否使用付费API MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 enable_advance_output: bool = False # 是否启用高级输出 enable_kuuki_read: bool = True # 是否启用读空气功能 + + mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒 + mood_decay_rate: float = 0.95 # 情绪衰减率 + mood_intensity_factor: float = 0.7 # 情绪强度因子 # 默认人设 PROMPT_PERSONALITY=[ "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", - "是一个女大学生,你有黑色头发,你会刷小红书" + "是一个女大学生,你有黑色头发,你会刷小红书", + "是一个女大学生,你会刷b站,对ACG文化感兴趣" ] PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + PERSONALITY_1: float = 0.6 # 第一种人格概率 + PERSONALITY_2: float = 0.3 # 第二种人格概率 + PERSONALITY_3: float = 0.1 # 第三种人格概率 + @staticmethod def get_config_dir() -> str: """获取配置文件目录""" @@ -78,7 +93,11 @@ class BotConfig: config = cls() if os.path.exists(config_path): with open(config_path, "rb") as f: - toml_dict = tomli.load(f) + try: + toml_dict = tomli.load(f) + except(tomli.TOMLDecodeError) as e: + logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") + exit(1) if 'personality' in toml_dict: personality_config=toml_dict['personality'] @@ -88,11 +107,17 @@ class BotConfig: config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY) logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}") config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN) + config.PERSONALITY_1=personality_config.get('personality_1_probability',config.PERSONALITY_1) + config.PERSONALITY_2=personality_config.get('personality_2_probability',config.PERSONALITY_2) + config.PERSONALITY_3=personality_config.get('personality_3_probability',config.PERSONALITY_3) if "emoji" in toml_dict: emoji_config = toml_dict["emoji"] config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) + config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) + config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE) + config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK) if "cq_code" in toml_dict: cq_code_config = toml_dict["cq_code"] @@ -110,8 +135,7 @@ class BotConfig: config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) - config.API_USING = response_config.get("api_using", config.API_USING) - config.API_PAID = response_config.get("api_paid", config.API_PAID) + config.max_response_length = response_config.get("max_response_length", config.max_response_length) # 加载模型配置 if "model" in toml_dict: @@ -125,10 +149,18 @@ class BotConfig: if "llm_normal" in model_config: config.llm_normal = model_config["llm_normal"] - config.llm_topic_extract = config.llm_normal if "llm_normal_minor" in model_config: config.llm_normal_minor = model_config["llm_normal_minor"] + + if "llm_topic_judge" in model_config: + config.llm_topic_judge = model_config["llm_topic_judge"] + + if "llm_summary_by_topic" in model_config: + config.llm_summary_by_topic = model_config["llm_summary_by_topic"] + + if "llm_emotion_judge" in model_config: + config.llm_emotion_judge = model_config["llm_emotion_judge"] if "vlm" in model_config: config.vlm = model_config["vlm"] @@ -136,14 +168,8 @@ class BotConfig: if "embedding" in model_config: config.embedding = model_config["embedding"] - if 'topic' in toml_dict: - topic_config=toml_dict['topic'] - if 'topic_extract' in topic_config: - config.topic_extract=topic_config.get('topic_extract',config.topic_extract) - logger.info(f"载入自定义主题提取为{config.topic_extract}") - if config.topic_extract=='llm' and 'llm_topic' in topic_config: - config.llm_topic_extract=topic_config['llm_topic'] - logger.info(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}") + if "moderation" in model_config: + config.moderation = model_config["moderation"] # 消息配置 if "message" in toml_dict: @@ -152,11 +178,21 @@ class BotConfig: config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE) config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance) config.ban_words=msg_config.get("ban_words",config.ban_words) + config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout) + config.response_willing_amplifier = msg_config.get("response_willing_amplifier", config.response_willing_amplifier) + config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier", config.response_interested_rate_amplifier) + config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate) if "memory" in toml_dict: memory_config = toml_dict["memory"] config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval) config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval) + + if "mood" in toml_dict: + mood_config = toml_dict["mood"] + config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval) + config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate) + config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor) # 群组配置 if "groups" in toml_dict: @@ -178,13 +214,13 @@ class BotConfig: bot_config_floder_path = BotConfig.get_config_dir() print(f"正在品鉴配置文件目录: {bot_config_floder_path}") -bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml") -if not os.path.exists(bot_config_path): +bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") +if os.path.exists(bot_config_path): # 如果开发环境配置文件不存在,则使用默认配置文件 - bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") + print(f"异常的新鲜,异常的美味: {bot_config_path}") logger.info("使用bot配置文件") else: - logger.info("已找到开发bot配置文件") + logger.info("没有找到美味") global_config = BotConfig.load_config(config_path=bot_config_path) diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 4d70736cd..4a295e3d5 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -1,23 +1,23 @@ -from dataclasses import dataclass -from typing import Dict, Optional, List, Union -import html -import requests import base64 -from PIL import Image +import html import os -from random import random -from nonebot.adapters.onebot.v11 import Bot -from .config import global_config import time -import asyncio -from .utils_image import storage_image,storage_emoji -from .utils_user import get_user_nickname -from ..models.utils_model import LLM_request -#解析各种CQ码 -#包含CQ码类 +from dataclasses import dataclass +from typing import Dict, Optional + +import requests + +# 解析各种CQ码 +# 包含CQ码类 import urllib3 -from urllib3.util import create_urllib3_context from nonebot import get_driver +from urllib3.util import create_urllib3_context + +from ..models.utils_model import LLM_request +from .config import global_config +from .mapper import emojimapper +from .utils_image import storage_emoji, storage_image +from .utils_user import get_user_nickname driver = get_driver() config = driver.config @@ -27,6 +27,7 @@ ctx = create_urllib3_context() ctx.load_default_certs() ctx.set_ciphers("AES128-GCM-SHA256") + class TencentSSLAdapter(requests.adapters.HTTPAdapter): def __init__(self, ssl_context=None, **kwargs): self.ssl_context = ssl_context @@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter): num_pools=connections, maxsize=maxsize, block=block, ssl_context=self.ssl_context) + @dataclass class CQCode: """ @@ -64,29 +66,29 @@ class CQCode: """初始化LLM实例""" self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) - def translate(self): + async def translate(self): """根据CQ码类型进行相应的翻译处理""" if self.type == 'text': self.translated_plain_text = self.params.get('text', '') elif self.type == 'image': if self.params.get('sub_type') == '0': - self.translated_plain_text = self.translate_image() + self.translated_plain_text = await self.translate_image() else: - self.translated_plain_text = self.translate_emoji() + self.translated_plain_text = await self.translate_emoji() elif self.type == 'at': user_nickname = get_user_nickname(self.params.get('qq', '')) if user_nickname: self.translated_plain_text = f"[@{user_nickname}]" else: - self.translated_plain_text = f"@某人" + self.translated_plain_text = "@某人" elif self.type == 'reply': - self.translated_plain_text = self.translate_reply() + self.translated_plain_text = await self.translate_reply() elif self.type == 'face': face_id = self.params.get('id', '') # self.translated_plain_text = f"[表情{face_id}]" - self.translated_plain_text = f"[表情]" + self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]" elif self.type == 'forward': - self.translated_plain_text = self.translate_forward() + self.translated_plain_text = await self.translate_forward() else: self.translated_plain_text = f"[{self.type}]" @@ -133,7 +135,7 @@ class CQCode: # 腾讯服务器特殊状态码处理 if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: return None - + if response.status_code != 200: raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") @@ -157,8 +159,8 @@ class CQCode: return None return None - - def translate_emoji(self) -> str: + + async def translate_emoji(self) -> str: """处理表情包类型的CQ码""" if 'url' not in self.params: return '[表情包]' @@ -167,50 +169,51 @@ class CQCode: # 将 base64 字符串转换为字节类型 image_bytes = base64.b64decode(base64_str) storage_emoji(image_bytes) - return self.get_emoji_description(base64_str) + return await self.get_emoji_description(base64_str) else: return '[表情包]' - - - def translate_image(self) -> str: + + async def translate_image(self) -> str: """处理图片类型的CQ码,区分普通图片和表情包""" - #没有url,直接返回默认文本 + # 没有url,直接返回默认文本 if 'url' not in self.params: return '[图片]' base64_str = self.get_img() if base64_str: image_bytes = base64.b64decode(base64_str) storage_image(image_bytes) - return self.get_image_description(base64_str) + return await self.get_image_description(base64_str) else: return '[图片]' - def get_emoji_description(self, image_base64: str) -> str: + async def get_emoji_description(self, image_base64: str) -> str: """调用AI接口获取表情包描述""" try: prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[表情包:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[表情包]" - def get_image_description(self, image_base64: str) -> str: + async def get_image_description(self, image_base64: str) -> str: """调用AI接口获取普通图片描述""" try: prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[图片:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[图片]" - - def translate_forward(self) -> str: + + async def translate_forward(self) -> str: """处理转发消息""" try: if 'content' not in self.params: return '[转发消息]' - + # 解析content内容(需要先反转义) content = self.unescape(self.params['content']) # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") @@ -221,17 +224,17 @@ class CQCode: except ValueError as e: print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") return '[转发消息]' - + # 处理每条消息 formatted_messages = [] for msg in messages: sender = msg.get('sender', {}) nickname = sender.get('card') or sender.get('nickname', '未知用户') - + # 获取消息内容并使用Message类处理 raw_message = msg.get('raw_message', '') message_array = msg.get('message', []) - + if message_array and isinstance(message_array, list): # 检查是否包含嵌套的转发消息 for message_part in message_array: @@ -249,6 +252,7 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' @@ -263,23 +267,24 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' - + formatted_msg = f"{nickname}: {content}" formatted_messages.append(formatted_msg) - + # 合并所有消息 combined_messages = '\n'.join(formatted_messages) print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") return f"[转发消息:\n{combined_messages}]" - + except Exception as e: print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") return '[转发消息]' - def translate_reply(self) -> str: + async def translate_reply(self) -> str: """处理回复类型的CQ码""" # 创建Message对象 @@ -287,7 +292,7 @@ class CQCode: if self.reply_message == None: # print(f"\033[1;31m[错误]\033[0m 回复消息为空") return '[回复某人消息]' - + if self.reply_message.sender.user_id: message_obj = Message( user_id=self.reply_message.sender.user_id, @@ -295,22 +300,23 @@ class CQCode: raw_message=str(self.reply_message.message), group_id=self.group_id ) + await message_obj.initialize() if message_obj.user_id == global_config.BOT_QQ: return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" else: return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]" else: - print(f"\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空") + print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空") return '[回复某人消息]' @staticmethod def unescape(text: str) -> str: """反转义CQ码中的特殊字符""" return text.replace(',', ',') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace('&', '&') + .replace('[', '[') \ + .replace(']', ']') \ + .replace('&', '&') @staticmethod def create_emoji_cq(file_path: str) -> str: @@ -325,15 +331,16 @@ class CQCode: abs_path = os.path.abspath(file_path) # 转义特殊字符 escaped_path = abs_path.replace('&', '&') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace(',', ',') + .replace('[', '[') \ + .replace(']', ']') \ + .replace(',', ',') # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" - + + class CQCode_tool: @staticmethod - def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: + async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: """ 将CQ码字典转换为CQCode对象 @@ -352,7 +359,7 @@ class CQCode_tool: params['text'] = cq_code.get('data', {}).get('text', '') else: params = cq_code.get('data', {}) - + instance = CQCode( type=cq_type, params=params, @@ -360,11 +367,11 @@ class CQCode_tool: user_id=0, reply_message=reply ) - + # 进行翻译处理 - instance.translate() + await instance.translate() return instance - + @staticmethod def create_reply_cq(message_id: int) -> str: """ @@ -375,6 +382,6 @@ class CQCode_tool: 回复CQ码字符串 """ return f"[CQ:reply,id={message_id}]" - - + + cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 2311b2459..4f2637738 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -1,22 +1,17 @@ -from typing import List, Dict, Optional -import random -from ...common.database import Database -import os -import json -from dataclasses import dataclass -import jieba.analyse as jieba_analyse -import aiohttp -import hashlib -from datetime import datetime -import base64 -import shutil import asyncio +import os +import random import time -from PIL import Image -import io +import traceback +from typing import Optional +from loguru import logger from nonebot import get_driver + +from ...common.database import Database from ..chat.config import global_config +from ..chat.utils import get_embedding +from ..chat.utils_image import image_path_to_base64 from ..models.utils_model import LLM_request driver = get_driver() @@ -27,16 +22,6 @@ class EmojiManager: _instance = None EMOJI_DIR = "data/emoji" # 表情包存储目录 - EMOTION_KEYWORDS = { - 'happy': ['开心', '快乐', '高兴', '欢喜', '笑', '喜悦', '兴奋', '愉快', '乐', '好'], - 'angry': ['生气', '愤怒', '恼火', '不爽', '火大', '怒', '气愤', '恼怒', '发火', '不满'], - 'sad': ['伤心', '难过', '悲伤', '痛苦', '哭', '忧伤', '悲痛', '哀伤', '委屈', '失落'], - 'surprised': ['惊讶', '震惊', '吃惊', '意外', '惊', '诧异', '惊奇', '惊喜', '不敢相信', '目瞪口呆'], - 'disgusted': ['恶心', '讨厌', '厌恶', '反感', '嫌弃', '恶', '嫌恶', '憎恶', '不喜欢', '烦'], - 'fearful': ['害怕', '恐惧', '惊恐', '担心', '怕', '惊吓', '惊慌', '畏惧', '胆怯', '惧'], - 'neutral': ['普通', '一般', '还行', '正常', '平静', '平淡', '一般般', '凑合', '还好', '就这样'] - } - def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) @@ -47,7 +32,8 @@ class EmojiManager: def __init__(self): self.db = Database.get_instance() self._scan_task = None - self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50) + self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) + self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度,更少的token(后续可以根据情绪来调整温度) def _ensure_emoji_dir(self): """确保表情存储目录存在""" @@ -64,7 +50,7 @@ class EmojiManager: # 启动时执行一次完整性检查 self.check_emoji_file_integrity() except Exception as e: - print(f"\033[1;31m[错误]\033[0m 初始化表情管理器失败: {str(e)}") + logger.error(f"初始化表情管理器失败: {str(e)}") def _ensure_db(self): """确保数据库已初始化""" @@ -74,9 +60,20 @@ class EmojiManager: raise RuntimeError("EmojiManager not initialized") def _ensure_emoji_collection(self): - """确保emoji集合存在并创建索引""" + """确保emoji集合存在并创建索引 + + 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 + + 索引的作用是加快数据库查询速度: + - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 + - tags字段的普通索引: 加快按标签搜索表情包的速度 + - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 + + 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 + """ if 'emoji' not in self.db.db.list_collection_names(): self.db.db.create_collection('emoji') + self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('tags', 1)]) self.db.db.emoji.create_index([('filename', 1)], unique=True) @@ -89,228 +86,128 @@ class EmojiManager: {'$inc': {'usage_count': 1}} ) except Exception as e: - print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}") + logger.error(f"记录表情使用失败: {str(e)}") - async def _get_emotion_from_text(self, text: str) -> List[str]: - """从文本中识别情感关键词 - Args: - text: 输入文本 - Returns: - List[str]: 匹配到的情感标签列表 - """ - try: - prompt = f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。' - - content, _ = await self.llm.generate_response(prompt) - emotion = content.strip().lower() - - if emotion in self.EMOTION_KEYWORDS: - print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}") - return [emotion] - - return ['neutral'] - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}") - return ['neutral'] - - async def get_emoji_for_emotion(self, emotion_tag: str) -> Optional[str]: - try: - self._ensure_db() - - # 构建查询条件:标签匹配任一情感 - query = {'tags': {'$in': emotion_tag}} - - # print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") - - try: - # 随机获取一个匹配的表情 - emoji = self.db.db.emoji.aggregate([ - {'$match': query}, - {'$sample': {'size': 1}} - ]).next() - print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - # 如果没有匹配的表情,从所有表情中随机选择一个 - print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") - try: - emoji = self.db.db.emoji.aggregate([ - {'$sample': {'size': 1}} - ]).next() - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") - return None - - return None - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") - return None - - async def get_emoji_for_text(self, text: str) -> Optional[str]: """根据文本内容获取相关表情包 Args: text: 输入文本 Returns: Optional[str]: 表情包文件路径,如果没有找到则返回None + + + 可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑? + 我觉得可行 + """ try: self._ensure_db() - # 获取情感标签 - emotions = await self._get_emotion_from_text(text) - print("为 ‘"+ str(text) + "’ 获取到的情感标签为:" + str(emotions)) - if not emotions: - return None - - # 构建查询条件:标签匹配任一情感 - query = {'tags': {'$in': emotions}} - print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") - print(f"\033[1;34m[调试]\033[0m 匹配到的情感: {emotions}") + # 获取文本的embedding + text_for_search= await self._get_kimoji_for_text(text) + if not text_for_search: + logger.error("无法获取文本的情绪") + return None + text_embedding = await get_embedding(text_for_search) + if not text_embedding: + logger.error("无法获取文本的embedding") + return None try: - # 随机获取一个匹配的表情 - emoji = self.db.db.emoji.aggregate([ - {'$match': query}, - {'$sample': {'size': 1}} - ]).next() - print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") - if emoji and 'path' in emoji: + # 获取所有表情包 + all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1})) + + if not all_emojis: + logger.warning("数据库中没有任何表情包") + return None + + # 计算余弦相似度并排序 + def cosine_similarity(v1, v2): + if not v1 or not v2: + return 0 + dot_product = sum(a * b for a, b in zip(v1, v2)) + norm_v1 = sum(a * a for a in v1) ** 0.5 + norm_v2 = sum(b * b for b in v2) ** 0.5 + if norm_v1 == 0 or norm_v2 == 0: + return 0 + return dot_product / (norm_v1 * norm_v2) + + # 计算所有表情包与输入文本的相似度 + emoji_similarities = [ + (emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) + for emoji in all_emojis + ] + + # 按相似度降序排序 + emoji_similarities.sort(key=lambda x: x[1], reverse=True) + + # 获取前3个最相似的表情包 + top_3_emojis = emoji_similarities[:3] + + if not top_3_emojis: + logger.warning("未找到匹配的表情包") + return None + + # 从前3个中随机选择一个 + selected_emoji, similarity = random.choice(top_3_emojis) + + if selected_emoji and 'path' in selected_emoji: # 更新使用次数 self.db.db.emoji.update_one( - {'_id': emoji['_id']}, + {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) - return emoji['path'] - except StopIteration: - # 如果没有匹配的表情,从所有表情中随机选择一个 - print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") - try: - emoji = self.db.db.emoji.aggregate([ - {'$sample': {'size': 1}} - ]).next() - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") - return None + logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})") + # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 + return selected_emoji['path'],"[ %s ]" % selected_emoji.get('discription', '无描述') + + except Exception as search_error: + logger.error(f"搜索表情包失败: {str(search_error)}") + return None return None except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") + logger.error(f"获取表情包失败: {str(e)}") return None - async def _get_emoji_tag(self, image_base64: str) -> str: + async def _get_emoji_discription(self, image_base64: str) -> str: """获取表情包的标签""" try: - prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好' + prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' - content, _ = await self.llm.generate_response_for_image(prompt, image_base64) - tag_result = content.strip().lower() - - valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"] - for tag_match in valid_tags: - if tag_match in tag_result or tag_match == tag_result: - return tag_match - print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过") + content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) + logger.debug(f"输出描述: {content}") + return content except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}") - return "skip" - - print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral") - return "skip" # 默认标签 - - async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]: - """压缩图片并返回base64编码 - Args: - image_path: 图片文件路径 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - Optional[str]: 成功返回base64编码的图片数据,失败返回None - """ - try: - file_size = os.path.getsize(image_path) - if file_size <= target_size: - # 如果文件已经小于目标大小,直接读取并返回base64 - with open(image_path, 'rb') as f: - return base64.b64encode(f.read()).decode('utf-8') - - # 打开图片 - with Image.open(image_path) as img: - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / file_size) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - for frame_idx in range(img.n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) - frames.append(new_frame) - - # 保存到缓冲区 - frames[0].save( - output_buffer, - format='GIF', - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get('duration', 100), - loop=img.info.get('loop', 0) - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): - resized_img.save(output_buffer, format='PNG', optimize=True) - else: - resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - print(f"\033[1;32m[成功]\033[0m 压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})") - - return base64.b64encode(compressed_data).decode('utf-8') - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}") + logger.error(f"获取标签失败: {str(e)}") return None + + async def _check_emoji(self, image_base64: str) -> str: + try: + prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' + content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) + logger.debug(f"输出描述: {content}") + return content + + except Exception as e: + logger.error(f"获取标签失败: {str(e)}") + return None + + async def _get_kimoji_for_text(self, text:str): + try: + prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' + + content, _ = await self.llm_emotion_judge.generate_response_async(prompt) + logger.info(f"输出描述: {content}") + return content + + except Exception as e: + logger.error(f"获取标签失败: {str(e)}") + return None + async def scan_new_emojis(self): """扫描新的表情包""" try: @@ -329,41 +226,51 @@ class EmojiManager: continue # 压缩图片并获取base64编码 - image_base64 = await self._compress_image(image_path) + image_base64 = image_path_to_base64(image_path) if image_base64 is None: os.remove(image_path) continue - # 获取表情包的情感标签 - tag = await self._get_emoji_tag(image_base64) - if not tag == "skip": + # 获取表情包的描述 + discription = await self._get_emoji_discription(image_base64) + if global_config.EMOJI_CHECK: + check = await self._check_emoji(image_base64) + if '是' not in check: + os.remove(image_path) + logger.info(f"描述: {discription}") + logger.info(f"其不满足过滤规则,被剔除 {check}") + continue + logger.info(f"check通过 {check}") + embedding = await get_embedding(discription) + if discription is not None: # 准备数据库记录 emoji_record = { 'filename': filename, 'path': image_path, - 'tags': [tag], + 'embedding':embedding, + 'discription': discription, 'timestamp': int(time.time()) } # 保存到数据库 self.db.db['emoji'].insert_one(emoji_record) - print(f"\033[1;32m[成功]\033[0m 注册新表情包: {filename}") - print(f"标签: {tag}") + logger.success(f"注册新表情包: {filename}") + logger.info(f"描述: {discription}") else: - print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}") + logger.warning(f"跳过表情包: {filename}") except Exception as e: - print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}") - import traceback - print(traceback.format_exc()) - + logger.error(f"扫描表情包失败: {str(e)}") + logger.error(traceback.format_exc()) + async def _periodic_scan(self, interval_MINS: int = 10): """定期扫描新表情包""" while True: - print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...") + print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...") await self.scan_new_emojis() await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 + def check_emoji_file_integrity(self): """检查表情包文件完整性 如果文件已被删除,则从数据库中移除对应记录 @@ -378,44 +285,42 @@ class EmojiManager: for emoji in all_emojis: try: if 'path' not in emoji: - print(f"\033[1;33m[提示]\033[0m 发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") + logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") + self.db.db.emoji.delete_one({'_id': emoji['_id']}) + removed_count += 1 + continue + + if 'embedding' not in emoji: + logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") self.db.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue # 检查文件是否存在 if not os.path.exists(emoji['path']): - print(f"\033[1;33m[提示]\033[0m 表情包文件已被删除: {emoji['path']}") + logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) if result.deleted_count > 0: - print(f"\033[1;32m[成功]\033[0m 成功删除数据库记录: {emoji['_id']}") + logger.success(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 else: - print(f"\033[1;31m[错误]\033[0m 删除数据库记录失败: {emoji['_id']}") + logger.error(f"删除数据库记录失败: {emoji['_id']}") except Exception as item_error: - print(f"\033[1;31m[错误]\033[0m 处理表情包记录时出错: {str(item_error)}") + logger.error(f"处理表情包记录时出错: {str(item_error)}") continue # 验证清理结果 remaining_count = self.db.db.emoji.count_documents({}) if removed_count > 0: - print(f"\033[1;32m[成功]\033[0m 已清理 {removed_count} 个失效的表情包记录") - print(f"\033[1;34m[统计]\033[0m 清理前总数: {total_count} | 清理后总数: {remaining_count}") - # print(f"\033[1;34m[统计]\033[0m 应删除数量: {removed_count} | 实际删除数量: {total_count - remaining_count}") - # 执行数据库压缩 - try: - self.db.db.command({"compact": "emoji"}) - print(f"\033[1;32m[成功]\033[0m 数据库集合压缩完成") - except Exception as compact_error: - print(f"\033[1;31m[错误]\033[0m 数据库压缩失败: {str(compact_error)}") + logger.success(f"已清理 {removed_count} 个失效的表情包记录") + logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") else: - print(f"\033[1;36m[表情包]\033[0m 已检查 {total_count} 个表情包记录") + logger.info(f"已检查 {total_count} 个表情包记录") except Exception as e: - print(f"\033[1;31m[错误]\033[0m 检查表情包完整性失败: {str(e)}") - import traceback - print(f"\033[1;31m[错误追踪]\033[0m\n{traceback.format_exc()}") + logger.error(f"检查表情包完整性失败: {str(e)}") + logger.error(traceback.format_exc()) async def start_periodic_check(self, interval_MINS: int = 120): while True: diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 04f2e73ad..1ac421e6b 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -1,19 +1,16 @@ -from typing import Dict, Any, List, Optional, Union, Tuple -from openai import OpenAI -import asyncio -from functools import partial -from .message import Message -from .config import global_config -from ...common.database import Database import random import time -import numpy as np -from .relationship_manager import relationship_manager -from .prompt_builder import prompt_builder -from .config import global_config -from .utils import process_llm_response +from typing import List, Optional, Tuple, Union + from nonebot import get_driver + +from ...common.database import Database from ..models.utils_model import LLM_request +from .config import global_config +from .message import Message +from .prompt_builder import prompt_builder +from .relationship_manager import relationship_manager +from .utils import process_llm_response driver = get_driver() config = driver.config @@ -21,9 +18,10 @@ config = driver.config class ResponseGenerator: def __init__(self): - self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000) + self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000) self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000) + self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000) self.db = Database.get_instance() self.current_model_type = 'r1' # 默认使用 R1 @@ -44,19 +42,15 @@ class ResponseGenerator: print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") model_response = await self._generate_response_with_model(message, current_model) + raw_content=model_response if model_response: print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') - model_response, emotion = await self._process_response(model_response) + model_response = await self._process_response(model_response) if model_response: - print(f"为 '{model_response}' 获取到的情感标签为:{emotion}") - valuedict={ - 'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25 - } - await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) - return model_response, emotion - return None, [] + return model_response ,raw_content + return None,raw_content async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]: """使用指定的模型生成回复""" @@ -67,10 +61,11 @@ class ResponseGenerator: # 获取关系值 relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0 if relationship_value != 0.0: - print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + # print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + pass # 构建prompt - prompt, prompt_check = prompt_builder._build_prompt( + prompt, prompt_check = await prompt_builder._build_prompt( message_txt=message.processed_plain_text, sender_name=sender_name, relationship_value=relationship_value, @@ -142,7 +137,7 @@ class ResponseGenerator: 内容:{content} 输出: ''' - content, _ = await self.model_v3.generate_response(prompt) + content, _ = await self.model_v25.generate_response(prompt) content=content.strip() if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']: return [content] @@ -158,10 +153,9 @@ class ResponseGenerator: if not content: return None, [] - emotion_tags = await self._get_emotion_tags(content) processed_response = process_llm_response(content) - return processed_response, emotion_tags + return processed_response class InitiativeMessageGenerate: @@ -197,6 +191,6 @@ class InitiativeMessageGenerate: prompt = prompt_builder._build_initiative_prompt( select_dot, prompt_template, memory ) - content, reasoning = self.model_r1.generate_response(prompt) + content, reasoning = self.model_r1.generate_response_async(prompt) print(f"[DEBUG] {content} {reasoning}") return content diff --git a/src/plugins/chat/mapper.py b/src/plugins/chat/mapper.py new file mode 100644 index 000000000..67fa801e2 --- /dev/null +++ b/src/plugins/chat/mapper.py @@ -0,0 +1,26 @@ +emojimapper = {5: "流泪", 311: "打 call", 312: "变形", 314: "仔细分析", 317: "菜汪", 318: "崇拜", 319: "比心", + 320: "庆祝", 324: "吃糖", 325: "惊吓", 337: "花朵脸", 338: "我想开了", 339: "舔屏", 341: "打招呼", + 342: "酸Q", 343: "我方了", 344: "大怨种", 345: "红包多多", 346: "你真棒棒", 181: "戳一戳", 74: "太阳", + 75: "月亮", 351: "敲敲", 349: "坚强", 350: "贴贴", 395: "略略略", 114: "篮球", 326: "生气", 53: "蛋糕", + 137: "鞭炮", 333: "烟花", 424: "续标识", 415: "划龙舟", 392: "龙年快乐", 425: "求放过", 427: "偷感", + 426: "玩火", 419: "火车", 429: "蛇年快乐", + 14: "微笑", 1: "撇嘴", 2: "色", 3: "发呆", 4: "得意", 6: "害羞", 7: "闭嘴", 8: "睡", 9: "大哭", + 10: "尴尬", 11: "发怒", 12: "调皮", 13: "呲牙", 0: "惊讶", 15: "难过", 16: "酷", 96: "冷汗", 18: "抓狂", + 19: "吐", 20: "偷笑", 21: "可爱", 22: "白眼", 23: "傲慢", 24: "饥饿", 25: "困", 26: "惊恐", 27: "流汗", + 28: "憨笑", 29: "悠闲", 30: "奋斗", 31: "咒骂", 32: "疑问", 33: "嘘", 34: "晕", 35: "折磨", 36: "衰", + 37: "骷髅", 38: "敲打", 39: "再见", 97: "擦汗", 98: "抠鼻", 99: "鼓掌", 100: "糗大了", 101: "坏笑", + 102: "左哼哼", 103: "右哼哼", 104: "哈欠", 105: "鄙视", 106: "委屈", 107: "快哭了", 108: "阴险", + 305: "右亲亲", 109: "左亲亲", 110: "吓", 111: "可怜", 172: "眨眼睛", 182: "笑哭", 179: "doge", + 173: "泪奔", 174: "无奈", 212: "托腮", 175: "卖萌", 178: "斜眼笑", 177: "喷血", 176: "小纠结", + 183: "我最美", 262: "脑阔疼", 263: "沧桑", 264: "捂脸", 265: "辣眼睛", 266: "哦哟", 267: "头秃", + 268: "问号脸", 269: "暗中观察", 270: "emm", 271: "吃瓜", 272: "呵呵哒", 277: "汪汪", 307: "喵喵", + 306: "牛气冲天", 281: "无眼笑", 282: "敬礼", 283: "狂笑", 284: "面无表情", 285: "摸鱼", 293: "摸锦鲤", + 286: "魔鬼笑", 287: "哦", 289: "睁眼", 294: "期待", 297: "拜谢", 298: "元宝", 299: "牛啊", 300: "胖三斤", + 323: "嫌弃", 332: "举牌牌", 336: "豹富", 353: "拜托", 355: "耶", 356: "666", 354: "尊嘟假嘟", 352: "咦", + 357: "裂开", 334: "虎虎生威", 347: "大展宏兔", 303: "右拜年", 302: "左拜年", 295: "拿到红包", 49: "拥抱", + 66: "爱心", 63: "玫瑰", 64: "凋谢", 187: "幽灵", 146: "爆筋", 116: "示爱", 67: "心碎", 60: "咖啡", + 185: "羊驼", 76: "赞", 124: "OK", 118: "抱拳", 78: "握手", 119: "勾引", 79: "胜利", 120: "拳头", + 121: "差劲", 77: "踩", 123: "NO", 201: "点赞", 273: "我酸了", 46: "猪头", 112: "菜刀", 56: "刀", + 169: "手枪", 171: "茶", 59: "便便", 144: "喝彩", 147: "棒棒糖", 89: "西瓜", 41: "发抖", 125: "转圈", + 42: "爱情", 43: "跳跳", 86: "怄火", 129: "挥手", 85: "飞吻", 428: "收到", + 423: "复兴号", 432: "灵蛇献瑞"} diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index d6e400e15..f1fc5569d 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -1,16 +1,12 @@ -from dataclasses import dataclass -from typing import List, Optional, Dict, Tuple, ForwardRef import time -import jieba.analyse as jieba_analyse -import os -from datetime import datetime -from ...common.database import Database -from PIL import Image -from .config import global_config +from dataclasses import dataclass +from typing import Dict, ForwardRef, List, Optional + import urllib3 -from .utils_user import get_user_nickname,get_user_cardname,get_groupname + +from .cq_code import CQCode, cq_code_tool from .utils_cq import parse_cq_code -from .cq_code import cq_code_tool,CQCode +from .utils_user import get_groupname, get_user_cardname, get_user_nickname Message = ForwardRef('Message') # 添加这行 # 禁用SSL警告 @@ -27,58 +23,66 @@ class Message: """消息数据类""" message_id: int = None time: float = None - + group_id: int = None - group_name: str = None # 群名称 - + group_name: str = None # 群名称 + user_id: int = None user_nickname: str = None # 用户昵称 - user_cardname: str=None # 用户群昵称 - - raw_message: str = None # 原始消息,包含未解析的cq码 - plain_text: str = None # 纯文本 - + user_cardname: str = None # 用户群昵称 + + raw_message: str = None # 原始消息,包含未解析的cq码 + plain_text: str = None # 纯文本 + + reply_message: Dict = None # 存储 回复的 源消息 + + # 延迟初始化字段 + _initialized: bool = False message_segments: List[Dict] = None # 存储解析后的消息片段 processed_plain_text: str = None # 用于存储处理后的plain_text detailed_plain_text: str = None # 用于存储详细可读文本 - - reply_message: Dict = None # 存储 回复的 源消息 - - is_emoji: bool = False # 是否是表情包 - has_emoji: bool = False # 是否包含表情包 - - translate_cq: bool = True # 是否翻译cq码 - - def __post_init__(self): - if self.time is None: - self.time = int(time.time()) - - if not self.group_name: - self.group_name = get_groupname(self.group_id) - - if not self.user_nickname: - self.user_nickname = get_user_nickname(self.user_id) - - if not self.user_cardname: - self.user_cardname=get_user_cardname(self.user_id) - - if not self.processed_plain_text: - if self.raw_message: - self.message_segments = self.parse_message_segments(str(self.raw_message)) + + # 状态标志 + is_emoji: bool = False + has_emoji: bool = False + translate_cq: bool = True + + async def initialize(self): + """显式异步初始化方法(必须调用)""" + if self._initialized: + return + + # 异步获取补充信息 + self.group_name = self.group_name or get_groupname(self.group_id) + self.user_nickname = self.user_nickname or get_user_nickname(self.user_id) + self.user_cardname = self.user_cardname or get_user_cardname(self.user_id) + + # 消息解析 + if self.raw_message: + if not isinstance(self,Message_Sending): + self.message_segments = await self.parse_message_segments(self.raw_message) self.processed_plain_text = ' '.join( seg.translated_plain_text for seg in self.message_segments ) - #将详细翻译为详细可读文本 + + # 构建详细文本 + if self.time is None: + self.time = int(time.time()) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) - try: - name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" - except: - name = self.user_nickname or f"用户{self.user_id}" - content = self.processed_plain_text - self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" + name = ( + f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" + if self.user_cardname + else f"{self.user_nickname or f'用户{self.user_id}'}" + ) + if isinstance(self,Message_Sending) and self.is_emoji: + self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n" + else: + self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n" + + self._initialized = True - def parse_message_segments(self, message: str) -> List[CQCode]: + async def parse_message_segments(self, message: str) -> List[CQCode]: """ 将消息解析为片段列表,包括纯文本和CQ码 返回的列表中每个元素都是字典,包含: @@ -136,7 +140,7 @@ class Message: #翻译作为字典的CQ码 for _code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) + message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) trans_list.append(message_obj) return trans_list @@ -169,6 +173,8 @@ class Message_Sending(Message): reply_message_id: int = None # 存储 回复的 源消息ID + is_head: bool = False # 是否是头部消息 + def update_thinking_time(self): self.thinking_time = round(time.time(), 2) - self.thinking_start_time return self.thinking_time diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 970fd3682..050c59d74 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -1,14 +1,15 @@ -from typing import Union, List, Optional, Dict -from collections import deque -from .message import Message, Message_Thinking, MessageSet, Message_Sending -import time import asyncio +import time +from typing import Dict, List, Optional, Union + from nonebot.adapters.onebot.v11 import Bot -from .config import global_config -from .storage import MessageStorage + from .cq_code import cq_code_tool -import random +from .message import Message, Message_Sending, Message_Thinking, MessageSet +from .storage import MessageStorage from .utils import calculate_typing_time +from .config import global_config + class Message_Sender: """发送器""" @@ -103,7 +104,7 @@ class MessageContainer: def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: """添加消息到队列""" - print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") + # print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") if isinstance(message, MessageSet): for single_message in message.messages: self.messages.append(single_message) @@ -156,26 +157,25 @@ class MessageManager: #最早的对象,可能是思考消息,也可能是发送消息 message_earliest = container.get_earliest_message() #一个message_thinking or message_sending - #一个月后删了 - if not message_earliest: - print(f"\033[1;34m[BUG,如果出现这个,说明有BUG,3月4日留]\033[0m ") - return - #如果是思考消息 if isinstance(message_earliest, Message_Thinking): #优先等待这条消息 message_earliest.update_thinking_time() thinking_time = message_earliest.thinking_time - print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒") + print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒\033[K\r", end='', flush=True) + + # 检查是否超时 + if thinking_time > global_config.thinking_timeout: + print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息") + container.remove_message(message_earliest) else:# 如果不是message_thinking就只能是message_sending print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") #直接发,等什么呢 - if message_earliest.update_thinking_time() < 30: - await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) - else: + if message_earliest.is_head and message_earliest.update_thinking_time() >30: await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id) - - #移除消息 + else: + await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) + #移除消息 if message_earliest.is_emoji: message_earliest.processed_plain_text = "[表情包]" await self.storage.store_message(message_earliest, None) @@ -192,10 +192,11 @@ class MessageManager: try: #发送 - if msg.update_thinking_time() < 30: - await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) - else: + if msg.is_head and msg.update_thinking_time() >30: await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) + else: + await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) + #如果是表情包,则替换为"[表情包]" if msg.is_emoji: @@ -204,7 +205,7 @@ class MessageManager: # 安全地移除消息 if not container.remove_message(msg): - print(f"\033[1;33m[警告]\033[0m 尝试删除不存在的消息") + print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") except Exception as e: print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}") continue diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index ba22a403d..e337cef45 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -1,13 +1,13 @@ -import time import random -from ..schedule.schedule_generator import bot_schedule -import os -from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text +import time +from typing import Optional + from ...common.database import Database +from ..memory_system.memory import hippocampus, memory_graph +from ..moods.moods import MoodManager +from ..schedule.schedule_generator import bot_schedule from .config import global_config -from .topic_identifier import topic_identifier -from ..memory_system.memory import memory_graph -from random import choice +from .utils import get_embedding, get_recent_group_detailed_plain_text class PromptBuilder: @@ -16,11 +16,13 @@ class PromptBuilder: self.activate_messages = '' self.db = Database.get_instance() - def _build_prompt(self, + + + async def _build_prompt(self, message_txt: str, sender_name: str = "某人", relationship_value: float = 0.0, - group_id: int = None) -> str: + group_id: Optional[int] = None) -> tuple[str, str]: """构建prompt Args: @@ -31,60 +33,7 @@ class PromptBuilder: Returns: str: 构建好的prompt - """ - - - memory_prompt = '' - start_time = time.time() # 记录开始时间 - # topic = await topic_identifier.identify_topic_llm(message_txt) - topic = topic_identifier.identify_topic_snownlp(message_txt) - - # print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}") - - all_first_layer_items = [] # 存储所有第一层记忆 - all_second_layer_items = {} # 用字典存储每个topic的第二层记忆 - overlapping_second_layer = set() # 存储重叠的第二层记忆 - - if topic: - # 遍历所有topic - for current_topic in topic: - first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) - # if first_layer_items: - # print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") - - # 记录第一层数据 - all_first_layer_items.extend(first_layer_items) - - # 记录第二层数据 - all_second_layer_items[current_topic] = second_layer_items - - # 检查是否有重叠的第二层数据 - for other_topic, other_second_layer in all_second_layer_items.items(): - if other_topic != current_topic: - # 找到重叠的记忆 - overlap = set(second_layer_items) & set(other_second_layer) - if overlap: - # print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}") - overlapping_second_layer.update(overlap) - - selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else [] - selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else [] - - # 合并并去重 - all_memories = list(set(selected_first_layer + selected_second_layer)) - if all_memories: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}") - random_item = " ".join(all_memories) - memory_prompt = f"看到这些聊天,你想起来{random_item}\n" - else: - memory_prompt = "" # 如果没有记忆,则返回空字符串 - - end_time = time.time() # 记录结束时间 - print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") # 输出耗时 - - - - + """ #先禁用关系 if 0 > 30: relation_prompt = "关系特别特别好,你很喜欢喜欢他" @@ -98,6 +47,12 @@ class PromptBuilder: #开始构建prompt + + #心情 + mood_manager = MoodManager.get_instance() + mood_prompt = mood_manager.get_prompt() + + #日程构建 current_date = time.strftime("%Y-%m-%d", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime()) @@ -109,49 +64,83 @@ class PromptBuilder: prompt_info = '' promt_info_prompt = '' - prompt_info = self.get_prompt_info(message_txt,threshold=0.5) + prompt_info = await self.get_prompt_info(message_txt,threshold=0.5) if prompt_info: prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' - # promt_info_prompt = '你有一些[知识],在上面可以参考。' end_time = time.time() print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒") - # print(f"\033[1;34m[调试]\033[0m 获取知识库内容结果: {prompt_info}") - - # print(f"\033[1;34m[调试信息]\033[0m 正在构建聊天上下文") - + # 获取聊天上下文 chat_talking_prompt = '' if group_id: chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" - # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") + + + # 使用新的记忆获取方法 + memory_prompt = '' + start_time = time.time() + + # 调用 hippocampus 的 get_relevant_memories 方法 + relevant_memories = await hippocampus.get_relevant_memories( + text=message_txt, + max_topics=5, + similarity_threshold=0.4, + max_memory_num=5 + ) + + if relevant_memories: + # 格式化记忆内容 + memory_items = [] + for memory in relevant_memories: + memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}") + + memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n" + + # 打印调试信息 + print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:") + for memory in relevant_memories: + print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}") + + end_time = time.time() + print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") + + + #激活prompt构建 activate_prompt = '' - activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。" + activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。" #检测机器人相关词汇 bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords) if is_bot: - is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。' + is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认' else: is_bot_prompt = '' #人格选择 personality=global_config.PROMPT_PERSONALITY + probability_1 = global_config.PERSONALITY_1 + probability_2 = global_config.PERSONALITY_2 + probability_3 = global_config.PERSONALITY_3 prompt_personality = '' personality_choice = random.random() - if personality_choice < 4/6: # 第一种人格 + if personality_choice < probability_1: # 第一种人格 prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt} 请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。''' - elif personality_choice < 1: # 第二种人格 + elif personality_choice < probability_1 + probability_2: # 第二种人格 prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt} 请你表达自己的见解和观点。可以有个性。''' + else: # 第三种人格 + prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}, 你正在浏览qq群,{promt_info_prompt}, + 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt} + 请你表达自己的见解和观点。可以有个性。''' #中文高手(新加的好玩功能) prompt_ger = '' @@ -162,36 +151,28 @@ class PromptBuilder: if random.random() < 0.01: prompt_ger += '你喜欢用文言文' - #额外信息要求 - extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' - - + extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' #合并prompt prompt = "" prompt += f"{prompt_info}\n" prompt += f"{prompt_date}\n" prompt += f"{chat_talking_prompt}\n" - - # prompt += f"{memory_prompt}\n" - - # prompt += f"{activate_prompt}\n" prompt += f"{prompt_personality}\n" prompt += f"{prompt_ger}\n" prompt += f"{extra_info}\n" - - '''读空气prompt处理''' - activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" prompt_personality_check = '' extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。" - if personality_choice < 4/6: # 第一种人格 + if personality_choice < probability_1: # 第一种人格 prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' - elif personality_choice < 1: # 第二种人格 + elif personality_choice < probability_1 + probability_2: # 第二种人格 prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + else: # 第三种人格 + prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[2]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" @@ -219,14 +200,16 @@ class PromptBuilder: #激活prompt构建 activate_prompt = '' - activate_prompt = f"以上是群里正在进行的聊天。" + activate_prompt = "以上是群里正在进行的聊天。" personality=global_config.PROMPT_PERSONALITY prompt_personality = '' personality_choice = random.random() - if personality_choice < 4/6: # 第一种人格 + if personality_choice < probability_1: # 第一种人格 prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}''' - elif personality_choice < 1: # 第二种人格 + elif personality_choice < probability_1 + probability_2: # 第二种人格 prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}''' + else: # 第三种人格 + prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}''' topics_str=','.join(f"\"{topics}\"") prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)" @@ -247,10 +230,10 @@ class PromptBuilder: return prompt_for_initiative - def get_prompt_info(self,message:str,threshold:float): + async def get_prompt_info(self,message:str,threshold:float): related_info = '' print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - embedding = get_embedding(message) + embedding = await get_embedding(message) related_info += self.get_info_from_db(embedding,threshold=threshold) return related_info diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index 29a4334e8..4ed7a2f11 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -1,8 +1,8 @@ -import time -from ...common.database import Database -from nonebot.adapters.onebot.v11 import Bot -from typing import Optional, Tuple import asyncio +from typing import Optional + +from ...common.database import Database + class Impression: traits: str = None @@ -123,7 +123,7 @@ class RelationshipManager: print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") while True: - print(f"\033[1;32m[关系管理]\033[0m 正在自动保存关系") + print("\033[1;32m[关系管理]\033[0m 正在自动保存关系") await asyncio.sleep(300) # 等待300秒(5分钟) await self._save_all_relationships() diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index 08b52b7ca..6a87480b7 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -1,10 +1,8 @@ -from typing import Dict, List, Any, Optional -import time -import threading -from collections import defaultdict -import asyncio -from .message import Message +from typing import Optional + from ...common.database import Database +from .message import Message + class MessageStorage: def __init__(self): diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 812d4e321..3296d0895 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -1,21 +1,17 @@ -from typing import Optional, Dict, List -from openai import OpenAI -from .message import Message -import jieba +from typing import List, Optional + from nonebot import get_driver -from .config import global_config -from snownlp import SnowNLP + from ..models.utils_model import LLM_request +from .config import global_config driver = get_driver() config = driver.config class TopicIdentifier: def __init__(self): - self.llm_client = LLM_request(model=global_config.llm_topic_extract) - self.select=global_config.topic_extract + self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge) - async def identify_topic_llm(self, text: str) -> Optional[List[str]]: """识别消息主题,返回主题列表""" @@ -26,10 +22,10 @@ class TopicIdentifier: 消息内容:{text}""" # 使用 LLM_request 类进行请求 - topic, _ = await self.llm_client.generate_response(prompt) + topic, _ = await self.llm_topic_judge.generate_response(prompt) if not topic: - print(f"\033[1;31m[错误]\033[0m LLM API 返回为空") + print("\033[1;31m[错误]\033[0m LLM API 返回为空") return None # 直接在这里处理主题解析 @@ -42,25 +38,4 @@ class TopicIdentifier: print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}") return topic_list if topic_list else None - def identify_topic_snownlp(self, text: str) -> Optional[List[str]]: - """使用 SnowNLP 进行主题识别 - - Args: - text (str): 需要识别主题的文本 - - Returns: - Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None - """ - if not text or len(text.strip()) == 0: - return None - - try: - s = SnowNLP(text) - # 提取前3个关键词作为主题 - keywords = s.keywords(5) - return keywords if keywords else None - except Exception as e: - print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}") - return None - topic_identifier = TopicIdentifier() \ No newline at end of file diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index aa16268ef..b2583e86f 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -1,16 +1,17 @@ -import time -import random -from typing import List -from .message import Message -import requests -import numpy as np -from .config import global_config -import re -from typing import Dict -from collections import Counter import math +import random +import time +from collections import Counter +from typing import Dict, List + +import jieba +import numpy as np from nonebot import get_driver + from ..models.utils_model import LLM_request +from ..utils.typo_generator import ChineseTypoGenerator +from .config import global_config +from .message import Message driver = get_driver() config = driver.config @@ -30,16 +31,18 @@ def combine_messages(messages: List[Message]) -> str: time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time)) name = message.user_nickname or f"用户{message.user_id}" content = message.processed_plain_text or message.plain_text - + result += f"[{time_str}] {name}: {content}\n" - + return result -def db_message_to_str (message_dict: Dict) -> str: + +def db_message_to_str(message_dict: Dict) -> str: print(f"message_dict: {message_dict}") time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) try: - name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", "")) + name = "[(%s)%s]%s" % ( + message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", "")) except: name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" content = message_dict.get("processed_plain_text", "") @@ -56,6 +59,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool: return True return False + def is_mentioned_bot_in_txt(message: str) -> bool: """检查消息是否提到了机器人""" keywords = [global_config.BOT_NICKNAME] @@ -64,10 +68,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool: return True return False -def get_embedding(text): + +async def get_embedding(text): """获取文本的embedding向量""" llm = LLM_request(model=global_config.embedding) - return llm.get_embedding_sync(text) + # return llm.get_embedding_sync(text) + return await llm.get_embedding(text) + def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) @@ -75,52 +82,55 @@ def cosine_similarity(v1, v2): norm2 = np.linalg.norm(v2) return dot_product / (norm1 * norm2) + def calculate_information_content(text): """计算文本的信息量(熵)""" char_count = Counter(text) total_chars = len(text) - + entropy = 0 for count in char_count.values(): probability = count / total_chars entropy -= probability * math.log2(probability) - + return entropy + def get_cloest_chat_from_db(db, length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数""" chat_text = '' closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) - - if closest_record and closest_record.get('memorized', 0) < 4: + + if closest_record and closest_record.get('memorized', 0) < 4: closest_time = closest_record['time'] group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_records = list(db.db.messages.find( {"time": {"$gt": closest_time}, "group_id": group_id} ).sort('time', 1).limit(length)) - + # 更新每条消息的memorized属性 for record in chat_records: # 检查当前记录的memorized值 current_memorized = record.get('memorized', 0) - if current_memorized > 3: + if current_memorized > 3: # print(f"消息已读取3次,跳过") return '' - + # 更新memorized值 db.db.messages.update_one( {"_id": record["_id"]}, {"$set": {"memorized": current_memorized + 1}} ) - + chat_text += record["detailed_plain_text"] - + return chat_text - print(f"消息已读取3次,跳过") + # print(f"消息已读取3次,跳过") return '' -def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: + +async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 Args: @@ -132,7 +142,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: list: Message对象列表,按时间正序排列 """ - # 从数据库获取最近消息 + # 从数据库获取最近消息 recent_messages = list(db.db.messages.find( {"group_id": group_id}, # { @@ -147,7 +157,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: if not recent_messages: return [] - + # 转换为 Message对象列表 from .message import Message message_objects = [] @@ -162,16 +172,18 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: processed_plain_text=msg_data.get("processed_text", ""), group_id=group_id ) + await msg.initialize() message_objects.append(msg) except KeyError: print("[WARNING] 数据库中存在无效的消息") continue - + # 按时间正序排列 message_objects.reverse() return message_objects -def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False): + +def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False): recent_messages = list(db.db.messages.find( {"group_id": group_id}, { @@ -185,16 +197,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb if not recent_messages: return [] - + message_detailed_plain_text = '' message_detailed_plain_text_list = [] - + # 反转消息列表,使最新的消息在最后 recent_messages.reverse() - + if combine: for msg_db_data in recent_messages: - message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"]) + message_detailed_plain_text += str(msg_db_data["detailed_plain_text"]) return message_detailed_plain_text else: for msg_db_data in recent_messages: @@ -202,7 +214,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb return message_detailed_plain_text_list - def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: """将文本分割成句子,但保持书名号中的内容完整 Args: @@ -222,30 +233,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: split_strength = 0.7 else: split_strength = 0.9 - #先移除换行符 + # 先移除换行符 # print(f"split_strength: {split_strength}") - + # print(f"处理前的文本: {text}") - + # 统一将英文逗号转换为中文逗号 text = text.replace(',', ',') text = text.replace('\n', ' ') - + # print(f"处理前的文本: {text}") - + text_no_1 = '' for letter in text: # print(f"当前字符: {letter}") - if letter in ['!','!','?','?']: + if letter in ['!', '!', '?', '?']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < split_strength: letter = '' - if letter in ['。','…']: + if letter in ['。', '…']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < 1 - split_strength: letter = '' text_no_1 += letter - + # 对每个逗号单独判断是否分割 sentences = [text_no_1] new_sentences = [] @@ -274,84 +285,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: sentences_done = [] for sentence in sentences: sentence = sentence.rstrip(',,') - if random.random() < split_strength*0.5: + if random.random() < split_strength * 0.5: sentence = sentence.replace(',', '').replace(',', '') elif random.random() < split_strength: sentence = sentence.replace(',', ' ').replace(',', ' ') sentences_done.append(sentence) - + print(f"处理后的句子: {sentences_done}") return sentences_done -# 常见的错别字映射 -TYPO_DICT = { - '的': '地得', - '了': '咯啦勒', - '吗': '嘛麻', - '吧': '八把罢', - '是': '事', - '在': '再在', - '和': '合', - '有': '又', - '我': '沃窝喔', - '你': '泥尼拟', - '他': '它她塔祂', - '们': '门', - '啊': '阿哇', - '呢': '呐捏', - '都': '豆读毒', - '很': '狠', - '会': '回汇', - '去': '趣取曲', - '做': '作坐', - '想': '相像', - '说': '说税睡', - '看': '砍堪刊', - '来': '来莱赖', - '好': '号毫豪', - '给': '给既继', - '过': '锅果裹', - '能': '嫩', - '为': '位未', - '什': '甚深伸', - '么': '末麽嘛', - '话': '话花划', - '知': '织直值', - '道': '到', - '听': '听停挺', - '见': '见件建', - '觉': '觉脚搅', - '得': '得德锝', - '着': '着找招', - '像': '向象想', - '等': '等灯登', - '谢': '谢写卸', - '对': '对队', - '里': '里理鲤', - '啦': '啦拉喇', - '吃': '吃持迟', - '哦': '哦喔噢', - '呀': '呀压', - '要': '药', - '太': '太抬台', - '快': '块', - '点': '店', - '以': '以已', - '因': '因应', - '啥': '啥沙傻', - '行': '行型形', - '哈': '哈蛤铪', - '嘿': '嘿黑嗨', - '嗯': '嗯恩摁', - '哎': '哎爱埃', - '呜': '呜屋污', - '喂': '喂位未', - '嘛': '嘛麻马', - '嗨': '嗨害亥', - '哇': '哇娃蛙', - '咦': '咦意易', - '嘻': '嘻西希' -} + def random_remove_punctuation(text: str) -> str: """随机处理标点符号,模拟人类打字习惯 @@ -364,7 +307,7 @@ def random_remove_punctuation(text: str) -> str: """ result = '' text_len = len(text) - + for i, char in enumerate(text): if char == '。' and i == text_len - 1: # 结尾的句号 if random.random() > 0.4: # 80%概率删除结尾句号 @@ -379,32 +322,30 @@ def random_remove_punctuation(text: str) -> str: result += char return result -def add_typos(text: str) -> str: - TYPO_RATE = 0.02 # 控制错别字出现的概率(2%) - result = "" - for char in text: - if char in TYPO_DICT and random.random() < TYPO_RATE: - # 从可能的错别字中随机选择一个 - typos = TYPO_DICT[char] - result += random.choice(typos) - else: - result += char - return result + def process_llm_response(text: str) -> List[str]: # processed_response = process_text_with_typos(content) - if len(text) > 200: - print(f"回复过长 ({len(text)} 字符),返回默认回复") - return ['懒得说'] + if len(text) > 300: + print(f"回复过长 ({len(text)} 字符),返回默认回复") + return ['懒得说'] # 处理长消息 - sentences = split_into_sentences_w_remove_punctuation(add_typos(text)) + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.2, + word_replace_rate=0.02 + ) + typoed_text = typo_generator.create_typo_sentence(text)[0] + sentences = split_into_sentences_w_remove_punctuation(typoed_text) # 检查分割后的消息数量是否过多(超过3条) - if len(sentences) > 3: + if len(sentences) > 4: print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") return [f'{global_config.BOT_NICKNAME}不知道哦'] - + return sentences + def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float: """ 计算输入字符串所需的时间,中文和英文字符有不同的输入时间 @@ -417,7 +358,46 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_ if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符 total_time += chinese_time else: # 其他字符(如英文) - total_time += english_time + total_time += english_time return total_time +def cosine_similarity(v1, v2): + """计算余弦相似度""" + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 == 0 or norm2 == 0: + return 0 + return dot_product / (norm1 * norm2) + + +def text_to_vector(text): + """将文本转换为词频向量""" + # 分词 + words = jieba.lcut(text) + # 统计词频 + word_freq = Counter(words) + return word_freq + + +def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: + """使用简单的余弦相似度计算文本相似度""" + # 将输入文本转换为词频向量 + text_vector = text_to_vector(text) + + # 计算每个主题的相似度 + similarities = [] + for topic in topics: + topic_vector = text_to_vector(topic) + # 获取所有唯一词 + all_words = set(text_vector.keys()) | set(topic_vector.keys()) + # 构建向量 + v1 = [text_vector.get(word, 0) for word in all_words] + v2 = [topic_vector.get(word, 0) for word in all_words] + # 计算相似度 + similarity = cosine_similarity(v1, v2) + similarities.append((topic, similarity)) + + # 按相似度降序排序并返回前k个 + return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 922ab5228..8a8b3ce5a 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -1,13 +1,15 @@ -import io -from PIL import Image -import hashlib -import time -import os -from ...common.database import Database -import zlib # 用于 CRC32 import base64 -from nonebot import get_driver +import io +import os +import time +import zlib # 用于 CRC32 + from loguru import logger +from nonebot import get_driver +from PIL import Image + +from ...common.database import Database +from ..chat.config import global_config driver = get_driver() config = driver.config @@ -118,7 +120,7 @@ def storage_compress_image(base64_data: str, max_size: int = 200) -> str: # 保存记录 collection.insert_one(image_record) - print(f"\033[1;32m[成功]\033[0m 保存图片记录到数据库") + print("\033[1;32m[成功]\033[0m 保存图片记录到数据库") except Exception as db_error: print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}") @@ -143,6 +145,8 @@ def storage_emoji(image_data: bytes) -> bytes: Returns: bytes: 原始图片数据 """ + if not global_config.EMOJI_SAVE: + return image_data try: # 使用 CRC32 计算哈希值 hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') @@ -227,7 +231,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 image_data = base64.b64decode(base64_data) # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= target_size: + if len(image_data) <= 2*1024*1024: return base64_data # 将字节数据转换为图片对象 @@ -252,7 +256,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 for frame_idx in range(img.n_frames): img.seek(frame_idx) new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) + new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折 frames.append(new_frame) # 保存到缓冲区 @@ -286,4 +290,19 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 logger.error(f"压缩图片失败: {str(e)}") import traceback logger.error(traceback.format_exc()) - return base64_data \ No newline at end of file + return base64_data + +def image_path_to_base64(image_path: str) -> str: + """将图片路径转换为base64编码 + Args: + image_path: 图片文件路径 + Returns: + str: base64编码的图片数据 + """ + try: + with open(image_path, 'rb') as f: + image_data = f.read() + return base64.b64encode(image_data).decode('utf-8') + except Exception as e: + logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}") + return None \ No newline at end of file diff --git a/src/plugins/chat/utils_user.py b/src/plugins/chat/utils_user.py index bb8c30948..489eb7a1d 100644 --- a/src/plugins/chat/utils_user.py +++ b/src/plugins/chat/utils_user.py @@ -1,5 +1,6 @@ -from .relationship_manager import relationship_manager from .config import global_config +from .relationship_manager import relationship_manager + def get_user_nickname(user_id: int) -> str: if int(user_id) == int(global_config.BOT_QQ): diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index 7559406f9..001b66207 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -1,4 +1,6 @@ import asyncio +from .config import global_config + class WillingManager: def __init__(self): @@ -34,26 +36,32 @@ class WillingManager: print(f"被重复提及, 当前意愿: {current_willing}") if is_emoji: - current_willing *= 0.15 + current_willing *= 0.1 print(f"表情包, 当前意愿: {current_willing}") - if interested_rate > 0.65: - print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") - current_willing += interested_rate-0.6 + print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") + interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度 + if interested_rate > 0.4: + # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") + current_willing += interested_rate-0.4 - self.group_reply_willing[group_id] = min(current_willing, 3.0) + current_willing *= global_config.response_willing_amplifier #放大回复意愿 + # print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}") - reply_probability = max((current_willing - 0.55) * 1.9, 0) + reply_probability = max((current_willing - 0.45) * 2, 0) if group_id not in config.talk_allowed_groups: current_willing = 0 reply_probability = 0 if group_id in config.talk_frequency_down_groups: - reply_probability = reply_probability / 3.5 + reply_probability = reply_probability / global_config.down_frequency_rate reply_probability = min(reply_probability, 1) if reply_probability < 0: reply_probability = 0 + + + self.group_reply_willing[group_id] = min(current_willing, 3.0) return reply_probability def change_reply_willing_sent(self, group_id: int): diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index d7071985e..d2408e24f 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -1,8 +1,8 @@ import os import sys -import numpy as np -import requests import time + +import requests from dotenv import load_dotenv # 添加项目根目录到 Python 路径 diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index fad3f5f30..006991bcb 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -1,19 +1,12 @@ # -*- coding: utf-8 -*- import os import sys -import jieba -import networkx as nx -import matplotlib.pyplot as plt -import math -from collections import Counter -import datetime -import random import time + +import jieba +import matplotlib.pyplot as plt +import networkx as nx from dotenv import load_dotenv -import sys -import asyncio -import aiohttp -from typing import Tuple sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database # 使用正确的导入语法 diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 4d20d05a9..f88888aa4 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- -import os -import jieba -import networkx as nx -import matplotlib.pyplot as plt -from collections import Counter import datetime +import math import random import time + +import jieba +import networkx as nx + +from ...common.database import Database # 使用正确的导入语法 from ..chat.config import global_config -from ...common.database import Database # 使用正确的导入语法 +from ..chat.utils import ( + calculate_information_content, + cosine_similarity, + get_cloest_chat_from_db, + text_to_vector, +) from ..models.utils_model import LLM_request -import math -from ..chat.utils import calculate_information_content, get_cloest_chat_from_db - - - class Memory_graph: @@ -132,9 +133,17 @@ class Memory_graph: class Hippocampus: def __init__(self,memory_graph:Memory_graph): self.memory_graph = memory_graph - self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5) - self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5) + self.llm_topic_judge = LLM_request(model = global_config.llm_topic_judge,temperature=0.5) + self.llm_summary_by_topic = LLM_request(model = global_config.llm_summary_by_topic,temperature=0.5) + def get_all_node_names(self) -> list: + """获取记忆图中所有节点的名字列表 + + Returns: + list: 包含所有节点名字的列表 + """ + return list(self.memory_graph.G.nodes()) + def calculate_node_hash(self, concept, memory_items): """计算节点的特征值""" if not isinstance(memory_items, list): @@ -171,18 +180,24 @@ class Hippocampus: #获取topics topic_num = self.calculate_topic_num(input_text, compress_rate) - topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num)) + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) # 修改话题处理逻辑 - print(f"话题: {topics_response[0]}") - topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] - print(f"话题: {topics}") + # 定义需要过滤的关键词 + filter_keywords = ['表情包', '图片', '回复', '聊天记录'] - # 创建所有话题的请求任务 + # 过滤topics + topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] + filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] + + # print(f"原始话题: {topics}") + print(f"过滤后话题: {filtered_topics}") + + # 使用过滤后的话题继续处理 tasks = [] - for topic in topics: + for topic in filtered_topics: topic_what_prompt = self.topic_what(input_text, topic) # 创建异步任务 - task = self.llm_model_summary.generate_response_async(topic_what_prompt) + task = self.llm_summary_by_topic.generate_response_async(topic_what_prompt) tasks.append((topic.strip(), task)) # 等待所有任务完成 @@ -483,6 +498,201 @@ class Hippocampus: prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' return prompt + async def _identify_topics(self, text: str) -> list: + """从文本中识别可能的主题 + + Args: + text: 输入文本 + + Returns: + list: 识别出的主题列表 + """ + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5)) + # print(f"话题: {topics_response[0]}") + topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] + # print(f"话题: {topics}") + + return topics + + def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list: + """查找与给定主题相似的记忆主题 + + Args: + topics: 主题列表 + similarity_threshold: 相似度阈值 + debug_info: 调试信息前缀 + + Returns: + list: (主题, 相似度) 元组列表 + """ + all_memory_topics = self.get_all_node_names() + all_similar_topics = [] + + # 计算每个识别出的主题与记忆主题的相似度 + for topic in topics: + if debug_info: + # print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}") + pass + + topic_vector = text_to_vector(topic) + has_similar_topic = False + + for memory_topic in all_memory_topics: + memory_vector = text_to_vector(memory_topic) + # 获取所有唯一词 + all_words = set(topic_vector.keys()) | set(memory_vector.keys()) + # 构建向量 + v1 = [topic_vector.get(word, 0) for word in all_words] + v2 = [memory_vector.get(word, 0) for word in all_words] + # 计算相似度 + similarity = cosine_similarity(v1, v2) + + if similarity >= similarity_threshold: + has_similar_topic = True + if debug_info: + # print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})") + pass + all_similar_topics.append((memory_topic, similarity)) + + if not has_similar_topic and debug_info: + # print(f"\033[1;31m[{debug_info}]\033[0m 没有见过: {topic} ,呃呃") + pass + + return all_similar_topics + + def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list: + """获取相似度最高的主题 + + Args: + similar_topics: (主题, 相似度) 元组列表 + max_topics: 最大主题数量 + + Returns: + list: (主题, 相似度) 元组列表 + """ + seen_topics = set() + top_topics = [] + + for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True): + if topic not in seen_topics and len(top_topics) < max_topics: + seen_topics.add(topic) + top_topics.append((topic, score)) + + return top_topics + + async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: + """计算输入文本对记忆的激活程度""" + print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") + + # 识别主题 + identified_topics = await self._identify_topics(text) + if not identified_topics: + return 0 + + # 查找相似主题 + all_similar_topics = self._find_similar_topics( + identified_topics, + similarity_threshold=similarity_threshold, + debug_info="记忆激活" + ) + + if not all_similar_topics: + return 0 + + # 获取最相关的主题 + top_topics = self._get_top_topics(all_similar_topics, max_topics) + + # 如果只找到一个主题,进行惩罚 + if len(top_topics) == 1: + topic, score = top_topics[0] + # 获取主题内容数量并计算惩罚系数 + memory_items = self.memory_graph.G.nodes[topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + penalty = 1.0 / (1 + math.log(content_count + 1)) + + activation = int(score * 50 * penalty) + print(f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}") + return activation + + # 计算关键词匹配率,同时考虑内容数量 + matched_topics = set() + topic_similarities = {} + + for memory_topic, similarity in top_topics: + # 计算内容数量惩罚 + memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + penalty = 1.0 / (1 + math.log(content_count + 1)) + + # 对每个记忆主题,检查它与哪些输入主题相似 + for input_topic in identified_topics: + topic_vector = text_to_vector(input_topic) + memory_vector = text_to_vector(memory_topic) + all_words = set(topic_vector.keys()) | set(memory_vector.keys()) + v1 = [topic_vector.get(word, 0) for word in all_words] + v2 = [memory_vector.get(word, 0) for word in all_words] + sim = cosine_similarity(v1, v2) + if sim >= similarity_threshold: + matched_topics.add(input_topic) + adjusted_sim = sim * penalty + topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim) + print(f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})") + + # 计算主题匹配率和平均相似度 + topic_match = len(matched_topics) / len(identified_topics) + average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0 + + # 计算最终激活值 + activation = int((topic_match + average_similarities) / 2 * 100) + print(f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}") + + return activation + + async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5) -> list: + """根据输入文本获取相关的记忆内容""" + # 识别主题 + identified_topics = await self._identify_topics(text) + + # 查找相似主题 + all_similar_topics = self._find_similar_topics( + identified_topics, + similarity_threshold=similarity_threshold, + debug_info="记忆检索" + ) + + # 获取最相关的主题 + relevant_topics = self._get_top_topics(all_similar_topics, max_topics) + + # 获取相关记忆内容 + relevant_memories = [] + for topic, score in relevant_topics: + # 获取该主题的记忆内容 + first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) + if first_layer: + # 如果记忆条数超过限制,随机选择指定数量的记忆 + if len(first_layer) > max_memory_num/2: + first_layer = random.sample(first_layer, max_memory_num//2) + # 为每条记忆添加来源主题和相似度信息 + for memory in first_layer: + relevant_memories.append({ + 'topic': topic, + 'similarity': score, + 'content': memory + }) + + # 如果记忆数量超过5个,随机选择5个 + # 按相似度排序 + relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) + + if len(relevant_memories) > max_memory_num: + relevant_memories = random.sample(relevant_memories, max_memory_num) + + return relevant_memories + def segment_text(text): seg_text = list(jieba.cut(text)) @@ -490,6 +700,7 @@ def segment_text(text): from nonebot import get_driver + driver = get_driver() config = driver.config diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py index d6aa2f669..3124bc8e4 100644 --- a/src/plugins/memory_system/memory_manual_build.py +++ b/src/plugins/memory_system/memory_manual_build.py @@ -1,22 +1,22 @@ # -*- coding: utf-8 -*- -import sys -import jieba -import networkx as nx -import matplotlib.pyplot as plt -import math -from collections import Counter import datetime -import random -import time +import math import os -from dotenv import load_dotenv -import pymongo -from loguru import logger +import random +import sys +import time +from collections import Counter from pathlib import Path -from snownlp import SnowNLP + +import matplotlib.pyplot as plt +import networkx as nx +import pymongo +from dotenv import load_dotenv +from loguru import logger + # from chat.config import global_config sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 -from src.common.database import Database +from src.common.database import Database from src.plugins.memory_system.offline_llm import LLMModel # 获取当前文件的目录 @@ -103,7 +103,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): # 检查当前记录的memorized值 current_memorized = record.get('memorized', 0) if current_memorized > 3: - print(f"消息已读取3次,跳过") + print("消息已读取3次,跳过") return '' # 更新memorized值 @@ -115,7 +115,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): chat_text += record["detailed_plain_text"] return chat_text - print(f"消息已读取3次,跳过") + print("消息已读取3次,跳过") return '' class Memory_graph: @@ -234,16 +234,22 @@ class Hippocampus: async def memory_compress(self, input_text, compress_rate=0.1): print(input_text) - #获取topics topic_num = self.calculate_topic_num(input_text, compress_rate) - topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num)) + topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num)) # 修改话题处理逻辑 + # 定义需要过滤的关键词 + filter_keywords = ['表情包', '图片', '回复', '聊天记录'] + + # 过滤topics topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] - print(f"话题: {topics}") + filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] + + # print(f"原始话题: {topics}") + print(f"过滤后话题: {filtered_topics}") # 创建所有话题的请求任务 tasks = [] - for topic in topics: + for topic in filtered_topics: topic_what_prompt = self.topic_what(input_text, topic) # 创建异步任务 task = self.llm_model_small.generate_response_async(topic_what_prompt) @@ -650,7 +656,22 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal G = memory_graph.G # 创建一个新图用于可视化 - H = G.copy() + H = G.copy() + + # 过滤掉内容数量小于2的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + if memory_count < 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果没有符合条件的节点,直接返回 + if len(H.nodes()) == 0: + print("没有找到内容数量大于等于2的节点") + return # 计算节点大小和颜色 node_colors = [] @@ -704,7 +725,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal edge_color='gray', width=1.5) # 统一的边宽度 - title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近' + title = '记忆图谱可视化(仅显示内容≥2的节点)\n节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近' plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() diff --git a/src/plugins/memory_system/offline_llm.py b/src/plugins/memory_system/offline_llm.py index 5e877dceb..4a80b3ecd 100644 --- a/src/plugins/memory_system/offline_llm.py +++ b/src/plugins/memory_system/offline_llm.py @@ -1,11 +1,13 @@ -import os -import requests -from typing import Tuple, Union -import time -import aiohttp import asyncio +import os +import time +from typing import Tuple, Union + +import aiohttp +import requests from loguru import logger + class LLMModel: def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): self.model_name = model_name diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 11d7e2b72..c70c26ff9 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -1,11 +1,14 @@ -import aiohttp import asyncio -import requests -import time +import json import re +from datetime import datetime from typing import Tuple, Union -from nonebot import get_driver + +import aiohttp from loguru import logger +from nonebot import get_driver + +from ...common.database import Database from ..chat.config import global_config from ..chat.utils_image import compress_base64_image_by_scale @@ -24,397 +27,381 @@ class LLM_request: raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e self.model_name = model["name"] self.params = kwargs + + self.pri_in = model.get("pri_in", 0) + self.pri_out = model.get("pri_out", 0) + + # 获取数据库实例 + self.db = Database.get_instance() + self._init_database() - async def generate_response(self, prompt: str) -> Tuple[str, str]: - """根据输入的提示生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" + def _init_database(self): + """初始化数据库集合""" + try: + # 创建llm_usage集合的索引 + self.db.db.llm_usage.create_index([("timestamp", 1)]) + self.db.db.llm_usage.create_index([("model_name", 1)]) + self.db.db.llm_usage.create_index([("user_id", 1)]) + self.db.db.llm_usage.create_index([("request_type", 1)]) + except Exception as e: + logger.error(f"创建数据库索引失败: {e}") + + def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int, + user_id: str = "system", request_type: str = "chat", + endpoint: str = "/chat/completions"): + """记录模型使用情况到数据库 + Args: + prompt_tokens: 输入token数 + completion_tokens: 输出token数 + total_tokens: 总token数 + user_id: 用户ID,默认为system + request_type: 请求类型(chat/embedding/image等) + endpoint: API端点 + """ + try: + usage_data = { + "model_name": self.model_name, + "user_id": user_id, + "request_type": request_type, + "endpoint": endpoint, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "cost": self._calculate_cost(prompt_tokens, completion_tokens), + "status": "success", + "timestamp": datetime.now() + } + self.db.db.llm_usage.insert_one(usage_data) + logger.info( + f"Token使用情况 - 模型: {self.model_name}, " + f"用户: {user_id}, 类型: {request_type}, " + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {e}") + + def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: + """计算API调用成本 + 使用模型的pri_in和pri_out价格计算输入和输出的成本 + + Args: + prompt_tokens: 输入token数量 + completion_tokens: 输出token数量 + + Returns: + float: 总成本(元) + """ + # 使用模型的pri_in和pri_out计算成本 + input_cost = (prompt_tokens / 1000000) * self.pri_in + output_cost = (completion_tokens / 1000000) * self.pri_out + return round(input_cost + output_cost, 6) + + async def _execute_request( + self, + endpoint: str, + prompt: str = None, + image_base64: str = None, + payload: dict = None, + retry_policy: dict = None, + response_handler: callable = None, + user_id: str = "system", + request_type: str = "chat" + ): + """统一请求执行入口 + Args: + endpoint: API端点路径 (如 "chat/completions") + prompt: prompt文本 + image_base64: 图片的base64编码 + payload: 请求体数据 + retry_policy: 自定义重试策略 + response_handler: 自定义响应处理器 + user_id: 用户ID + request_type: 请求类型 + """ + # 合并重试策略 + default_retry = { + "max_retries": 3, "base_wait": 15, + "retry_codes": [429, 413, 500, 503], + "abort_codes": [400, 401, 402, 403]} + policy = {**default_retry, **(retry_policy or {})} + + # 常见Error Code Mapping + error_code_mapping = { + 400: "参数不正确", + 401: "API key 错误,认证失败", + 402: "账号余额不足", + 403: "需要实名,或余额不足", + 404: "Not Found", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高" } + api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" + #判断是否为流式 + stream_mode = self.params.get("stream", False) + if self.params.get("stream", False) is True: + logger.info(f"进入流式输出模式,发送请求到URL: {api_url}") + else: + logger.info(f"发送请求到URL: {api_url}") + logger.info(f"使用模型: {self.model_name}") + # 构建请求体 - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params - } + if image_base64: + payload = await self._build_payload(prompt, image_base64) + elif payload is None: + payload = await self._build_payload(prompt) - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): + for retry in range(policy["max_retries"]): try: + # 使用上下文管理器处理会话 + headers = await self._build_headers() + #似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 + if stream_mode: + headers["Accept"] = "text/event-stream" + async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") + async with session.post(api_url, headers=headers, json=payload) as response: + # 处理需要重试的状态码 + if response.status in policy["retry_codes"]: + wait_time = policy["base_wait"] * (2 ** retry) + logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") + if response.status == 413: + logger.warning("请求体过大,尝试压缩...") + image_base64 = compress_base64_image_by_scale(image_base64) + payload = await self._build_payload(prompt, image_base64) + elif response.status in [500, 503]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + else: + logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") + await asyncio.sleep(wait_time) continue + elif response.status in policy["abort_codes"]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - if response.status in [500, 503]: - logger.error(f"服务器错误: {response.status}") - raise RuntimeError("服务器负载过高,模型恢复失败QAQ") - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) + response.raise_for_status() + + #将流式输出转化为非流式输出 + if stream_mode: + accumulated_content = "" + async for line_bytes in response.content: + line = line_bytes.decode("utf-8").strip() + if not line: + continue + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + delta = chunk["choices"][0]["delta"] + delta_content = delta.get("content") + if delta_content is None: + delta_content = "" + accumulated_content += delta_content + except Exception as e: + logger.error(f"解析流式输出错误: {e}") + content = accumulated_content + reasoning_content = "" + think_match = re.search(r'(.*?)', content, re.DOTALL) if think_match: reasoning_content = think_match.group(1).strip() content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" + # 构造一个伪result以便调用自定义响应处理器或默认处理器 + result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]} + return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) + else: + result = await response.json() + # 使用自定义处理器或默认处理 + return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) + if retry < policy["max_retries"] - 1: + wait_time = policy["base_wait"] * (2 ** retry) + logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) + logger.critical(f"请求失败: {str(e)}") + logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") logger.error("达到最大重试次数,请求仍然失败") raise RuntimeError("达到最大重试次数,API请求仍然失败") + + async def _transform_parameters(self, params: dict) ->dict: + """ + 根据模型名称转换参数: + - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temprature' 参数, + 并将 'max_tokens' 重命名为 'max_completion_tokens' + """ + # 复制一份参数,避免直接修改原始数据 + new_params = dict(params) + # 定义需要转换的模型列表 + models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] + if self.model_name.lower() in models_needing_transformation: + # 删除 'temprature' 参数(如果存在) + new_params.pop("temperature", None) + # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' + if "max_tokens" in new_params: + new_params["max_completion_tokens"] = new_params.pop("max_tokens") + return new_params - async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """根据输入的提示和图片生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - # 构建请求体 - def build_request_data(img_base64: str): - return { + async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: + """构建请求体""" + # 复制一份参数,避免直接修改 self.params + params_copy = await self._transform_parameters(self.params) + if image_base64: + payload = { "model": self.model_name, "messages": [ { "role": "user", "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{img_base64}" - } - } + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} ] } ], - **self.params + "max_tokens": global_config.max_response_length, + **params_copy } + else: + payload = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, + **params_copy + } + # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 + if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload: + payload["max_completion_tokens"] = payload.pop("max_tokens") + return payload - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL + def _default_response_handler(self, result: dict, user_id: str = "system", + request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple: + """默认响应解析""" + if "choices" in result and result["choices"]: + message = result["choices"][0]["message"] + content = message.get("content", "") + content, reasoning = self._extract_reasoning(content) + reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") + if not reasoning_content: + reasoning_content = message.get("reasoning_content", "") + if not reasoning_content: + reasoning_content = reasoning - max_retries = 3 - base_wait_time = 15 - - current_image_base64 = image_base64 - current_image_base64 = compress_base64_image_by_scale(current_image_base64) - + # 记录token使用情况 + usage = result.get("usage", {}) + if usage: + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + total_tokens = usage.get("total_tokens", 0) + self._record_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + user_id=user_id, + request_type=request_type, + endpoint=endpoint + ) - for retry in range(max_retries): - try: - data = build_request_data(current_image_base64) - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue + return content, reasoning_content - elif response.status == 413: - logger.warning("图片太大(413),尝试压缩...") - current_image_base64 = compress_base64_image_by_scale(current_image_base64) - continue - - response.raise_for_status() # 检查其他响应状态 + return "没有返回结果", "" - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" + def _extract_reasoning(self, content: str) -> tuple[str, str]: + """CoT思维链提取""" + match = re.search(r'(?:)?(.*?)', content, re.DOTALL) + content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() + if match: + reasoning = match.group(1).strip() + else: + reasoning = "" + return content, reasoning - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - raise RuntimeError(f"API请求失败: {str(e)}") + async def _build_headers(self, no_key: bool = False) -> dict: + """构建请求头""" + if no_key: + return { + "Authorization": f"Bearer **********", + "Content-Type": "application/json" + } + else: + return { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + # 防止小朋友们截图自己的key - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") + async def generate_response(self, prompt: str) -> Tuple[str, str]: + """根据输入的提示生成模型的异步响应""" - async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt + ) + return content, reasoning_content + + async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: + """根据输入的提示和图片生成模型的异步响应""" + + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt, + image_base64=image_base64 + ) + return content, reasoning_content + + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]: """异步方式根据输入的提示生成模型的响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - # 构建请求体 data = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], - "temperature": 0.5, - **self.params - } - - # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"Request URL: {api_url}") # 记录请求的 URL - - max_retries = 3 - base_wait_time = 15 - - async with aiohttp.ClientSession() as session: - for retry in range(max_retries): - try: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - content = result["choices"][0]["message"]["content"] - reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") - await asyncio.sleep(wait_time) - else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" - - logger.error("达到最大重试次数,请求仍然失败") - return "达到最大重试次数,请求仍然失败", "" - - - - def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """同步方法:根据输入的提示和图片生成模型的响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - image_base64=compress_base64_image_by_scale(image_base64) - - # 构建请求体 - data = { - "model": self.model_name, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - } - } - ] - } - ], + "max_tokens": global_config.max_response_length, **self.params } - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + payload=data, + prompt=prompt + ) + return content, reasoning_content - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - raise RuntimeError(f"API请求失败: {str(e)}") - - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") - - def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: - """同步方法:获取文本的embedding向量 - - Args: - text: 需要获取embedding的文本 - model: 使用的模型名称,默认为"BAAI/bge-m3" - - Returns: - list: embedding向量,如果失败则返回None - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": model, - "input": text, - "encoding_format": "float" - } - - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() - - result = response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None - - async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: + async def get_embedding(self, text: str) -> Union[list, None]: """异步方法:获取文本的embedding向量 Args: text: 需要获取embedding的文本 - model: 使用的模型名称,默认为"BAAI/bge-m3" Returns: list: embedding向量,如果失败则返回None """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } + def embedding_handler(result): + """处理响应""" + if "data" in result and len(result["data"]) > 0: + return result["data"][0].get("embedding", None) + return None - data = { - "model": model, - "input": text, - "encoding_format": "float" - } + embedding = await self._execute_request( + endpoint="/embeddings", + prompt=text, + payload={ + "model": self.model_name, + "input": text, + "encoding_format": "float" + }, + retry_policy={ + "max_retries": 2, + "base_wait": 6 + }, + response_handler=embedding_handler + ) + return embedding - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() - - result = await response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None diff --git a/src/plugins/moods/moods.py b/src/plugins/moods/moods.py new file mode 100644 index 000000000..32b900b0b --- /dev/null +++ b/src/plugins/moods/moods.py @@ -0,0 +1,231 @@ +import math +import threading +import time +from dataclasses import dataclass + +from ..chat.config import global_config + + +@dataclass +class MoodState: + valence: float # 愉悦度 (-1 到 1) + arousal: float # 唤醒度 (0 到 1) + text: str # 心情文本描述 + +class MoodManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + # 确保初始化代码只运行一次 + if self._initialized: + return + + self._initialized = True + + # 初始化心情状态 + self.current_mood = MoodState( + valence=0.0, + arousal=0.5, + text="平静" + ) + + # 从配置文件获取衰减率 + self.decay_rate_valence = 1 - global_config.mood_decay_rate # 愉悦度衰减率 + self.decay_rate_arousal = 1 - global_config.mood_decay_rate # 唤醒度衰减率 + + # 上次更新时间 + self.last_update = time.time() + + # 线程控制 + self._running = False + self._update_thread = None + + # 情绪词映射表 (valence, arousal) + self.emotion_map = { + 'happy': (0.8, 0.6), # 高愉悦度,中等唤醒度 + 'angry': (-0.7, 0.8), # 负愉悦度,高唤醒度 + 'sad': (-0.6, 0.3), # 负愉悦度,低唤醒度 + 'surprised': (0.4, 0.9), # 中等愉悦度,高唤醒度 + 'disgusted': (-0.8, 0.5), # 高负愉悦度,中等唤醒度 + 'fearful': (-0.7, 0.7), # 负愉悦度,高唤醒度 + 'neutral': (0.0, 0.5), # 中性愉悦度,中等唤醒度 + } + + # 情绪文本映射表 + self.mood_text_map = { + # 第一象限:高唤醒,正愉悦 + (0.5, 0.7): "兴奋", + (0.3, 0.8): "快乐", + # 第二象限:高唤醒,负愉悦 + (-0.5, 0.7): "愤怒", + (-0.3, 0.8): "焦虑", + # 第三象限:低唤醒,负愉悦 + (-0.5, 0.3): "悲伤", + (-0.3, 0.2): "疲倦", + # 第四象限:低唤醒,正愉悦 + (0.5, 0.3): "放松", + (0.3, 0.2): "平静" + } + + @classmethod + def get_instance(cls) -> 'MoodManager': + """获取MoodManager的单例实例""" + if cls._instance is None: + cls._instance = MoodManager() + return cls._instance + + def start_mood_update(self, update_interval: float = 1.0) -> None: + """ + 启动情绪更新线程 + :param update_interval: 更新间隔(秒) + """ + if self._running: + return + + self._running = True + self._update_thread = threading.Thread( + target=self._continuous_mood_update, + args=(update_interval,), + daemon=True + ) + self._update_thread.start() + + def stop_mood_update(self) -> None: + """停止情绪更新线程""" + self._running = False + if self._update_thread and self._update_thread.is_alive(): + self._update_thread.join() + + def _continuous_mood_update(self, update_interval: float) -> None: + """ + 持续更新情绪状态的线程函数 + :param update_interval: 更新间隔(秒) + """ + while self._running: + self._apply_decay() + self._update_mood_text() + time.sleep(update_interval) + + def _apply_decay(self) -> None: + """应用情绪衰减""" + current_time = time.time() + time_diff = current_time - self.last_update + + # 应用衰减公式 + self.current_mood.valence *= math.pow(1 - self.decay_rate_valence, time_diff) + self.current_mood.arousal *= math.pow(1 - self.decay_rate_arousal, time_diff) + + # 确保值在合理范围内 + self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) + self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + + self.last_update = current_time + + def update_mood_from_text(self, text: str, valence_change: float, arousal_change: float) -> None: + """根据输入文本更新情绪状态""" + + self.current_mood.valence += valence_change + self.current_mood.arousal += arousal_change + + # 限制范围 + self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) + self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + + self._update_mood_text() + + def set_mood_text(self, text: str) -> None: + """直接设置心情文本""" + self.current_mood.text = text + + def _update_mood_text(self) -> None: + """根据当前情绪状态更新文本描述""" + closest_mood = None + min_distance = float('inf') + + for (v, a), text in self.mood_text_map.items(): + distance = math.sqrt( + (self.current_mood.valence - v) ** 2 + + (self.current_mood.arousal - a) ** 2 + ) + if distance < min_distance: + min_distance = distance + closest_mood = text + + if closest_mood: + self.current_mood.text = closest_mood + + def update_mood_by_user(self, user_id: str, valence_change: float, arousal_change: float) -> None: + """根据用户ID更新情绪状态""" + + # 这里可以根据用户ID添加特定的权重或规则 + weight = 1.0 # 默认权重 + + self.current_mood.valence += valence_change * weight + self.current_mood.arousal += arousal_change * weight + + # 限制范围 + self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) + self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + + self._update_mood_text() + + def get_prompt(self) -> str: + """根据当前情绪状态生成提示词""" + + base_prompt = f"当前心情:{self.current_mood.text}。" + + # 根据情绪状态添加额外的提示信息 + if self.current_mood.valence > 0.5: + base_prompt += "你现在心情很好," + elif self.current_mood.valence < -0.5: + base_prompt += "你现在心情不太好," + + if self.current_mood.arousal > 0.7: + base_prompt += "情绪比较激动。" + elif self.current_mood.arousal < 0.3: + base_prompt += "情绪比较平静。" + + return base_prompt + + def get_current_mood(self) -> MoodState: + """获取当前情绪状态""" + return self.current_mood + + def print_mood_status(self) -> None: + """打印当前情绪状态""" + print(f"\033[1;35m[情绪状态]\033[0m 愉悦度: {self.current_mood.valence:.2f}, " + f"唤醒度: {self.current_mood.arousal:.2f}, " + f"心情: {self.current_mood.text}") + + def update_mood_from_emotion(self, emotion: str, intensity: float = 1.0) -> None: + """ + 根据情绪词更新心情状态 + :param emotion: 情绪词(如'happy', 'sad'等) + :param intensity: 情绪强度(0.0-1.0) + """ + if emotion not in self.emotion_map: + return + + valence_change, arousal_change = self.emotion_map[emotion] + + # 应用情绪强度 + valence_change *= intensity + arousal_change *= intensity + + # 更新当前情绪状态 + self.current_mood.valence += valence_change + self.current_mood.arousal += arousal_change + + # 限制范围 + self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) + self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + + self._update_mood_text() diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index f2b11c33f..8a036152c 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -1,12 +1,15 @@ import datetime -import os -from typing import List, Dict, Union -from ...common.database import Database # 使用正确的导入语法 -from src.plugins.chat.config import global_config -from nonebot import get_driver -from ..models.utils_model import LLM_request -from loguru import logger import json +from typing import Dict, Union + +from loguru import logger +from nonebot import get_driver + +from src.plugins.chat.config import global_config + +from ...common.database import Database # 使用正确的导入语法 +from ..models.utils_model import LLM_request + driver = get_driver() config = driver.config diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py new file mode 100644 index 000000000..d7248e869 --- /dev/null +++ b/src/plugins/utils/statistic.py @@ -0,0 +1,163 @@ +import threading +import time +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any, Dict + +from ...common.database import Database + + +class LLMStatistics: + def __init__(self, output_file: str = "llm_statistics.txt"): + """初始化LLM统计类 + + Args: + output_file: 统计结果输出文件路径 + """ + self.db = Database.get_instance() + self.output_file = output_file + self.running = False + self.stats_thread = None + + def start(self): + """启动统计线程""" + if not self.running: + self.running = True + self.stats_thread = threading.Thread(target=self._stats_loop) + self.stats_thread.daemon = True + self.stats_thread.start() + + def stop(self): + """停止统计线程""" + self.running = False + if self.stats_thread: + self.stats_thread.join() + + def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]: + """收集指定时间段的LLM请求统计数据 + + Args: + start_time: 统计开始时间 + """ + stats = { + "total_requests": 0, + "requests_by_type": defaultdict(int), + "requests_by_user": defaultdict(int), + "requests_by_model": defaultdict(int), + "average_tokens": 0, + "total_tokens": 0, + "total_cost": 0.0, + "costs_by_user": defaultdict(float), + "costs_by_type": defaultdict(float), + "costs_by_model": defaultdict(float) + } + + cursor = self.db.db.llm_usage.find({ + "timestamp": {"$gte": start_time} + }) + + total_requests = 0 + + for doc in cursor: + stats["total_requests"] += 1 + request_type = doc.get("request_type", "unknown") + user_id = str(doc.get("user_id", "unknown")) + model_name = doc.get("model_name", "unknown") + + stats["requests_by_type"][request_type] += 1 + stats["requests_by_user"][user_id] += 1 + stats["requests_by_model"][model_name] += 1 + + prompt_tokens = doc.get("prompt_tokens", 0) + completion_tokens = doc.get("completion_tokens", 0) + stats["total_tokens"] += prompt_tokens + completion_tokens + + cost = doc.get("cost", 0.0) + stats["total_cost"] += cost + stats["costs_by_user"][user_id] += cost + stats["costs_by_type"][request_type] += cost + stats["costs_by_model"][model_name] += cost + + total_requests += 1 + + if total_requests > 0: + stats["average_tokens"] = stats["total_tokens"] / total_requests + + return stats + + def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]: + """收集所有时间范围的统计数据""" + now = datetime.now() + + return { + "all_time": self._collect_statistics_for_period(datetime.min), + "last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)), + "last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)), + "last_hour": self._collect_statistics_for_period(now - timedelta(hours=1)) + } + + def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str: + """格式化统计部分的输出 + + Args: + stats: 统计数据 + title: 部分标题 + """ + output = [] + output.append(f"\n{title}") + output.append("=" * len(title)) + + output.append(f"总请求数: {stats['total_requests']}") + if stats['total_requests'] > 0: + output.append(f"总Token数: {stats['total_tokens']}") + output.append(f"总花费: ¥{stats['total_cost']:.4f}") + + output.append("\n按模型统计:") + for model_name, count in sorted(stats["requests_by_model"].items()): + cost = stats["costs_by_model"][model_name] + output.append(f"- {model_name}: {count}次 (花费: ¥{cost:.4f})") + + output.append("\n按请求类型统计:") + for req_type, count in sorted(stats["requests_by_type"].items()): + cost = stats["costs_by_type"][req_type] + output.append(f"- {req_type}: {count}次 (花费: ¥{cost:.4f})") + + return "\n".join(output) + + def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]): + """将统计结果保存到文件""" + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + output = [] + output.append(f"LLM请求统计报告 (生成时间: {current_time})") + output.append("=" * 50) + + # 添加各个时间段的统计 + sections = [ + ("所有时间统计", "all_time"), + ("最近7天统计", "last_7_days"), + ("最近24小时统计", "last_24_hours"), + ("最近1小时统计", "last_hour") + ] + + for title, key in sections: + output.append(self._format_stats_section(all_stats[key], title)) + + # 写入文件 + with open(self.output_file, "w", encoding="utf-8") as f: + f.write("\n".join(output)) + + def _stats_loop(self): + """统计循环,每1分钟运行一次""" + while self.running: + try: + all_stats = self._collect_all_statistics() + self._save_statistics(all_stats) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}") + + # 等待1分钟 + for _ in range(60): + if not self.running: + break + time.sleep(1) diff --git a/src/plugins/utils/typo_generator.py b/src/plugins/utils/typo_generator.py new file mode 100644 index 000000000..c743ec6ec --- /dev/null +++ b/src/plugins/utils/typo_generator.py @@ -0,0 +1,439 @@ +""" +错别字生成器 - 基于拼音和字频的中文错别字生成工具 +""" + +import json +import math +import os +import random +import time +from collections import defaultdict +from pathlib import Path + +import jieba +from pypinyin import Style, pinyin + + +class ChineseTypoGenerator: + def __init__(self, + error_rate=0.3, + min_freq=5, + tone_error_rate=0.2, + word_replace_rate=0.3, + max_freq_diff=200): + """ + 初始化错别字生成器 + + 参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + self.error_rate = error_rate + self.min_freq = min_freq + self.tone_error_rate = tone_error_rate + self.word_replace_rate = word_replace_rate + self.max_freq_diff = max_freq_diff + + # 加载数据 + print("正在加载汉字数据库,请稍候...") + self.pinyin_dict = self._create_pinyin_dict() + self.char_frequency = self._load_or_create_char_frequency() + + def _load_or_create_char_frequency(self): + """ + 加载或创建汉字频率字典 + """ + cache_file = Path("char_frequency.json") + + # 如果缓存文件存在,直接加载 + if cache_file.exists(): + with open(cache_file, 'r', encoding='utf-8') as f: + return json.load(f) + + # 使用内置的词频文件 + char_freq = defaultdict(int) + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + + # 读取jieba的词典文件 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, freq = line.strip().split()[:2] + # 对词中的每个字进行频率累加 + for char in word: + if self._is_chinese_char(char): + char_freq[char] += int(freq) + + # 归一化频率值 + max_freq = max(char_freq.values()) + normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} + + # 保存到缓存文件 + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(normalized_freq, f, ensure_ascii=False, indent=2) + + return normalized_freq + + def _create_pinyin_dict(self): + """ + 创建拼音到汉字的映射字典 + """ + # 常用汉字范围 + chars = [chr(i) for i in range(0x4e00, 0x9fff)] + pinyin_dict = defaultdict(list) + + # 为每个汉字建立拼音映射 + for char in chars: + try: + py = pinyin(char, style=Style.TONE3)[0][0] + pinyin_dict[py].append(char) + except Exception: + continue + + return pinyin_dict + + def _is_chinese_char(self, char): + """ + 判断是否为汉字 + """ + try: + return '\u4e00' <= char <= '\u9fff' + except: + return False + + def _get_pinyin(self, sentence): + """ + 将中文句子拆分成单个汉字并获取其拼音 + """ + # 将句子拆分成单个字符 + characters = list(sentence) + + # 获取每个字符的拼音 + result = [] + for char in characters: + # 跳过空格和非汉字字符 + if char.isspace() or not self._is_chinese_char(char): + continue + # 获取拼音(数字声调) + py = pinyin(char, style=Style.TONE3)[0][0] + result.append((char, py)) + + return result + + def _get_similar_tone_pinyin(self, py): + """ + 获取相似声调的拼音 + """ + # 检查拼音是否为空或无效 + if not py or len(py) < 1: + return py + + # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 + if not py[-1].isdigit(): + # 为非数字结尾的拼音添加数字声调1 + return py + '1' + + base = py[:-1] # 去掉声调 + tone = int(py[-1]) # 获取声调 + + # 处理轻声(通常用5表示)或无效声调 + if tone not in [1, 2, 3, 4]: + return base + str(random.choice([1, 2, 3, 4])) + + # 正常处理声调 + possible_tones = [1, 2, 3, 4] + possible_tones.remove(tone) # 移除原声调 + new_tone = random.choice(possible_tones) # 随机选择一个新声调 + return base + str(new_tone) + + def _calculate_replacement_probability(self, orig_freq, target_freq): + """ + 根据频率差计算替换概率 + """ + if target_freq > orig_freq: + return 1.0 # 如果替换字频率更高,保持原有概率 + + freq_diff = orig_freq - target_freq + if freq_diff > self.max_freq_diff: + return 0.0 # 频率差太大,不替换 + + # 使用指数衰减函数计算概率 + # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 + return math.exp(-3 * freq_diff / self.max_freq_diff) + + def _get_similar_frequency_chars(self, char, py, num_candidates=5): + """ + 获取与给定字频率相近的同音字,可能包含声调错误 + """ + homophones = [] + + # 有一定概率使用错误声调 + if random.random() < self.tone_error_rate: + wrong_tone_py = self._get_similar_tone_pinyin(py) + homophones.extend(self.pinyin_dict[wrong_tone_py]) + + # 添加正确声调的同音字 + homophones.extend(self.pinyin_dict[py]) + + if not homophones: + return None + + # 获取原字的频率 + orig_freq = self.char_frequency.get(char, 0) + + # 计算所有同音字与原字的频率差,并过滤掉低频字 + freq_diff = [(h, self.char_frequency.get(h, 0)) + for h in homophones + if h != char and self.char_frequency.get(h, 0) >= self.min_freq] + + if not freq_diff: + return None + + # 计算每个候选字的替换概率 + candidates_with_prob = [] + for h, freq in freq_diff: + prob = self._calculate_replacement_probability(orig_freq, freq) + if prob > 0: # 只保留有效概率的候选字 + candidates_with_prob.append((h, prob)) + + if not candidates_with_prob: + return None + + # 根据概率排序 + candidates_with_prob.sort(key=lambda x: x[1], reverse=True) + + # 返回概率最高的几个字 + return [char for char, _ in candidates_with_prob[:num_candidates]] + + def _get_word_pinyin(self, word): + """ + 获取词语的拼音列表 + """ + return [py[0] for py in pinyin(word, style=Style.TONE3)] + + def _segment_sentence(self, sentence): + """ + 使用jieba分词,返回词语列表 + """ + return list(jieba.cut(sentence)) + + def _get_word_homophones(self, word): + """ + 获取整个词的同音词,只返回高频的有意义词语 + """ + if len(word) == 1: + return [] + + # 获取词的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 遍历所有可能的同音字组合 + candidates = [] + for py in word_pinyin: + chars = self.pinyin_dict.get(py, []) + if not chars: + return [] + candidates.append(chars) + + # 生成所有可能的组合 + import itertools + all_combinations = itertools.product(*candidates) + + # 获取jieba词典和词频信息 + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + valid_words = {} # 改用字典存储词语及其频率 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 2: + word_text = parts[0] + word_freq = float(parts[1]) # 获取词频 + valid_words[word_text] = word_freq + + # 获取原词的词频作为参考 + original_word_freq = valid_words.get(word, 0) + min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% + + # 过滤和计算频率 + homophones = [] + for combo in all_combinations: + new_word = ''.join(combo) + if new_word != word and new_word in valid_words: + new_word_freq = valid_words[new_word] + # 只保留词频达到阈值的词 + if new_word_freq >= min_word_freq: + # 计算词的平均字频(考虑字频和词频) + char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word) + # 综合评分:结合词频和字频 + combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) + if combined_score >= self.min_freq: + homophones.append((new_word, combined_score)) + + # 按综合分数排序并限制返回数量 + sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) + return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 + + def create_typo_sentence(self, sentence): + """ + 创建包含同音字错误的句子,支持词语级别和字级别的替换 + + 参数: + sentence: 输入的中文句子 + + 返回: + typo_sentence: 包含错别字的句子 + typo_info: 错别字信息列表 + """ + result = [] + typo_info = [] + + # 分词 + words = self._segment_sentence(sentence) + + for word in words: + # 如果是标点符号或空格,直接添加 + if all(not self._is_chinese_char(c) for c in word): + result.append(word) + continue + + # 获取词语的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 尝试整词替换 + if len(word) > 1 and random.random() < self.word_replace_rate: + word_homophones = self._get_word_homophones(word) + if word_homophones: + typo_word = random.choice(word_homophones) + # 计算词的平均频率 + orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) + typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) + + # 添加到结果中 + result.append(typo_word) + typo_info.append((word, typo_word, + ' '.join(word_pinyin), + ' '.join(self._get_word_pinyin(typo_word)), + orig_freq, typo_freq)) + continue + + # 如果不进行整词替换,则进行单字替换 + if len(word) == 1: + char = word + py = word_pinyin[0] + if random.random() < self.error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + result.append(char) + else: + # 处理多字词的单字替换 + word_result = [] + for i, (char, py) in enumerate(zip(word, word_pinyin)): + # 词中的字替换概率降低 + word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) + + if random.random() < word_error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + word_result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + word_result.append(char) + result.append(''.join(word_result)) + + return ''.join(result), typo_info + + def format_typo_info(self, typo_info): + """ + 格式化错别字信息 + + 参数: + typo_info: 错别字信息列表 + + 返回: + 格式化后的错别字信息字符串 + """ + if not typo_info: + return "未生成错别字" + + result = [] + for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: + # 判断是否为词语替换 + is_word = ' ' in orig_py + if is_word: + error_type = "整词替换" + else: + tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] + error_type = "声调错误" if tone_error else "同音字替换" + + result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " + f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") + + return "\n".join(result) + + def set_params(self, **kwargs): + """ + 设置参数 + + 可设置参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + print(f"参数 {key} 已设置为 {value}") + else: + print(f"警告: 参数 {key} 不存在") + +def main(): + # 创建错别字生成器实例 + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.02, + word_replace_rate=0.3 + ) + + # 获取用户输入 + sentence = input("请输入中文句子:") + + # 创建包含错别字的句子 + start_time = time.time() + typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) + + # 打印结果 + print("\n原句:", sentence) + print("错字版:", typo_sentence) + + # 打印错别字信息 + if typo_info: + print("\n错别字信息:") + print(typo_generator.format_typo_info(typo_info)) + + # 计算并打印总耗时 + end_time = time.time() + total_time = end_time - start_time + print(f"\n总耗时:{total_time:.2f}秒") + +if __name__ == "__main__": + main() diff --git a/src/test/typo.py b/src/test/typo.py index c452589ce..16834200f 100644 --- a/src/test/typo.py +++ b/src/test/typo.py @@ -1,455 +1,376 @@ """ -错别字生成器 - 流程说明 - -整体替换逻辑: -1. 数据准备 - - 加载字频词典:使用jieba词典计算汉字使用频率 - - 创建拼音映射:建立拼音到汉字的映射关系 - - 加载词频信息:从jieba词典获取词语使用频率 - -2. 分词处理 - - 使用jieba将输入句子分词 - - 区分单字词和多字词 - - 保留标点符号和空格 - -3. 词语级别替换(针对多字词) - - 触发条件:词长>1 且 随机概率<0.3 - - 替换流程: - a. 获取词语拼音 - b. 生成所有可能的同音字组合 - c. 过滤条件: - - 必须是jieba词典中的有效词 - - 词频必须达到原词频的10%以上 - - 综合评分(词频70%+字频30%)必须达到阈值 - d. 按综合评分排序,选择最合适的替换词 - -4. 字级别替换(针对单字词或未进行整词替换的多字词) - - 单字替换概率:0.3 - - 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1)) - - 替换流程: - a. 获取字的拼音 - b. 声调错误处理(20%概率) - c. 获取同音字列表 - d. 过滤条件: - - 字频必须达到最小阈值 - - 频率差异不能过大(指数衰减计算) - e. 按频率排序选择替换字 - -5. 频率控制机制 - - 字频控制:使用归一化的字频(0-1000范围) - - 词频控制:使用jieba词典中的词频 - - 频率差异计算:使用指数衰减函数 - - 最小频率阈值:确保替换字/词不会太生僻 - -6. 输出信息 - - 原文和错字版本的对照 - - 每个替换的详细信息(原字/词、替换后字/词、拼音、频率) - - 替换类型说明(整词替换/声调错误/同音字替换) - - 词语分析和完整拼音 - -注意事项: -1. 所有替换都必须使用有意义的词语 -2. 替换词的使用频率不能过低 -3. 多字词优先考虑整词替换 -4. 考虑声调变化的情况 -5. 保持标点符号和空格不变 +错别字生成器 - 基于拼音和字频的中文错别字生成工具 """ from pypinyin import pinyin, Style from collections import defaultdict import json import os -import unicodedata import jieba -import jieba.posseg as pseg from pathlib import Path import random import math import time -def load_or_create_char_frequency(): - """ - 加载或创建汉字频率字典 - """ - cache_file = Path("char_frequency.json") - - # 如果缓存文件存在,直接加载 - if cache_file.exists(): - with open(cache_file, 'r', encoding='utf-8') as f: - return json.load(f) - - # 使用内置的词频文件 - char_freq = defaultdict(int) - dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') - - # 读取jieba的词典文件 - with open(dict_path, 'r', encoding='utf-8') as f: - for line in f: - word, freq = line.strip().split()[:2] - # 对词中的每个字进行频率累加 - for char in word: - if is_chinese_char(char): - char_freq[char] += int(freq) - - # 归一化频率值 - max_freq = max(char_freq.values()) - normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} - - # 保存到缓存文件 - with open(cache_file, 'w', encoding='utf-8') as f: - json.dump(normalized_freq, f, ensure_ascii=False, indent=2) - - return normalized_freq - -# 创建拼音到汉字的映射字典 -def create_pinyin_dict(): - """ - 创建拼音到汉字的映射字典 - """ - # 常用汉字范围 - chars = [chr(i) for i in range(0x4e00, 0x9fff)] - pinyin_dict = defaultdict(list) - - # 为每个汉字建立拼音映射 - for char in chars: - try: - py = pinyin(char, style=Style.TONE3)[0][0] - pinyin_dict[py].append(char) - except Exception: - continue - - return pinyin_dict - -def is_chinese_char(char): - """ - 判断是否为汉字 - """ - try: - return '\u4e00' <= char <= '\u9fff' - except: - return False - -def get_pinyin(sentence): - """ - 将中文句子拆分成单个汉字并获取其拼音 - :param sentence: 输入的中文句子 - :return: 每个汉字及其拼音的列表 - """ - # 将句子拆分成单个字符 - characters = list(sentence) - - # 获取每个字符的拼音 - result = [] - for char in characters: - # 跳过空格和非汉字字符 - if char.isspace() or not is_chinese_char(char): - continue - # 获取拼音(数字声调) - py = pinyin(char, style=Style.TONE3)[0][0] - result.append((char, py)) - - return result - -def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5): - """ - 获取同音字,按照使用频率排序 - """ - homophones = pinyin_dict[py] - # 移除原字并过滤低频字 - if char in homophones: - homophones.remove(char) - - # 过滤掉低频字 - homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq] - - # 按照字频排序 - sorted_homophones = sorted(homophones, - key=lambda x: char_frequency.get(x, 0), - reverse=True) - - # 只返回前10个同音字,避免输出过多 - return sorted_homophones[:10] - -def get_similar_tone_pinyin(py): - """ - 获取相似声调的拼音 - 例如:'ni3' 可能返回 'ni2' 或 'ni4' - 处理特殊情况: - 1. 轻声(如 'de5' 或 'le') - 2. 非数字结尾的拼音 - """ - # 检查拼音是否为空或无效 - if not py or len(py) < 1: - return py +class ChineseTypoGenerator: + def __init__(self, + error_rate=0.3, + min_freq=5, + tone_error_rate=0.2, + word_replace_rate=0.3, + max_freq_diff=200): + """ + 初始化错别字生成器 - # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 - if not py[-1].isdigit(): - # 为非数字结尾的拼音添加数字声调1 - return py + '1' - - base = py[:-1] # 去掉声调 - tone = int(py[-1]) # 获取声调 - - # 处理轻声(通常用5表示)或无效声调 - if tone not in [1, 2, 3, 4]: - return base + str(random.choice([1, 2, 3, 4])) - - # 正常处理声调 - possible_tones = [1, 2, 3, 4] - possible_tones.remove(tone) # 移除原声调 - new_tone = random.choice(possible_tones) # 随机选择一个新声调 - return base + str(new_tone) - -def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200): - """ - 根据频率差计算替换概率 - 频率差越大,概率越低 - :param orig_freq: 原字频率 - :param target_freq: 目标字频率 - :param max_freq_diff: 最大允许的频率差 - :return: 0-1之间的概率值 - """ - if target_freq > orig_freq: - return 1.0 # 如果替换字频率更高,保持原有概率 - - freq_diff = orig_freq - target_freq - if freq_diff > max_freq_diff: - return 0.0 # 频率差太大,不替换 - - # 使用指数衰减函数计算概率 - # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 - return math.exp(-3 * freq_diff / max_freq_diff) - -def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2): - """ - 获取与给定字频率相近的同音字,可能包含声调错误 - """ - homophones = [] - - # 有20%的概率使用错误声调 - if random.random() < tone_error_rate: - wrong_tone_py = get_similar_tone_pinyin(py) - homophones.extend(pinyin_dict[wrong_tone_py]) - - # 添加正确声调的同音字 - homophones.extend(pinyin_dict[py]) - - if not homophones: - return None + 参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + self.error_rate = error_rate + self.min_freq = min_freq + self.tone_error_rate = tone_error_rate + self.word_replace_rate = word_replace_rate + self.max_freq_diff = max_freq_diff - # 获取原字的频率 - orig_freq = char_frequency.get(char, 0) + # 加载数据 + print("正在加载汉字数据库,请稍候...") + self.pinyin_dict = self._create_pinyin_dict() + self.char_frequency = self._load_or_create_char_frequency() - # 计算所有同音字与原字的频率差,并过滤掉低频字 - freq_diff = [(h, char_frequency.get(h, 0)) - for h in homophones - if h != char and char_frequency.get(h, 0) >= min_freq] - - if not freq_diff: - return None - - # 计算每个候选字的替换概率 - candidates_with_prob = [] - for h, freq in freq_diff: - prob = calculate_replacement_probability(orig_freq, freq) - if prob > 0: # 只保留有效概率的候选字 - candidates_with_prob.append((h, prob)) - - if not candidates_with_prob: - return None - - # 根据概率排序 - candidates_with_prob.sort(key=lambda x: x[1], reverse=True) - - # 返回概率最高的几个字 - return [char for char, _ in candidates_with_prob[:num_candidates]] - -def get_word_pinyin(word): - """ - 获取词语的拼音列表 - """ - return [py[0] for py in pinyin(word, style=Style.TONE3)] - -def segment_sentence(sentence): - """ - 使用jieba分词,返回词语列表 - """ - return list(jieba.cut(sentence)) - -def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5): - """ - 获取整个词的同音词,只返回高频的有意义词语 - :param word: 输入词语 - :param pinyin_dict: 拼音字典 - :param char_frequency: 字频字典 - :param min_freq: 最小频率阈值 - :return: 同音词列表 - """ - if len(word) == 1: - return [] + def _load_or_create_char_frequency(self): + """ + 加载或创建汉字频率字典 + """ + cache_file = Path("char_frequency.json") - # 获取词的拼音 - word_pinyin = get_word_pinyin(word) - word_pinyin_str = ''.join(word_pinyin) - - # 创建词语频率字典 - word_freq = defaultdict(float) - - # 遍历所有可能的同音字组合 - candidates = [] - for py in word_pinyin: - chars = pinyin_dict.get(py, []) - if not chars: - return [] - candidates.append(chars) - - # 生成所有可能的组合 - import itertools - all_combinations = itertools.product(*candidates) - - # 获取jieba词典和词频信息 - dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') - valid_words = {} # 改用字典存储词语及其频率 - with open(dict_path, 'r', encoding='utf-8') as f: - for line in f: - parts = line.strip().split() - if len(parts) >= 2: - word_text = parts[0] - word_freq = float(parts[1]) # 获取词频 - valid_words[word_text] = word_freq - - # 获取原词的词频作为参考 - original_word_freq = valid_words.get(word, 0) - min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% - - # 过滤和计算频率 - homophones = [] - for combo in all_combinations: - new_word = ''.join(combo) - if new_word != word and new_word in valid_words: - new_word_freq = valid_words[new_word] - # 只保留词频达到阈值的词 - if new_word_freq >= min_word_freq: - # 计算词的平均字频(考虑字频和词频) - char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word) - # 综合评分:结合词频和字频 - combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) - if combined_score >= min_freq: - homophones.append((new_word, combined_score)) - - # 按综合分数排序并限制返回数量 - sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) - return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 - -def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3): - """ - 创建包含同音字错误的句子,支持词语级别和字级别的替换 - 只使用高频的有意义词语进行替换 - """ - result = [] - typo_info = [] - - # 分词 - words = segment_sentence(sentence) - - for word in words: - # 如果是标点符号或空格,直接添加 - if all(not is_chinese_char(c) for c in word): - result.append(word) - continue - - # 获取词语的拼音 - word_pinyin = get_word_pinyin(word) + # 如果缓存文件存在,直接加载 + if cache_file.exists(): + with open(cache_file, 'r', encoding='utf-8') as f: + return json.load(f) - # 尝试整词替换 - if len(word) > 1 and random.random() < word_replace_rate: - word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq) - if word_homophones: - typo_word = random.choice(word_homophones) - # 计算词的平均频率 - orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word) - typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word) - - # 添加到结果中 - result.append(typo_word) - typo_info.append((word, typo_word, - ' '.join(word_pinyin), - ' '.join(get_word_pinyin(typo_word)), - orig_freq, typo_freq)) + # 使用内置的词频文件 + char_freq = defaultdict(int) + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + + # 读取jieba的词典文件 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, freq = line.strip().split()[:2] + # 对词中的每个字进行频率累加 + for char in word: + if self._is_chinese_char(char): + char_freq[char] += int(freq) + + # 归一化频率值 + max_freq = max(char_freq.values()) + normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} + + # 保存到缓存文件 + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(normalized_freq, f, ensure_ascii=False, indent=2) + + return normalized_freq + + def _create_pinyin_dict(self): + """ + 创建拼音到汉字的映射字典 + """ + # 常用汉字范围 + chars = [chr(i) for i in range(0x4e00, 0x9fff)] + pinyin_dict = defaultdict(list) + + # 为每个汉字建立拼音映射 + for char in chars: + try: + py = pinyin(char, style=Style.TONE3)[0][0] + pinyin_dict[py].append(char) + except Exception: continue - # 如果不进行整词替换,则进行单字替换 + return pinyin_dict + + def _is_chinese_char(self, char): + """ + 判断是否为汉字 + """ + try: + return '\u4e00' <= char <= '\u9fff' + except: + return False + + def _get_pinyin(self, sentence): + """ + 将中文句子拆分成单个汉字并获取其拼音 + """ + # 将句子拆分成单个字符 + characters = list(sentence) + + # 获取每个字符的拼音 + result = [] + for char in characters: + # 跳过空格和非汉字字符 + if char.isspace() or not self._is_chinese_char(char): + continue + # 获取拼音(数字声调) + py = pinyin(char, style=Style.TONE3)[0][0] + result.append((char, py)) + + return result + + def _get_similar_tone_pinyin(self, py): + """ + 获取相似声调的拼音 + """ + # 检查拼音是否为空或无效 + if not py or len(py) < 1: + return py + + # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 + if not py[-1].isdigit(): + # 为非数字结尾的拼音添加数字声调1 + return py + '1' + + base = py[:-1] # 去掉声调 + tone = int(py[-1]) # 获取声调 + + # 处理轻声(通常用5表示)或无效声调 + if tone not in [1, 2, 3, 4]: + return base + str(random.choice([1, 2, 3, 4])) + + # 正常处理声调 + possible_tones = [1, 2, 3, 4] + possible_tones.remove(tone) # 移除原声调 + new_tone = random.choice(possible_tones) # 随机选择一个新声调 + return base + str(new_tone) + + def _calculate_replacement_probability(self, orig_freq, target_freq): + """ + 根据频率差计算替换概率 + """ + if target_freq > orig_freq: + return 1.0 # 如果替换字频率更高,保持原有概率 + + freq_diff = orig_freq - target_freq + if freq_diff > self.max_freq_diff: + return 0.0 # 频率差太大,不替换 + + # 使用指数衰减函数计算概率 + # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 + return math.exp(-3 * freq_diff / self.max_freq_diff) + + def _get_similar_frequency_chars(self, char, py, num_candidates=5): + """ + 获取与给定字频率相近的同音字,可能包含声调错误 + """ + homophones = [] + + # 有一定概率使用错误声调 + if random.random() < self.tone_error_rate: + wrong_tone_py = self._get_similar_tone_pinyin(py) + homophones.extend(self.pinyin_dict[wrong_tone_py]) + + # 添加正确声调的同音字 + homophones.extend(self.pinyin_dict[py]) + + if not homophones: + return None + + # 获取原字的频率 + orig_freq = self.char_frequency.get(char, 0) + + # 计算所有同音字与原字的频率差,并过滤掉低频字 + freq_diff = [(h, self.char_frequency.get(h, 0)) + for h in homophones + if h != char and self.char_frequency.get(h, 0) >= self.min_freq] + + if not freq_diff: + return None + + # 计算每个候选字的替换概率 + candidates_with_prob = [] + for h, freq in freq_diff: + prob = self._calculate_replacement_probability(orig_freq, freq) + if prob > 0: # 只保留有效概率的候选字 + candidates_with_prob.append((h, prob)) + + if not candidates_with_prob: + return None + + # 根据概率排序 + candidates_with_prob.sort(key=lambda x: x[1], reverse=True) + + # 返回概率最高的几个字 + return [char for char, _ in candidates_with_prob[:num_candidates]] + + def _get_word_pinyin(self, word): + """ + 获取词语的拼音列表 + """ + return [py[0] for py in pinyin(word, style=Style.TONE3)] + + def _segment_sentence(self, sentence): + """ + 使用jieba分词,返回词语列表 + """ + return list(jieba.cut(sentence)) + + def _get_word_homophones(self, word): + """ + 获取整个词的同音词,只返回高频的有意义词语 + """ if len(word) == 1: - char = word - py = word_pinyin[0] - if random.random() < error_rate: - similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, - min_freq=min_freq, tone_error_rate=tone_error_rate) - if similar_chars: - typo_char = random.choice(similar_chars) - typo_freq = char_frequency.get(typo_char, 0) - orig_freq = char_frequency.get(char, 0) - replace_prob = calculate_replacement_probability(orig_freq, typo_freq) - if random.random() < replace_prob: - result.append(typo_char) - typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] - typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) - continue - result.append(char) - else: - # 处理多字词的单字替换 - word_result = [] - for i, (char, py) in enumerate(zip(word, word_pinyin)): - # 词中的字替换概率降低 - word_error_rate = error_rate * (0.7 ** (len(word) - 1)) + return [] + + # 获取词的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 遍历所有可能的同音字组合 + candidates = [] + for py in word_pinyin: + chars = self.pinyin_dict.get(py, []) + if not chars: + return [] + candidates.append(chars) + + # 生成所有可能的组合 + import itertools + all_combinations = itertools.product(*candidates) + + # 获取jieba词典和词频信息 + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + valid_words = {} # 改用字典存储词语及其频率 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 2: + word_text = parts[0] + word_freq = float(parts[1]) # 获取词频 + valid_words[word_text] = word_freq + + # 获取原词的词频作为参考 + original_word_freq = valid_words.get(word, 0) + min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% + + # 过滤和计算频率 + homophones = [] + for combo in all_combinations: + new_word = ''.join(combo) + if new_word != word and new_word in valid_words: + new_word_freq = valid_words[new_word] + # 只保留词频达到阈值的词 + if new_word_freq >= min_word_freq: + # 计算词的平均字频(考虑字频和词频) + char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word) + # 综合评分:结合词频和字频 + combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) + if combined_score >= self.min_freq: + homophones.append((new_word, combined_score)) + + # 按综合分数排序并限制返回数量 + sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) + return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 + + def create_typo_sentence(self, sentence): + """ + 创建包含同音字错误的句子,支持词语级别和字级别的替换 + + 参数: + sentence: 输入的中文句子 + + 返回: + typo_sentence: 包含错别字的句子 + typo_info: 错别字信息列表 + """ + result = [] + typo_info = [] + + # 分词 + words = self._segment_sentence(sentence) + + for word in words: + # 如果是标点符号或空格,直接添加 + if all(not self._is_chinese_char(c) for c in word): + result.append(word) + continue - if random.random() < word_error_rate: - similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, - min_freq=min_freq, tone_error_rate=tone_error_rate) + # 获取词语的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 尝试整词替换 + if len(word) > 1 and random.random() < self.word_replace_rate: + word_homophones = self._get_word_homophones(word) + if word_homophones: + typo_word = random.choice(word_homophones) + # 计算词的平均频率 + orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) + typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) + + # 添加到结果中 + result.append(typo_word) + typo_info.append((word, typo_word, + ' '.join(word_pinyin), + ' '.join(self._get_word_pinyin(typo_word)), + orig_freq, typo_freq)) + continue + + # 如果不进行整词替换,则进行单字替换 + if len(word) == 1: + char = word + py = word_pinyin[0] + if random.random() < self.error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) if similar_chars: typo_char = random.choice(similar_chars) - typo_freq = char_frequency.get(typo_char, 0) - orig_freq = char_frequency.get(char, 0) - replace_prob = calculate_replacement_probability(orig_freq, typo_freq) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) if random.random() < replace_prob: - word_result.append(typo_char) + result.append(typo_char) typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) continue - word_result.append(char) - result.append(''.join(word_result)) - - return ''.join(result), typo_info + result.append(char) + else: + # 处理多字词的单字替换 + word_result = [] + for i, (char, py) in enumerate(zip(word, word_pinyin)): + # 词中的字替换概率降低 + word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) + + if random.random() < word_error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + word_result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + word_result.append(char) + result.append(''.join(word_result)) + + return ''.join(result), typo_info -def format_frequency(freq): - """ - 格式化频率显示 - """ - return f"{freq:.2f}" - -def main(): - # 记录开始时间 - start_time = time.time() - - # 首先创建拼音字典和加载字频统计 - print("正在加载汉字数据库,请稍候...") - pinyin_dict = create_pinyin_dict() - char_frequency = load_or_create_char_frequency() - - # 获取用户输入 - sentence = input("请输入中文句子:") - - # 创建包含错别字的句子 - typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency, - error_rate=0.3, min_freq=5, - tone_error_rate=0.2, word_replace_rate=0.3) - - # 打印结果 - print("\n原句:", sentence) - print("错字版:", typo_sentence) - - if typo_info: - print("\n错别字信息:") + def format_typo_info(self, typo_info): + """ + 格式化错别字信息 + + 参数: + typo_info: 错别字信息列表 + + 返回: + 格式化后的错别字信息字符串 + """ + if not typo_info: + return "未生成错别字" + + result = [] for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: # 判断是否为词语替换 is_word = ' ' in orig_py @@ -459,25 +380,53 @@ def main(): tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] error_type = "声调错误" if tone_error else "同音字替换" - print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> " - f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]") + result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " + f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") + + return "\n".join(result) - # 获取拼音结果 - result = get_pinyin(sentence) + def set_params(self, **kwargs): + """ + 设置参数 + + 可设置参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + print(f"参数 {key} 已设置为 {value}") + else: + print(f"警告: 参数 {key} 不存在") + +def main(): + # 创建错别字生成器实例 + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.02, + word_replace_rate=0.3 + ) - # 打印完整拼音 - print("\n完整拼音:") - print(" ".join(py for _, py in result)) + # 获取用户输入 + sentence = input("请输入中文句子:") - # 打印词语分析 - print("\n词语分析:") - words = segment_sentence(sentence) - for word in words: - if any(is_chinese_char(c) for c in word): - word_pinyin = get_word_pinyin(word) - print(f"词语:{word}") - print(f"拼音:{' '.join(word_pinyin)}") - print("---") + # 创建包含错别字的句子 + start_time = time.time() + typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) + + # 打印结果 + print("\n原句:", sentence) + print("错字版:", typo_sentence) + + # 打印错别字信息 + if typo_info: + print("\n错别字信息:") + print(typo_generator.format_typo_info(typo_info)) # 计算并打印总耗时 end_time = time.time() diff --git a/template.env b/template.env index d70bba206..09fe63597 100644 --- a/template.env +++ b/template.env @@ -5,7 +5,7 @@ PORT=8080 PLUGINS=["src2.plugins.chat"] # 默认配置 -MONGODB_HOST=127.0.0.1 +MONGODB_HOST=127.0.0.1 # 如果工作在Docker下,请改成 MONGODB_HOST=mongodb MONGODB_PORT=27017 DATABASE_NAME=MegBot diff --git a/config/auto_format.py b/template/auto_format.py similarity index 99% rename from config/auto_format.py rename to template/auto_format.py index 9bc27da29..d99e29e34 100644 --- a/config/auto_format.py +++ b/template/auto_format.py @@ -1,8 +1,10 @@ -import tomli -import tomli_w +import os import sys from pathlib import Path -import os + +import tomli +import tomli_w + def sync_configs(): # 读取两个配置文件 diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml new file mode 100644 index 000000000..59d053cc1 --- /dev/null +++ b/template/bot_config_template.toml @@ -0,0 +1,143 @@ +[bot] +qq = 123 +nickname = "麦麦" + +[personality] +prompt_personality = [ + "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧人格 + "是一个女大学生,你有黑色头发,你会刷小红书", # 小红书人格 + "是一个女大学生,你会刷b站,对ACG文化感兴趣" # b站人格 + ] +personality_1_probability = 0.6 # 第一种人格出现概率 +personality_2_probability = 0.3 # 第二种人格出现概率 +personality_3_probability = 0.1 # 第三种人格出现概率,请确保三个概率相加等于1 +prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + +[message] +min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息 +max_context_size = 15 # 麦麦获得的上文数量 +emoji_chance = 0.2 # 麦麦使用表情包的概率 +thinking_timeout = 10 # 麦麦思考时间 + +response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1 +response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数 +down_frequency_rate = 3.5 # 降低回复频率的群组回复意愿降低系数 +ban_words = [ + # "403","张三" + ] + +[emoji] +check_interval = 120 # 检查表情包的时间间隔 +register_interval = 10 # 注册表情包的时间间隔 +auto_save = true # 自动偷表情包 +enable_check = false # 是否启用表情包过滤 +check_prompt = "符合公序良俗" # 表情包过滤要求 + +[cq_code] +enable_pic_translate = false + +[response] +model_r1_probability = 0.8 # 麦麦回答时选择主要回复模型1 模型的概率 +model_v3_probability = 0.1 # 麦麦回答时选择次要回复模型2 模型的概率 +model_r1_distill_probability = 0.1 # 麦麦回答时选择次要回复模型3 模型的概率 +max_response_length = 1024 # 麦麦回答的最大token数 + +[memory] +build_memory_interval = 300 # 记忆构建间隔 单位秒 +forget_memory_interval = 300 # 记忆遗忘间隔 单位秒 + +[mood] +mood_update_interval = 1.0 # 情绪更新间隔 单位秒 +mood_decay_rate = 0.95 # 情绪衰减率 +mood_intensity_factor = 1.0 # 情绪强度因子 + +[others] +enable_advance_output = true # 是否启用高级输出 +enable_kuuki_read = true # 是否启用读空气功能 + +[groups] +talk_allowed = [ + 123, + 123, +] #可以回复消息的群 +talk_frequency_down = [] #降低回复频率的群 +ban_user_id = [] #禁止回复消息的QQ号 + + +#V3 +#name = "deepseek-chat" +#base_url = "DEEP_SEEK_BASE_URL" +#key = "DEEP_SEEK_KEY" + +#R1 +#name = "deepseek-reasoner" +#base_url = "DEEP_SEEK_BASE_URL" +#key = "DEEP_SEEK_KEY" + +#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env.prod自定义的宏,使用自定义模型则选择定位相似的模型自己填写 + +#推理模型: + +[model.llm_reasoning] #回复模型1 主要回复模型 +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +pri_in = 0 #模型的输入价格(非必填,可以记录消耗) +pri_out = 0 #模型的输出价格(非必填,可以记录消耗) + +[model.llm_reasoning_minor] #回复模型3 次要回复模型 +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +#非推理模型 + +[model.llm_normal] #V3 回复模型2 次要回复模型 +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal_minor] #V2.5 +name = "deepseek-ai/DeepSeek-V2.5" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_emotion_judge] #主题判断 0.7/m +name = "Qwen/Qwen2.5-14B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_topic_judge] #主题判断:建议使用qwen2.5 7b +name = "Pro/Qwen/Qwen2.5-7B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_summary_by_topic] #建议使用qwen2.5 32b 及以上 +name = "Qwen/Qwen2.5-32B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +pri_in = 0 +pri_out = 0 + +[model.moderation] #内容审核 未启用 +name = "" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +pri_in = 0 +pri_out = 0 + +# 识图模型 + +[model.vlm] #图像识别 0.35/m +name = "Pro/Qwen/Qwen2-VL-7B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + + + +#嵌入模型 + +[model.embedding] #嵌入 +name = "BAAI/bge-m3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY"