diff --git a/.envrc b/.envrc new file mode 100644 index 000000000..8392d159f --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake \ No newline at end of file diff --git a/.gitignore b/.gitignore index 51a11d8c2..38deb3666 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,10 @@ cython_debug/ # jieba jieba.cache + + +# vscode +/.vscode + +# direnv +/.direnv \ No newline at end of file diff --git a/README.md b/README.md index 7bfa465ae..04cfc0772 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,19 @@ **🍔麦麦是一个基于大语言模型的智能QQ群聊机器人** -- 🤖 基于 nonebot2 框架开发 -- 🧠 LLM 提供对话能力 -- 💾 MongoDB 提供数据持久化支持 -- 🐧 NapCat 作为QQ协议端支持 +- 基于 nonebot2 框架开发 +- LLM 提供对话能力 +- MongoDB 提供数据持久化支持 +- NapCat 作为QQ协议端支持 + +**最新版本: v0.5.***
麦麦演示视频
👆 点击观看麦麦演示视频 👆 +
@@ -31,12 +34,31 @@ > - 文档未完善,有问题可以提交 Issue 或者 Discussion > - QQ机器人存在被限制风险,请自行了解,谨慎使用 > - 由于持续迭代,可能存在一些已知或未知的bug +> - 由于开发中,可能消耗较多token -**交流群**: 766798517(仅用于开发和建议相关讨论)不建议在群内询问部署问题,我不一定有空回复,会优先写文档和代码 +**交流群**: 766798517(仅用于开发和建议相关讨论)不一定有空回复,但大家可以自行交流部署问题,我会优先写文档和代码 -## 📚 文档 +## +
+

📚 文档 ⬇️ 快速开始使用麦麦 ⬇️

+
+ +### 部署方式 + +如果你不知道Docker是什么,建议寻找相关教程或使用手动部署 + +- [🐳 Docker部署指南](docs/docker_deploy.md) + +- [📦 手动部署指南](docs/manual_deploy.md) + +### 配置说明 +- [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘 +- [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户 + +
+

了解麦麦

+
-- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明 - [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节 ## 🎯 功能介绍 @@ -69,6 +91,12 @@ ## 开发计划TODO:LIST + +规划主线 +0.6.0:记忆系统更新 +0.7.0: 麦麦RunTime + + - 人格功能:WIP - 群氛围功能:WIP - 图片发送,转发功能:WIP @@ -86,10 +114,21 @@ - 改进表情包发送逻辑 - 自动生成的回复逻辑,例如自生成的回复方向,回复风格 - 采用截断生成加快麦麦的反应速度 -- 改进发送消息的触发: +- 改进发送消息的触发 + + +## 设计理念 + +- **千石可乐说:** +- 这个项目最初只是为了给牛牛bot添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在. +- 程序的功能设计理念基于一个核心的原则:"最像而不是好" +- 主打一个陪伴 +- 如果人类真的需要一个AI来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的helpful assistant,而是一个会犯错的,拥有自己感知和想法的"生命形式"。 +- 代码会保持开源和开放,但个人希望MaiMbot的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器. + ## 📌 注意事项 -纯编程外行,面向cursor编程,很多代码史一样多多包涵 +SengokuCola纯编程外行,面向cursor编程,很多代码史一样多多包涵 > ⚠️ **警告**:本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本人观点和立场。 diff --git a/bot.py b/bot.py index 50c8cfaa4..1bebfe553 100644 --- a/bot.py +++ b/bot.py @@ -8,7 +8,7 @@ from loguru import logger from colorama import init, Fore init() -text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" +text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午" rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA] rainbow_text = "" for i, char in enumerate(text): @@ -17,11 +17,11 @@ print(rainbow_text) '''彩蛋''' # 初次启动检测 -if not os.path.exists("config/bot_config.toml") or not os.path.exists(".env"): - logger.info("检测到bot_config.toml不存在,正在从模板复制") +if not os.path.exists("config/bot_config.toml"): + logger.warning("检测到bot_config.toml不存在,正在从模板复制") import shutil - shutil.copy("config/bot_config_template.toml", "config/bot_config.toml") + shutil.copy("templete/bot_config_template.toml", "config/bot_config.toml") logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动") # 初始化.env 默认ENVIRONMENT=prod diff --git a/docs/doc1.md b/docs/doc1.md index 34de628ed..158136b9c 100644 --- a/docs/doc1.md +++ b/docs/doc1.md @@ -83,7 +83,6 @@ 14. **`topic_identifier.py`**: - 识别消息中的主题,帮助机器人理解用户的意图。 - - 使用多种方法(LLM、jieba、snownlp)进行主题识别。 15. **`utils.py`** 和 **`utils_*.py`** 系列文件: - 存放各种工具函数,提供辅助功能以支持其他模块。 diff --git a/docs/docker_deploy.md b/docs/docker_deploy.md new file mode 100644 index 000000000..c9b069309 --- /dev/null +++ b/docs/docker_deploy.md @@ -0,0 +1,24 @@ +# 🐳 Docker 部署指南 + +## 部署步骤(推荐,但不一定是最新) + +1. 获取配置文件: +```bash +wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml +``` + +2. 启动服务: +```bash +NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d +``` + +3. 修改配置后重启: +```bash +NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart +``` + +## ⚠️ 注意事项 + +- 目前部署方案仍在测试中,可能存在未知问题 +- 配置文件中的API密钥请妥善保管,不要泄露 +- 建议先在测试环境中运行,确认无误后再部署到生产环境 \ No newline at end of file diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index c988eb7c9..000000000 --- a/docs/installation.md +++ /dev/null @@ -1,145 +0,0 @@ -# 🔧 安装与配置指南 - -## 部署方式 - -如果你不知道Docker是什么,建议寻找相关教程或使用手动部署 - -### 🐳 Docker部署(推荐,但不一定是最新) - -1. 获取配置文件: -```bash -wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml -``` - -2. 启动服务: -```bash -NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d -``` - -3. 修改配置后重启: -```bash -NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart -``` - -### 📦 手动部署 - -1. **环境准备** -```bash -# 创建虚拟环境(推荐) -python -m venv venv -venv\\Scripts\\activate # Windows -# 安装依赖 -pip install -r requirements.txt -``` - -2. **配置MongoDB** -- 安装并启动MongoDB服务 -- 默认连接本地27017端口 - -3. **配置NapCat** -- 安装并登录NapCat -- 添加反向WS:`ws://localhost:8080/onebot/v11/ws` - -4. **配置文件设置** -- 修改环境配置文件:`.env.prod` -- 修改机器人配置文件:`bot_config.toml` - -5. **启动麦麦机器人** -- 打开命令行,cd到对应路径 -```bash -nb run -``` - -6. **其他组件** -- `run_thingking.bat`: 启动可视化推理界面(未完善) - -- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~ -- 直接运行 knowledge.py生成知识库 - -## ⚙️ 配置说明 - -### 环境配置 (.env.prod) -```ini -# API配置,你可以在这里定义你的密钥和base_url -# 你可以选择定义其他服务商提供的KEY,完全可以自定义 -SILICONFLOW_KEY=your_key -SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ -DEEP_SEEK_KEY=your_key -DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 - -# 服务配置,如果你不知道这是什么,保持默认 -HOST=127.0.0.1 -PORT=8080 - -# 数据库配置,如果你不知道这是什么,保持默认 -MONGODB_HOST=127.0.0.1 -MONGODB_PORT=27017 -DATABASE_NAME=MegBot -``` - -### 机器人配置 (bot_config.toml) -```toml -[bot] -qq = "你的机器人QQ号" -nickname = "麦麦" - -[message] -min_text_length = 2 -max_context_size = 15 -emoji_chance = 0.2 - -[emoji] -check_interval = 120 -register_interval = 10 - -[cq_code] -enable_pic_translate = false - -[response] -#现已移除deepseek或硅基流动选项,可以直接切换分别配置任意模型 -model_r1_probability = 0.8 #推理模型权重 -model_v3_probability = 0.1 #非推理模型权重 -model_r1_distill_probability = 0.1 - -[memory] -build_memory_interval = 300 - -[others] -enable_advance_output = true # 是否启用详细日志输出 - -[groups] -talk_allowed = [] # 允许回复的群号列表 -talk_frequency_down = [] # 降低回复频率的群号列表 -ban_user_id = [] # 禁止回复的用户QQ号列表 - -[model.llm_reasoning] -name = "Pro/deepseek-ai/DeepSeek-R1" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_reasoning_minor] -name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal] -name = "Pro/deepseek-ai/DeepSeek-V3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.llm_normal_minor] -name = "deepseek-ai/DeepSeek-V2.5" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" - -[model.vlm] -name = "deepseek-ai/deepseek-vl2" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY" -``` - -## ⚠️ 注意事项 - -- 目前部署方案仍在测试中,可能存在未知问题 -- 配置文件中的API密钥请妥善保管,不要泄露 -- 建议先在测试环境中运行,确认无误后再部署到生产环境 \ No newline at end of file diff --git a/docs/installation_cute.md b/docs/installation_cute.md new file mode 100644 index 000000000..278cbfe20 --- /dev/null +++ b/docs/installation_cute.md @@ -0,0 +1,215 @@ +# 🔧 配置指南 喵~ + +## 👋 你好呀! + +让咱来告诉你我们要做什么喵: +1. 我们要一起设置一个可爱的AI机器人 +2. 这个机器人可以在QQ上陪你聊天玩耍哦 +3. 需要设置两个文件才能让机器人工作呢 + +## 📝 需要设置的文件喵 + +要设置这两个文件才能让机器人跑起来哦: +1. `.env.prod` - 这个文件告诉机器人要用哪些AI服务呢 +2. `bot_config.toml` - 这个文件教机器人怎么和你聊天喵 + +## 🔑 密钥和域名的对应关系 + +想象一下,你要进入一个游乐园,需要: +1. 知道游乐园的地址(这就是域名 base_url) +2. 有入场的门票(这就是密钥 key) + +在 `.env.prod` 文件里,我们定义了三个游乐园的地址和门票喵: +```ini +# 硅基流动游乐园 +SILICONFLOW_KEY=your_key # 硅基流动的门票 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ # 硅基流动的地址 + +# DeepSeek游乐园 +DEEP_SEEK_KEY=your_key # DeepSeek的门票 +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 # DeepSeek的地址 + +# ChatAnyWhere游乐园 +CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere的门票 +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere的地址 +``` + +然后在 `bot_config.toml` 里,机器人会用这些门票和地址去游乐园玩耍: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" # 告诉机器人:去硅基流动游乐园玩 +key = "SILICONFLOW_KEY" # 用硅基流动的门票进去 + +[model.llm_normal] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" # 还是去硅基流动游乐园 +key = "SILICONFLOW_KEY" # 用同一张门票就可以啦 +``` + +### 🎪 举个例子喵: + +如果你想用DeepSeek官方的服务,就要这样改: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "DEEP_SEEK_BASE_URL" # 改成去DeepSeek游乐园 +key = "DEEP_SEEK_KEY" # 用DeepSeek的门票 + +[model.llm_normal] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "DEEP_SEEK_BASE_URL" # 也去DeepSeek游乐园 +key = "DEEP_SEEK_KEY" # 用同一张DeepSeek门票 +``` + +### 🎯 简单来说: +- `.env.prod` 文件就像是你的票夹,存放着各个游乐园的门票和地址 +- `bot_config.toml` 就是告诉机器人:用哪张票去哪个游乐园玩 +- 所有模型都可以用同一个游乐园的票,也可以去不同的游乐园玩耍 +- 如果用硅基流动的服务,就保持默认配置不用改呢~ + +记住:门票(key)要保管好,不能给别人看哦,不然别人就可以用你的票去玩了喵! + +## ---让我们开始吧--- + +### 第一个文件:环境配置 (.env.prod) + +这个文件就像是机器人的"身份证"呢,告诉它要用哪些AI服务喵~ + +```ini +# 这些是AI服务的密钥,就像是魔法钥匙一样呢 +# 要把 your_key 换成真正的密钥才行喵 +# 比如说:SILICONFLOW_KEY=sk-123456789abcdef +SILICONFLOW_KEY=your_key +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY=your_key +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +CHAT_ANY_WHERE_KEY=your_key +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 + +# 如果你不知道这是什么,那么下面这些不用改,保持原样就好啦 +HOST=127.0.0.1 +PORT=8080 + +# 这些是数据库设置,一般也不用改呢 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot +MONGODB_USERNAME = "" # 如果数据库需要用户名,就在这里填写喵 +MONGODB_PASSWORD = "" # 如果数据库需要密码,就在这里填写呢 +MONGODB_AUTH_SOURCE = "" # 数据库认证源,一般不用改哦 + +# 插件设置喵 +PLUGINS=["src2.plugins.chat"] # 这里是机器人的插件列表呢 +``` + +### 第二个文件:机器人配置 (bot_config.toml) + +这个文件就像是教机器人"如何说话"的魔法书呢! + +```toml +[bot] +qq = "把这里改成你的机器人QQ号喵" # 填写你的机器人QQ号 +nickname = "麦麦" # 机器人的名字,你可以改成你喜欢的任何名字哦 + +[personality] +# 这里可以设置机器人的性格呢,让它更有趣一些喵 +prompt_personality = [ + "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧风格的性格 + "是一个女大学生,你有黑色头发,你会刷小红书" # 小红书风格的性格 +] +prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + +[message] +min_text_length = 2 # 机器人每次至少要说几个字呢 +max_context_size = 15 # 机器人能记住多少条消息喵 +emoji_chance = 0.2 # 机器人使用表情的概率哦(0.2就是20%的机会呢) +ban_words = ["脏话", "不文明用语"] # 在这里填写不让机器人说的词 + +[emoji] +auto_save = true # 是否自动保存看到的表情包呢 +enable_check = false # 是否要检查表情包是不是合适的喵 +check_prompt = "符合公序良俗" # 检查表情包的标准呢 + +[groups] +talk_allowed = [123456, 789012] # 比如:让机器人在群123456和789012里说话 +talk_frequency_down = [345678] # 比如:在群345678里少说点话 +ban_user_id = [111222] # 比如:不回复QQ号为111222的人的消息 + +[others] +enable_advance_output = true # 是否要显示更多的运行信息呢 +enable_kuuki_read = true # 让机器人能够"察言观色"喵 + +# 模型配置部分的详细说明喵~ + + +#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成在.env.prod自己指定的密钥和域名,使用自定义模型则选择定位相似的模型自己填写 + +[model.llm_reasoning] #推理模型R1,用来理解和思考的喵 +name = "Pro/deepseek-ai/DeepSeek-R1" # 模型名字 +# name = "Qwen/QwQ-32B" # 如果想用千问模型,可以把上面那行注释掉,用这个呢 +base_url = "SILICONFLOW_BASE_URL" # 使用在.env.prod里设置的服务地址 +key = "SILICONFLOW_KEY" # 使用在.env.prod里设置的密钥 + +[model.llm_reasoning_minor] #R1蒸馏模型,是个轻量版的推理模型喵 +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal] #V3模型,用来日常聊天的喵 +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal_minor] #V2.5模型,是V3的前代版本呢 +name = "deepseek-ai/DeepSeek-V2.5" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.vlm] #图像识别模型,让机器人能看懂图片喵 +name = "deepseek-ai/deepseek-vl2" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.embedding] #嵌入模型,帮助机器人理解文本的相似度呢 +name = "BAAI/bge-m3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +# 如果选择了llm方式提取主题,就用这个模型配置喵 +[topic.llm_topic] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +``` + +## 💡 模型配置说明喵 + +1. **关于模型服务**: + - 如果你用硅基流动的服务,这些配置都不用改呢 + - 如果用DeepSeek官方API,要把base_url和key改成你在.env.prod里设置的值喵 + - 如果要用自定义模型,选择一个相似功能的模型配置来改呢 + +2. **主要模型功能**: + - `llm_reasoning`: 负责思考和推理的大脑喵 + - `llm_normal`: 负责日常聊天的嘴巴呢 + - `vlm`: 负责看图片的眼睛哦 + - `embedding`: 负责理解文字含义的理解力喵 + - `topic`: 负责理解对话主题的能力呢 + +## 🌟 小提示 +- 如果你刚开始使用,建议保持默认配置呢 +- 不同的模型有不同的特长,可以根据需要调整它们的使用比例哦 + +## 🌟 小贴士喵 +- 记得要好好保管密钥(key)哦,不要告诉别人呢 +- 配置文件要小心修改,改错了机器人可能就不能和你玩了喵 +- 如果想让机器人更聪明,可以调整 personality 里的设置呢 +- 不想让机器人说某些话,就把那些词放在 ban_words 里面喵 +- QQ群号和QQ号都要用数字填写,不要加引号哦(除了机器人自己的QQ号) + +## ⚠️ 注意事项 +- 这个机器人还在测试中呢,可能会有一些小问题喵 +- 如果不知道怎么改某个设置,就保持原样不要动它哦~ +- 记得要先有AI服务的密钥,不然机器人就不能和你说话了呢 +- 修改完配置后要重启机器人才能生效喵~ \ No newline at end of file diff --git a/docs/installation_standard.md b/docs/installation_standard.md new file mode 100644 index 000000000..6e4920220 --- /dev/null +++ b/docs/installation_standard.md @@ -0,0 +1,154 @@ +# 🔧 配置指南 + +## 简介 + +本项目需要配置两个主要文件: +1. `.env.prod` - 配置API服务和系统环境 +2. `bot_config.toml` - 配置机器人行为和模型 + +## API配置说明 + +`.env.prod`和`bot_config.toml`中的API配置关系如下: + +### 在.env.prod中定义API凭证: +```ini +# API凭证配置 +SILICONFLOW_KEY=your_key # 硅基流动API密钥 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ # 硅基流动API地址 + +DEEP_SEEK_KEY=your_key # DeepSeek API密钥 +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 # DeepSeek API地址 + +CHAT_ANY_WHERE_KEY=your_key # ChatAnyWhere API密钥 +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 # ChatAnyWhere API地址 +``` + +### 在bot_config.toml中引用API凭证: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" # 引用.env.prod中定义的地址 +key = "SILICONFLOW_KEY" # 引用.env.prod中定义的密钥 +``` + +如需切换到其他API服务,只需修改引用: +```toml +[model.llm_reasoning] +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "DEEP_SEEK_BASE_URL" # 切换为DeepSeek服务 +key = "DEEP_SEEK_KEY" # 使用DeepSeek密钥 +``` + +## 配置文件详解 + +### 环境配置文件 (.env.prod) +```ini +# API配置 +SILICONFLOW_KEY=your_key +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY=your_key +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +CHAT_ANY_WHERE_KEY=your_key +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 + +# 服务配置 +HOST=127.0.0.1 +PORT=8080 + +# 数据库配置 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot +MONGODB_USERNAME = "" # 数据库用户名 +MONGODB_PASSWORD = "" # 数据库密码 +MONGODB_AUTH_SOURCE = "" # 认证数据库 + +# 插件配置 +PLUGINS=["src2.plugins.chat"] +``` + +### 机器人配置文件 (bot_config.toml) +```toml +[bot] +qq = "机器人QQ号" # 必填 +nickname = "麦麦" # 机器人昵称 + +[personality] +prompt_personality = [ + "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", + "是一个女大学生,你有黑色头发,你会刷小红书" +] +prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书" + +[message] +min_text_length = 2 # 最小回复长度 +max_context_size = 15 # 上下文记忆条数 +emoji_chance = 0.2 # 表情使用概率 +ban_words = [] # 禁用词列表 + +[emoji] +auto_save = true # 自动保存表情 +enable_check = false # 启用表情审核 +check_prompt = "符合公序良俗" + +[groups] +talk_allowed = [] # 允许对话的群号 +talk_frequency_down = [] # 降低回复频率的群号 +ban_user_id = [] # 禁止回复的用户QQ号 + +[others] +enable_advance_output = true # 启用详细日志 +enable_kuuki_read = true # 启用场景理解 + +# 模型配置 +[model.llm_reasoning] # 推理模型 +name = "Pro/deepseek-ai/DeepSeek-R1" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_reasoning_minor] # 轻量推理模型 +name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal] # 对话模型 +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_normal_minor] # 备用对话模型 +name = "deepseek-ai/DeepSeek-V2.5" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.vlm] # 图像识别模型 +name = "deepseek-ai/deepseek-vl2" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.embedding] # 文本向量模型 +name = "BAAI/bge-m3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + + +[topic.llm_topic] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +``` + +## 注意事项 + +1. API密钥安全: + - 妥善保管API密钥 + - 不要将含有密钥的配置文件上传至公开仓库 + +2. 配置修改: + - 修改配置后需重启服务 + - 使用默认服务(硅基流动)时无需修改模型配置 + - QQ号和群号使用数字格式(机器人QQ号除外) + +3. 其他说明: + - 项目处于测试阶段,可能存在未知问题 + - 建议初次使用保持默认配置 \ No newline at end of file diff --git a/docs/manual_deploy.md b/docs/manual_deploy.md new file mode 100644 index 000000000..6d53beb4e --- /dev/null +++ b/docs/manual_deploy.md @@ -0,0 +1,100 @@ +# 📦 如何手动部署MaiMbot麦麦? + +## 你需要什么? + +- 一台电脑,能够上网的那种 + +- 一个QQ小号(QQ框架的使用可能导致qq被风控,严重(小概率)可能会导致账号封禁,强烈不推荐使用大号) + +- 可用的大模型API + +- 一个AI助手,网上随便搜一家打开来用都行,可以帮你解决一些不懂的问题 + +## 你需要知道什么? + +- 如何正确向AI助手提问,来学习新知识 + +- Python是什么 + +- Python的虚拟环境是什么?如何创建虚拟环境 + +- 命令行是什么 + +- 数据库是什么?如何安装并启动MongoDB + +- 如何运行一个QQ机器人,以及NapCat框架是什么 + +## 如果准备好了,就可以开始部署了 + +### 1️⃣ **首先,我们需要安装正确版本的Python** + +在创建虚拟环境之前,请确保你的电脑上安装了Python 3.9及以上版本。如果没有,可以按以下步骤安装: + +1. 访问Python官网下载页面:https://www.python.org/downloads/release/python-3913/ +2. 下载Windows安装程序 (64-bit): `python-3.9.13-amd64.exe` +3. 运行安装程序,并确保勾选"Add Python 3.9 to PATH"选项 +4. 点击"Install Now"开始安装 + +或者使用PowerShell自动下载安装(需要管理员权限): +```powershell +# 下载并安装Python 3.9.13 +$pythonUrl = "https://www.python.org/ftp/python/3.9.13/python-3.9.13-amd64.exe" +$pythonInstaller = "$env:TEMP\python-3.9.13-amd64.exe" +Invoke-WebRequest -Uri $pythonUrl -OutFile $pythonInstaller +Start-Process -Wait -FilePath $pythonInstaller -ArgumentList "/quiet", "InstallAllUsers=0", "PrependPath=1" -Verb RunAs +``` + +### 2️⃣ **创建Python虚拟环境来运行程序** + + 你可以选择使用以下两种方法之一来创建Python环境: + +```bash +# ---方法1:使用venv(Python自带) +# 在命令行中创建虚拟环境(环境名为maimbot) +# 这会让你在运行命令的目录下创建一个虚拟环境 +# 请确保你已通过cd命令前往到了对应路径,不然之后你可能找不到你的python环境 +python -m venv maimbot + +maimbot\\Scripts\\activate + +# 安装依赖 +pip install -r requirements.txt +``` +```bash +# ---方法2:使用conda +# 创建一个新的conda环境(环境名为maimbot) +# Python版本为3.9 +conda create -n maimbot python=3.9 + +# 激活环境 +conda activate maimbot + +# 安装依赖 +pip install -r requirements.txt +``` + +### 2️⃣ **然后你需要启动MongoDB数据库,来存储信息** +- 安装并启动MongoDB服务 +- 默认连接本地27017端口 + +### 3️⃣ **配置NapCat,让麦麦bot与qq取得联系** +- 安装并登录NapCat(用你的qq小号) +- 添加反向WS:`ws://localhost:8080/onebot/v11/ws` + +### 4️⃣ **配置文件设置,让麦麦Bot正常工作** +- 修改环境配置文件:`.env.prod` +- 修改机器人配置文件:`bot_config.toml` + +### 5️⃣ **启动麦麦机器人** +- 打开命令行,cd到对应路径 +```bash +nb run +``` +- 或者cd到对应路径后 +```bash +python bot.py +``` + +### 6️⃣ **其他组件(可选)** +- `run_thingking.bat`: 启动可视化推理界面(未完善) +- 直接运行 knowledge.py生成知识库 diff --git a/flake.lock b/flake.lock new file mode 100644 index 000000000..dd215f1c6 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1741196730, + "narHash": "sha256-0Sj6ZKjCpQMfWnN0NURqRCQn2ob7YtXTAOTwCuz7fkA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "48913d8f9127ea6530a2a2f1bd4daa1b8685d8a3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000..54737d640 --- /dev/null +++ b/flake.nix @@ -0,0 +1,61 @@ +{ + description = "MaiMBot Nix Dev Env"; + # 本配置仅方便用于开发,但是因为 nb-cli 上游打包中并未包含 nonebot2,因此目前本配置并不能用于运行和调试 + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { + inherit system; + }; + + pythonEnv = pkgs.python3.withPackages ( + ps: with ps; [ + pymongo + python-dotenv + pydantic + jieba + openai + aiohttp + requests + urllib3 + numpy + pandas + matplotlib + networkx + python-dateutil + APScheduler + loguru + tomli + customtkinter + colorama + pypinyin + pillow + setuptools + ] + ); + in + { + devShell = pkgs.mkShell { + buildInputs = [ + pythonEnv + pkgs.nb-cli + ]; + + shellHook = '' + ''; + }; + } + ); +} diff --git a/llm_statistics.txt b/llm_statistics.txt new file mode 100644 index 000000000..338158ef8 --- /dev/null +++ b/llm_statistics.txt @@ -0,0 +1,74 @@ +LLM请求统计报告 (生成时间: 2025-03-07 20:38:57) +================================================== + +所有时间统计 +====== +总请求数: 858 +总Token数: 285415 +总花费: ¥0.3309 + +按模型统计: +- Pro/Qwen/Qwen2-VL-7B-Instruct: 67次 (花费: ¥0.0272) +- Pro/Qwen/Qwen2.5-7B-Instruct: 646次 (花费: ¥0.0718) +- Pro/deepseek-ai/DeepSeek-V3: 9次 (花费: ¥0.0193) +- Qwen/QwQ-32B: 29次 (花费: ¥0.1246) +- Qwen/Qwen2.5-32B-Instruct: 55次 (花费: ¥0.0771) +- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B: 3次 (花费: ¥0.0067) +- deepseek-ai/DeepSeek-V2.5: 49次 (花费: ¥0.0043) + +按请求类型统计: +- chat: 858次 (花费: ¥0.3309) + +最近7天统计 +====== +总请求数: 858 +总Token数: 285415 +总花费: ¥0.3309 + +按模型统计: +- Pro/Qwen/Qwen2-VL-7B-Instruct: 67次 (花费: ¥0.0272) +- Pro/Qwen/Qwen2.5-7B-Instruct: 646次 (花费: ¥0.0718) +- Pro/deepseek-ai/DeepSeek-V3: 9次 (花费: ¥0.0193) +- Qwen/QwQ-32B: 29次 (花费: ¥0.1246) +- Qwen/Qwen2.5-32B-Instruct: 55次 (花费: ¥0.0771) +- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B: 3次 (花费: ¥0.0067) +- deepseek-ai/DeepSeek-V2.5: 49次 (花费: ¥0.0043) + +按请求类型统计: +- chat: 858次 (花费: ¥0.3309) + +最近24小时统计 +======== +总请求数: 858 +总Token数: 285415 +总花费: ¥0.3309 + +按模型统计: +- Pro/Qwen/Qwen2-VL-7B-Instruct: 67次 (花费: ¥0.0272) +- Pro/Qwen/Qwen2.5-7B-Instruct: 646次 (花费: ¥0.0718) +- Pro/deepseek-ai/DeepSeek-V3: 9次 (花费: ¥0.0193) +- Qwen/QwQ-32B: 29次 (花费: ¥0.1246) +- Qwen/Qwen2.5-32B-Instruct: 55次 (花费: ¥0.0771) +- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B: 3次 (花费: ¥0.0067) +- deepseek-ai/DeepSeek-V2.5: 49次 (花费: ¥0.0043) + +按请求类型统计: +- chat: 858次 (花费: ¥0.3309) + +最近1小时统计 +======= +总请求数: 858 +总Token数: 285415 +总花费: ¥0.3309 + +按模型统计: +- Pro/Qwen/Qwen2-VL-7B-Instruct: 67次 (花费: ¥0.0272) +- Pro/Qwen/Qwen2.5-7B-Instruct: 646次 (花费: ¥0.0718) +- Pro/deepseek-ai/DeepSeek-V3: 9次 (花费: ¥0.0193) +- Qwen/QwQ-32B: 29次 (花费: ¥0.1246) +- Qwen/Qwen2.5-32B-Instruct: 55次 (花费: ¥0.0771) +- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B: 3次 (花费: ¥0.0067) +- deepseek-ai/DeepSeek-V2.5: 49次 (花费: ¥0.0043) + +按请求类型统计: +- chat: 858次 (花费: ¥0.3309) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 49c102dc6..4f969682f 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run_maimai.bat b/run_maimai.bat index ff00cc5c1..3a099fd7f 100644 --- a/run_maimai.bat +++ b/run_maimai.bat @@ -1,5 +1,5 @@ chcp 65001 -call conda activate niuniu +call conda activate maimbot cd . REM 执行nb run命令 diff --git a/run_windows.bat b/run_windows.bat new file mode 100644 index 000000000..920069318 --- /dev/null +++ b/run_windows.bat @@ -0,0 +1,67 @@ +@echo off +setlocal enabledelayedexpansion +chcp 65001 + +REM 修正路径获取逻辑 +cd /d "%~dp0" || ( + echo 错误:切换目录失败 + exit /b 1 +) + +if not exist "venv\" ( + echo 正在初始化虚拟环境... + + where python >nul 2>&1 + if %errorlevel% neq 0 ( + echo 未找到Python解释器 + exit /b 1 + ) + + for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a + for /f "tokens=1,2 delims=." %%b in ("!version!") do ( + set major=%%b + set minor=%%c + ) + + if !major! lss 3 ( + echo 需要Python大于等于3.0,当前版本 !version! + exit /b 1 + ) + + if !major! equ 3 if !minor! lss 9 ( + echo 需要Python大于等于3.9,当前版本 !version! + exit /b 1 + ) + + echo 正在安装virtualenv... + python -m pip install virtualenv || ( + echo virtualenv安装失败 + exit /b 1 + ) + + echo 正在创建虚拟环境... + python -m virtualenv venv || ( + echo 虚拟环境创建失败 + exit /b 1 + ) + + call venv\Scripts\activate.bat + + echo 正在安装依赖... + pip install -r requirements.txt +) else ( + call venv\Scripts\activate.bat +) + +echo 当前代理设置: +echo HTTP_PROXY=%HTTP_PROXY% +echo HTTPS_PROXY=%HTTPS_PROXY% + +set HTTP_PROXY= +set HTTPS_PROXY= +echo 代理已取消。 + +set no_proxy=0.0.0.0/32 + +call nb run +pause \ No newline at end of file diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index ab99f6477..f7da8ba96 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -14,7 +14,13 @@ from nonebot.rule import to_me from .bot import chat_bot from .emoji_manager import emoji_manager import time +from ..utils.statistic import LLMStatistics +# 创建LLM统计实例 +llm_stats = LLMStatistics("llm_statistics.txt") + +# 添加标志变量 +_message_manager_started = False # 获取驱动器 driver = get_driver() @@ -55,6 +61,10 @@ scheduler = require("nonebot_plugin_apscheduler").scheduler @driver.on_startup async def start_background_tasks(): """启动后台任务""" + # 启动LLM统计 + llm_stats.start() + print("\033[1;32m[初始化]\033[0m LLM统计功能已启动") + # 只启动表情包管理任务 asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) await bot_schedule.initialize() @@ -70,18 +80,20 @@ async def init_relationships(): @driver.on_bot_connect async def _(bot: Bot): """Bot连接成功时的处理""" + global _message_manager_started print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m") await willing_manager.ensure_started() - message_sender.set_bot(bot) print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m") - asyncio.create_task(message_manager.start_processor()) - print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") + + if not _message_manager_started: + asyncio.create_task(message_manager.start_processor()) + _message_manager_started = True + print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") - # 启动消息发送控制任务 @group_msg.handle() async def _(bot: Bot, event: GroupMessageEvent, state: T_State): @@ -90,7 +102,7 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State): # 添加build_memory定时任务 @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") async def build_memory_task(): - """每30秒执行一次记忆构建""" + """每build_memory_interval秒执行一次记忆构建""" print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------") start_time = time.time() await hippocampus.operation_build_memory(chat_size=20) diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index e3525b3bb..4306c0f9d 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -15,7 +15,7 @@ from .message import Message_Thinking # 导入 Message_Thinking 类 from .relationship_manager import relationship_manager from .willing_manager import willing_manager # 导入意愿管理器 from .utils import is_mentioned_bot_in_txt, calculate_typing_time -from ..memory_system.memory import memory_graph +from ..memory_system.memory import memory_graph,hippocampus from loguru import logger class ChatBot: @@ -58,6 +58,7 @@ class ChatBot: plain_text=event.get_plaintext(), reply_message=event.reply, ) + await message.initialize() # 过滤词 for word in global_config.ban_words: @@ -70,24 +71,12 @@ class ChatBot: - topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) - - - # topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text) - # topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text) - # topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text) - logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") - - all_num = 0 - interested_num = 0 - if topic: - for current_topic in topic: - all_num += 1 - first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) - if first_layer_items: - interested_num += 1 - print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象") - interested_rate = interested_num / all_num if all_num > 0 else 0 + # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) + topic = '' + interested_rate = 0 + interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text)/100 + print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n") + # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") await self.storage.store_message(message, topic[0] if topic else None) @@ -119,14 +108,9 @@ class ChatBot: willing_manager.change_reply_willing_sent(thinking_message.group_id) - response, emotion = await self.gpt.generate_response(message) - - # if response is None: - # thinking_message.interupt=True + response,raw_content = await self.gpt.generate_response(message) if response: - # print(f"\033[1;32m[思考结束]\033[0m 思考结束,已得到回复,开始回复") - # 找到并删除对应的thinking消息 container = message_manager.get_container(event.group_id) thinking_message = None # 找到message,删除 @@ -134,7 +118,7 @@ class ChatBot: if isinstance(msg, Message_Thinking) and msg.message_id == think_id: thinking_message = msg container.messages.remove(msg) - print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除") + # print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除") break #记录开始思考的时间,避免从思考到回复的时间太久 @@ -144,6 +128,7 @@ class ChatBot: accu_typing_time = 0 # print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器") + mark_head = False for msg in response: # print(f"\033[1;32m[回复内容]\033[0m {msg}") #通过时间改变时间戳 @@ -164,15 +149,20 @@ class ChatBot: thinking_start_time=thinking_start_time, #记录了思考开始的时间 reply_message_id=message.message_id ) + await bot_message.initialize() + if not mark_head: + bot_message.is_head = True + mark_head = True message_set.add_message(bot_message) #message_set 可以直接加入 message_manager - print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") + # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") message_manager.add_message(message_set) bot_response_time = tinking_time_point + if random() < global_config.emoji_chance: - emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) + emoji_path,discription = await emoji_manager.get_emoji_for_text(response) if emoji_path: emoji_cq = CQCode.create_emoji_cq(emoji_path) @@ -188,6 +178,7 @@ class ChatBot: raw_message=emoji_cq, plain_text=emoji_cq, processed_plain_text=emoji_cq, + detailed_plain_text=discription, user_nickname=global_config.BOT_NICKNAME, group_name=message.group_name, time=bot_response_time, @@ -196,9 +187,16 @@ class ChatBot: thinking_start_time=thinking_start_time, # reply_message_id=message.message_id ) + await bot_message.initialize() message_manager.add_message(bot_message) + emotion = await self.gpt._get_emotion_tags(raw_content) + print(f"为 '{response}' 获取到的情感标签为:{emotion}") + valuedict={ + 'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25 + } + await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) - willing_manager.change_reply_willing_after_sent(event.group_id) + # willing_manager.change_reply_willing_after_sent(event.group_id) # 创建全局ChatBot实例 chat_bot = ChatBot() \ No newline at end of file diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index be599f48a..5c3c0b27a 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -30,23 +30,26 @@ class BotConfig: forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) + EMOJI_SAVE: bool = True # 偷表情包 + EMOJI_CHECK: bool = False #是否开启过滤 + EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 ban_words = set() + + max_response_length: int = 1024 # 最大回复长度 # 模型配置 llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {}) llm_normal: Dict[str, str] = field(default_factory=lambda: {}) llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {}) + llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {}) + llm_summary_by_topic: Dict[str, str] = field(default_factory=lambda: {}) + llm_emotion_judge: Dict[str, str] = field(default_factory=lambda: {}) embedding: Dict[str, str] = field(default_factory=lambda: {}) vlm: Dict[str, str] = field(default_factory=lambda: {}) + moderation: Dict[str, str] = field(default_factory=lambda: {}) - # 主题提取配置 - topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm - llm_topic_extract: Dict[str, str] = field(default_factory=lambda: {}) - - API_USING: str = "siliconflow" # 使用的API - API_PAID: bool = False # 是否使用付费API MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 @@ -93,6 +96,9 @@ class BotConfig: emoji_config = toml_dict["emoji"] config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) + config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) + config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE) + config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK) if "cq_code" in toml_dict: cq_code_config = toml_dict["cq_code"] @@ -110,8 +116,7 @@ class BotConfig: config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) - config.API_USING = response_config.get("api_using", config.API_USING) - config.API_PAID = response_config.get("api_paid", config.API_PAID) + config.max_response_length = response_config.get("max_response_length", config.max_response_length) # 加载模型配置 if "model" in toml_dict: @@ -125,10 +130,18 @@ class BotConfig: if "llm_normal" in model_config: config.llm_normal = model_config["llm_normal"] - config.llm_topic_extract = config.llm_normal if "llm_normal_minor" in model_config: config.llm_normal_minor = model_config["llm_normal_minor"] + + if "llm_topic_judge" in model_config: + config.llm_topic_judge = model_config["llm_topic_judge"] + + if "llm_summary_by_topic" in model_config: + config.llm_summary_by_topic = model_config["llm_summary_by_topic"] + + if "llm_emotion_judge" in model_config: + config.llm_emotion_judge = model_config["llm_emotion_judge"] if "vlm" in model_config: config.vlm = model_config["vlm"] @@ -136,14 +149,8 @@ class BotConfig: if "embedding" in model_config: config.embedding = model_config["embedding"] - if 'topic' in toml_dict: - topic_config=toml_dict['topic'] - if 'topic_extract' in topic_config: - config.topic_extract=topic_config.get('topic_extract',config.topic_extract) - logger.info(f"载入自定义主题提取为{config.topic_extract}") - if config.topic_extract=='llm' and 'llm_topic' in topic_config: - config.llm_topic_extract=topic_config['llm_topic'] - logger.info(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}") + if "moderation" in model_config: + config.moderation = model_config["moderation"] # 消息配置 if "message" in toml_dict: @@ -178,13 +185,13 @@ class BotConfig: bot_config_floder_path = BotConfig.get_config_dir() print(f"正在品鉴配置文件目录: {bot_config_floder_path}") -bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml") -if not os.path.exists(bot_config_path): +bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") +if os.path.exists(bot_config_path): # 如果开发环境配置文件不存在,则使用默认配置文件 - bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") + print(f"异常的新鲜,异常的美味: {bot_config_path}") logger.info("使用bot配置文件") else: - logger.info("已找到开发bot配置文件") + logger.info("没有找到美味") global_config = BotConfig.load_config(config_path=bot_config_path) diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index ab14148b5..4427ecb4a 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -10,12 +10,12 @@ from nonebot.adapters.onebot.v11 import Bot from .config import global_config import time import asyncio -from .utils_image import storage_image,storage_emoji +from .utils_image import storage_image, storage_emoji from .utils_user import get_user_nickname from ..models.utils_model import LLM_request from .mapper import emojimapper -#解析各种CQ码 -#包含CQ码类 +# 解析各种CQ码 +# 包含CQ码类 import urllib3 from urllib3.util import create_urllib3_context from nonebot import get_driver @@ -28,6 +28,7 @@ ctx = create_urllib3_context() ctx.load_default_certs() ctx.set_ciphers("AES128-GCM-SHA256") + class TencentSSLAdapter(requests.adapters.HTTPAdapter): def __init__(self, ssl_context=None, **kwargs): self.ssl_context = ssl_context @@ -38,6 +39,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter): num_pools=connections, maxsize=maxsize, block=block, ssl_context=self.ssl_context) + @dataclass class CQCode: """ @@ -65,15 +67,15 @@ class CQCode: """初始化LLM实例""" self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) - def translate(self): + async def translate(self): """根据CQ码类型进行相应的翻译处理""" if self.type == 'text': self.translated_plain_text = self.params.get('text', '') elif self.type == 'image': if self.params.get('sub_type') == '0': - self.translated_plain_text = self.translate_image() + self.translated_plain_text = await self.translate_image() else: - self.translated_plain_text = self.translate_emoji() + self.translated_plain_text = await self.translate_emoji() elif self.type == 'at': user_nickname = get_user_nickname(self.params.get('qq', '')) if user_nickname: @@ -81,13 +83,13 @@ class CQCode: else: self.translated_plain_text = f"@某人" elif self.type == 'reply': - self.translated_plain_text = self.translate_reply() + self.translated_plain_text = await self.translate_reply() elif self.type == 'face': face_id = self.params.get('id', '') # self.translated_plain_text = f"[表情{face_id}]" self.translated_plain_text = f"[{emojimapper.get(int(face_id), "表情")}]" elif self.type == 'forward': - self.translated_plain_text = self.translate_forward() + self.translated_plain_text = await self.translate_forward() else: self.translated_plain_text = f"[{self.type}]" @@ -134,7 +136,7 @@ class CQCode: # 腾讯服务器特殊状态码处理 if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: return None - + if response.status_code != 200: raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") @@ -158,8 +160,8 @@ class CQCode: return None return None - - def translate_emoji(self) -> str: + + async def translate_emoji(self) -> str: """处理表情包类型的CQ码""" if 'url' not in self.params: return '[表情包]' @@ -168,50 +170,51 @@ class CQCode: # 将 base64 字符串转换为字节类型 image_bytes = base64.b64decode(base64_str) storage_emoji(image_bytes) - return self.get_emoji_description(base64_str) + return await self.get_emoji_description(base64_str) else: return '[表情包]' - - - def translate_image(self) -> str: + + async def translate_image(self) -> str: """处理图片类型的CQ码,区分普通图片和表情包""" - #没有url,直接返回默认文本 + # 没有url,直接返回默认文本 if 'url' not in self.params: return '[图片]' base64_str = self.get_img() if base64_str: image_bytes = base64.b64decode(base64_str) storage_image(image_bytes) - return self.get_image_description(base64_str) + return await self.get_image_description(base64_str) else: return '[图片]' - def get_emoji_description(self, image_base64: str) -> str: + async def get_emoji_description(self, image_base64: str) -> str: """调用AI接口获取表情包描述""" try: prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[表情包:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[表情包]" - def get_image_description(self, image_base64: str) -> str: + async def get_image_description(self, image_base64: str) -> str: """调用AI接口获取普通图片描述""" try: prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[图片:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[图片]" - - def translate_forward(self) -> str: + + async def translate_forward(self) -> str: """处理转发消息""" try: if 'content' not in self.params: return '[转发消息]' - + # 解析content内容(需要先反转义) content = self.unescape(self.params['content']) # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") @@ -222,17 +225,17 @@ class CQCode: except ValueError as e: print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") return '[转发消息]' - + # 处理每条消息 formatted_messages = [] for msg in messages: sender = msg.get('sender', {}) nickname = sender.get('card') or sender.get('nickname', '未知用户') - + # 获取消息内容并使用Message类处理 raw_message = msg.get('raw_message', '') message_array = msg.get('message', []) - + if message_array and isinstance(message_array, list): # 检查是否包含嵌套的转发消息 for message_part in message_array: @@ -250,6 +253,7 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' @@ -264,23 +268,24 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' - + formatted_msg = f"{nickname}: {content}" formatted_messages.append(formatted_msg) - + # 合并所有消息 combined_messages = '\n'.join(formatted_messages) print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") return f"[转发消息:\n{combined_messages}]" - + except Exception as e: print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") return '[转发消息]' - def translate_reply(self) -> str: + async def translate_reply(self) -> str: """处理回复类型的CQ码""" # 创建Message对象 @@ -288,7 +293,7 @@ class CQCode: if self.reply_message == None: # print(f"\033[1;31m[错误]\033[0m 回复消息为空") return '[回复某人消息]' - + if self.reply_message.sender.user_id: message_obj = Message( user_id=self.reply_message.sender.user_id, @@ -296,6 +301,7 @@ class CQCode: raw_message=str(self.reply_message.message), group_id=self.group_id ) + await message_obj.initialize() if message_obj.user_id == global_config.BOT_QQ: return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" else: @@ -309,9 +315,9 @@ class CQCode: def unescape(text: str) -> str: """反转义CQ码中的特殊字符""" return text.replace(',', ',') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace('&', '&') + .replace('[', '[') \ + .replace(']', ']') \ + .replace('&', '&') @staticmethod def create_emoji_cq(file_path: str) -> str: @@ -326,15 +332,16 @@ class CQCode: abs_path = os.path.abspath(file_path) # 转义特殊字符 escaped_path = abs_path.replace('&', '&') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace(',', ',') + .replace('[', '[') \ + .replace(']', ']') \ + .replace(',', ',') # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" - + + class CQCode_tool: @staticmethod - def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: + async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: """ 将CQ码字典转换为CQCode对象 @@ -353,7 +360,7 @@ class CQCode_tool: params['text'] = cq_code.get('data', {}).get('text', '') else: params = cq_code.get('data', {}) - + instance = CQCode( type=cq_type, params=params, @@ -361,11 +368,11 @@ class CQCode_tool: user_id=0, reply_message=reply ) - + # 进行翻译处理 - instance.translate() + await instance.translate() return instance - + @staticmethod def create_reply_cq(message_id: int) -> str: """ @@ -376,6 +383,6 @@ class CQCode_tool: 回复CQ码字符串 """ return f"[CQ:reply,id={message_id}]" - - + + cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 2311b2459..432d11753 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -14,10 +14,14 @@ import asyncio import time from PIL import Image import io +from loguru import logger +import traceback from nonebot import get_driver from ..chat.config import global_config from ..models.utils_model import LLM_request +from ..chat.utils_image import image_path_to_base64 +from ..chat.utils import get_embedding driver = get_driver() config = driver.config @@ -27,16 +31,6 @@ class EmojiManager: _instance = None EMOJI_DIR = "data/emoji" # 表情包存储目录 - EMOTION_KEYWORDS = { - 'happy': ['开心', '快乐', '高兴', '欢喜', '笑', '喜悦', '兴奋', '愉快', '乐', '好'], - 'angry': ['生气', '愤怒', '恼火', '不爽', '火大', '怒', '气愤', '恼怒', '发火', '不满'], - 'sad': ['伤心', '难过', '悲伤', '痛苦', '哭', '忧伤', '悲痛', '哀伤', '委屈', '失落'], - 'surprised': ['惊讶', '震惊', '吃惊', '意外', '惊', '诧异', '惊奇', '惊喜', '不敢相信', '目瞪口呆'], - 'disgusted': ['恶心', '讨厌', '厌恶', '反感', '嫌弃', '恶', '嫌恶', '憎恶', '不喜欢', '烦'], - 'fearful': ['害怕', '恐惧', '惊恐', '担心', '怕', '惊吓', '惊慌', '畏惧', '胆怯', '惧'], - 'neutral': ['普通', '一般', '还行', '正常', '平静', '平淡', '一般般', '凑合', '还好', '就这样'] - } - def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) @@ -47,7 +41,8 @@ class EmojiManager: def __init__(self): self.db = Database.get_instance() self._scan_task = None - self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50) + self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) + self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度,更少的token(后续可以根据情绪来调整温度) def _ensure_emoji_dir(self): """确保表情存储目录存在""" @@ -64,7 +59,7 @@ class EmojiManager: # 启动时执行一次完整性检查 self.check_emoji_file_integrity() except Exception as e: - print(f"\033[1;31m[错误]\033[0m 初始化表情管理器失败: {str(e)}") + logger.error(f"初始化表情管理器失败: {str(e)}") def _ensure_db(self): """确保数据库已初始化""" @@ -74,9 +69,20 @@ class EmojiManager: raise RuntimeError("EmojiManager not initialized") def _ensure_emoji_collection(self): - """确保emoji集合存在并创建索引""" + """确保emoji集合存在并创建索引 + + 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 + + 索引的作用是加快数据库查询速度: + - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 + - tags字段的普通索引: 加快按标签搜索表情包的速度 + - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 + + 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 + """ if 'emoji' not in self.db.db.list_collection_names(): self.db.db.create_collection('emoji') + self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('tags', 1)]) self.db.db.emoji.create_index([('filename', 1)], unique=True) @@ -89,228 +95,128 @@ class EmojiManager: {'$inc': {'usage_count': 1}} ) except Exception as e: - print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}") + logger.error(f"记录表情使用失败: {str(e)}") - async def _get_emotion_from_text(self, text: str) -> List[str]: - """从文本中识别情感关键词 - Args: - text: 输入文本 - Returns: - List[str]: 匹配到的情感标签列表 - """ - try: - prompt = f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。' - - content, _ = await self.llm.generate_response(prompt) - emotion = content.strip().lower() - - if emotion in self.EMOTION_KEYWORDS: - print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}") - return [emotion] - - return ['neutral'] - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}") - return ['neutral'] - - async def get_emoji_for_emotion(self, emotion_tag: str) -> Optional[str]: - try: - self._ensure_db() - - # 构建查询条件:标签匹配任一情感 - query = {'tags': {'$in': emotion_tag}} - - # print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") - - try: - # 随机获取一个匹配的表情 - emoji = self.db.db.emoji.aggregate([ - {'$match': query}, - {'$sample': {'size': 1}} - ]).next() - print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - # 如果没有匹配的表情,从所有表情中随机选择一个 - print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") - try: - emoji = self.db.db.emoji.aggregate([ - {'$sample': {'size': 1}} - ]).next() - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") - return None - - return None - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") - return None - - async def get_emoji_for_text(self, text: str) -> Optional[str]: """根据文本内容获取相关表情包 Args: text: 输入文本 Returns: Optional[str]: 表情包文件路径,如果没有找到则返回None + + + 可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑? + 我觉得可行 + """ try: self._ensure_db() - # 获取情感标签 - emotions = await self._get_emotion_from_text(text) - print("为 ‘"+ str(text) + "’ 获取到的情感标签为:" + str(emotions)) - if not emotions: - return None - - # 构建查询条件:标签匹配任一情感 - query = {'tags': {'$in': emotions}} - print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") - print(f"\033[1;34m[调试]\033[0m 匹配到的情感: {emotions}") + # 获取文本的embedding + text_for_search= await self._get_kimoji_for_text(text) + if not text_for_search: + logger.error("无法获取文本的情绪") + return None + text_embedding = await get_embedding(text_for_search) + if not text_embedding: + logger.error("无法获取文本的embedding") + return None try: - # 随机获取一个匹配的表情 - emoji = self.db.db.emoji.aggregate([ - {'$match': query}, - {'$sample': {'size': 1}} - ]).next() - print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") - if emoji and 'path' in emoji: + # 获取所有表情包 + all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1})) + + if not all_emojis: + logger.warning("数据库中没有任何表情包") + return None + + # 计算余弦相似度并排序 + def cosine_similarity(v1, v2): + if not v1 or not v2: + return 0 + dot_product = sum(a * b for a, b in zip(v1, v2)) + norm_v1 = sum(a * a for a in v1) ** 0.5 + norm_v2 = sum(b * b for b in v2) ** 0.5 + if norm_v1 == 0 or norm_v2 == 0: + return 0 + return dot_product / (norm_v1 * norm_v2) + + # 计算所有表情包与输入文本的相似度 + emoji_similarities = [ + (emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) + for emoji in all_emojis + ] + + # 按相似度降序排序 + emoji_similarities.sort(key=lambda x: x[1], reverse=True) + + # 获取前3个最相似的表情包 + top_3_emojis = emoji_similarities[:3] + + if not top_3_emojis: + logger.warning("未找到匹配的表情包") + return None + + # 从前3个中随机选择一个 + selected_emoji, similarity = random.choice(top_3_emojis) + + if selected_emoji and 'path' in selected_emoji: # 更新使用次数 self.db.db.emoji.update_one( - {'_id': emoji['_id']}, + {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) - return emoji['path'] - except StopIteration: - # 如果没有匹配的表情,从所有表情中随机选择一个 - print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") - try: - emoji = self.db.db.emoji.aggregate([ - {'$sample': {'size': 1}} - ]).next() - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") - return None + logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})") + # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 + return selected_emoji['path'],"[ %s ]" % selected_emoji.get('discription', '无描述') + + except Exception as search_error: + logger.error(f"搜索表情包失败: {str(search_error)}") + return None return None except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") + logger.error(f"获取表情包失败: {str(e)}") return None - async def _get_emoji_tag(self, image_base64: str) -> str: + async def _get_emoji_discription(self, image_base64: str) -> str: """获取表情包的标签""" try: - prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好' + prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' - content, _ = await self.llm.generate_response_for_image(prompt, image_base64) - tag_result = content.strip().lower() - - valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"] - for tag_match in valid_tags: - if tag_match in tag_result or tag_match == tag_result: - return tag_match - print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过") + content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) + logger.debug(f"输出描述: {content}") + return content except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}") - return "skip" - - print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral") - return "skip" # 默认标签 - - async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]: - """压缩图片并返回base64编码 - Args: - image_path: 图片文件路径 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - Optional[str]: 成功返回base64编码的图片数据,失败返回None - """ - try: - file_size = os.path.getsize(image_path) - if file_size <= target_size: - # 如果文件已经小于目标大小,直接读取并返回base64 - with open(image_path, 'rb') as f: - return base64.b64encode(f.read()).decode('utf-8') - - # 打开图片 - with Image.open(image_path) as img: - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / file_size) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - for frame_idx in range(img.n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) - frames.append(new_frame) - - # 保存到缓冲区 - frames[0].save( - output_buffer, - format='GIF', - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get('duration', 100), - loop=img.info.get('loop', 0) - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): - resized_img.save(output_buffer, format='PNG', optimize=True) - else: - resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - print(f"\033[1;32m[成功]\033[0m 压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})") - - return base64.b64encode(compressed_data).decode('utf-8') - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}") + logger.error(f"获取标签失败: {str(e)}") return None + + async def _check_emoji(self, image_base64: str) -> str: + try: + prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' + content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) + logger.debug(f"输出描述: {content}") + return content + + except Exception as e: + logger.error(f"获取标签失败: {str(e)}") + return None + + async def _get_kimoji_for_text(self, text:str): + try: + prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' + + content, _ = await self.llm_emotion_judge.generate_response_async(prompt) + logger.info(f"输出描述: {content}") + return content + + except Exception as e: + logger.error(f"获取标签失败: {str(e)}") + return None + async def scan_new_emojis(self): """扫描新的表情包""" try: @@ -329,34 +235,43 @@ class EmojiManager: continue # 压缩图片并获取base64编码 - image_base64 = await self._compress_image(image_path) + image_base64 = image_path_to_base64(image_path) if image_base64 is None: os.remove(image_path) continue - # 获取表情包的情感标签 - tag = await self._get_emoji_tag(image_base64) - if not tag == "skip": + # 获取表情包的描述 + discription = await self._get_emoji_discription(image_base64) + if global_config.EMOJI_CHECK: + check = await self._check_emoji(image_base64) + if '是' not in check: + os.remove(image_path) + logger.info(f"描述: {discription}") + logger.info(f"其不满足过滤规则,被剔除 {check}") + continue + logger.info(f"check通过 {check}") + embedding = await get_embedding(discription) + if discription is not None: # 准备数据库记录 emoji_record = { 'filename': filename, 'path': image_path, - 'tags': [tag], + 'embedding':embedding, + 'discription': discription, 'timestamp': int(time.time()) } # 保存到数据库 self.db.db['emoji'].insert_one(emoji_record) - print(f"\033[1;32m[成功]\033[0m 注册新表情包: {filename}") - print(f"标签: {tag}") + logger.success(f"注册新表情包: {filename}") + logger.info(f"描述: {discription}") else: - print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}") + logger.warning(f"跳过表情包: {filename}") except Exception as e: - print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}") - import traceback - print(traceback.format_exc()) - + logger.error(f"扫描表情包失败: {str(e)}") + logger.error(traceback.format_exc()) + async def _periodic_scan(self, interval_MINS: int = 10): """定期扫描新表情包""" while True: @@ -364,6 +279,7 @@ class EmojiManager: await self.scan_new_emojis() await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 + def check_emoji_file_integrity(self): """检查表情包文件完整性 如果文件已被删除,则从数据库中移除对应记录 @@ -378,44 +294,42 @@ class EmojiManager: for emoji in all_emojis: try: if 'path' not in emoji: - print(f"\033[1;33m[提示]\033[0m 发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") + logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") + self.db.db.emoji.delete_one({'_id': emoji['_id']}) + removed_count += 1 + continue + + if 'embedding' not in emoji: + logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") self.db.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue # 检查文件是否存在 if not os.path.exists(emoji['path']): - print(f"\033[1;33m[提示]\033[0m 表情包文件已被删除: {emoji['path']}") + logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) if result.deleted_count > 0: - print(f"\033[1;32m[成功]\033[0m 成功删除数据库记录: {emoji['_id']}") + logger.success(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 else: - print(f"\033[1;31m[错误]\033[0m 删除数据库记录失败: {emoji['_id']}") + logger.error(f"删除数据库记录失败: {emoji['_id']}") except Exception as item_error: - print(f"\033[1;31m[错误]\033[0m 处理表情包记录时出错: {str(item_error)}") + logger.error(f"处理表情包记录时出错: {str(item_error)}") continue # 验证清理结果 remaining_count = self.db.db.emoji.count_documents({}) if removed_count > 0: - print(f"\033[1;32m[成功]\033[0m 已清理 {removed_count} 个失效的表情包记录") - print(f"\033[1;34m[统计]\033[0m 清理前总数: {total_count} | 清理后总数: {remaining_count}") - # print(f"\033[1;34m[统计]\033[0m 应删除数量: {removed_count} | 实际删除数量: {total_count - remaining_count}") - # 执行数据库压缩 - try: - self.db.db.command({"compact": "emoji"}) - print(f"\033[1;32m[成功]\033[0m 数据库集合压缩完成") - except Exception as compact_error: - print(f"\033[1;31m[错误]\033[0m 数据库压缩失败: {str(compact_error)}") + logger.success(f"已清理 {removed_count} 个失效的表情包记录") + logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") else: - print(f"\033[1;36m[表情包]\033[0m 已检查 {total_count} 个表情包记录") + logger.info(f"已检查 {total_count} 个表情包记录") except Exception as e: - print(f"\033[1;31m[错误]\033[0m 检查表情包完整性失败: {str(e)}") - import traceback - print(f"\033[1;31m[错误追踪]\033[0m\n{traceback.format_exc()}") + logger.error(f"检查表情包完整性失败: {str(e)}") + logger.error(traceback.format_exc()) async def start_periodic_check(self, interval_MINS: int = 120): while True: diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 04f2e73ad..a19a222c2 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -24,6 +24,7 @@ class ResponseGenerator: self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000) self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000) + self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000) self.db = Database.get_instance() self.current_model_type = 'r1' # 默认使用 R1 @@ -44,19 +45,15 @@ class ResponseGenerator: print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") model_response = await self._generate_response_with_model(message, current_model) + raw_content=model_response if model_response: print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') - model_response, emotion = await self._process_response(model_response) + model_response = await self._process_response(model_response) if model_response: - print(f"为 '{model_response}' 获取到的情感标签为:{emotion}") - valuedict={ - 'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25 - } - await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) - return model_response, emotion - return None, [] + return model_response ,raw_content + return None,raw_content async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]: """使用指定的模型生成回复""" @@ -67,10 +64,11 @@ class ResponseGenerator: # 获取关系值 relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0 if relationship_value != 0.0: - print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + # print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + pass # 构建prompt - prompt, prompt_check = prompt_builder._build_prompt( + prompt, prompt_check = await prompt_builder._build_prompt( message_txt=message.processed_plain_text, sender_name=sender_name, relationship_value=relationship_value, @@ -142,7 +140,7 @@ class ResponseGenerator: 内容:{content} 输出: ''' - content, _ = await self.model_v3.generate_response(prompt) + content, _ = await self.model_v25.generate_response(prompt) content=content.strip() if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']: return [content] @@ -158,10 +156,9 @@ class ResponseGenerator: if not content: return None, [] - emotion_tags = await self._get_emotion_tags(content) processed_response = process_llm_response(content) - return processed_response, emotion_tags + return processed_response class InitiativeMessageGenerate: diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index d6e400e15..a39cf293f 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -27,58 +27,66 @@ class Message: """消息数据类""" message_id: int = None time: float = None - + group_id: int = None - group_name: str = None # 群名称 - + group_name: str = None # 群名称 + user_id: int = None user_nickname: str = None # 用户昵称 - user_cardname: str=None # 用户群昵称 - - raw_message: str = None # 原始消息,包含未解析的cq码 - plain_text: str = None # 纯文本 - + user_cardname: str = None # 用户群昵称 + + raw_message: str = None # 原始消息,包含未解析的cq码 + plain_text: str = None # 纯文本 + + reply_message: Dict = None # 存储 回复的 源消息 + + # 延迟初始化字段 + _initialized: bool = False message_segments: List[Dict] = None # 存储解析后的消息片段 processed_plain_text: str = None # 用于存储处理后的plain_text detailed_plain_text: str = None # 用于存储详细可读文本 - - reply_message: Dict = None # 存储 回复的 源消息 - - is_emoji: bool = False # 是否是表情包 - has_emoji: bool = False # 是否包含表情包 - - translate_cq: bool = True # 是否翻译cq码 - - def __post_init__(self): - if self.time is None: - self.time = int(time.time()) - - if not self.group_name: - self.group_name = get_groupname(self.group_id) - - if not self.user_nickname: - self.user_nickname = get_user_nickname(self.user_id) - - if not self.user_cardname: - self.user_cardname=get_user_cardname(self.user_id) - - if not self.processed_plain_text: - if self.raw_message: - self.message_segments = self.parse_message_segments(str(self.raw_message)) + + # 状态标志 + is_emoji: bool = False + has_emoji: bool = False + translate_cq: bool = True + + async def initialize(self): + """显式异步初始化方法(必须调用)""" + if self._initialized: + return + + # 异步获取补充信息 + self.group_name = self.group_name or get_groupname(self.group_id) + self.user_nickname = self.user_nickname or get_user_nickname(self.user_id) + self.user_cardname = self.user_cardname or get_user_cardname(self.user_id) + + # 消息解析 + if self.raw_message: + if not isinstance(self,Message_Sending): + self.message_segments = await self.parse_message_segments(self.raw_message) self.processed_plain_text = ' '.join( seg.translated_plain_text for seg in self.message_segments ) - #将详细翻译为详细可读文本 + + # 构建详细文本 + if self.time is None: + self.time = int(time.time()) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) - try: - name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" - except: - name = self.user_nickname or f"用户{self.user_id}" - content = self.processed_plain_text - self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" + name = ( + f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" + if self.user_cardname + else f"{self.user_nickname or f'用户{self.user_id}'}" + ) + if isinstance(self,Message_Sending) and self.is_emoji: + self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n" + else: + self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n" + + self._initialized = True - def parse_message_segments(self, message: str) -> List[CQCode]: + async def parse_message_segments(self, message: str) -> List[CQCode]: """ 将消息解析为片段列表,包括纯文本和CQ码 返回的列表中每个元素都是字典,包含: @@ -136,7 +144,7 @@ class Message: #翻译作为字典的CQ码 for _code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) + message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) trans_list.append(message_obj) return trans_list @@ -169,6 +177,8 @@ class Message_Sending(Message): reply_message_id: int = None # 存储 回复的 源消息ID + is_head: bool = False # 是否是头部消息 + def update_thinking_time(self): self.thinking_time = round(time.time(), 2) - self.thinking_start_time return self.thinking_time diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 970fd3682..3e30b3cbe 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -103,7 +103,7 @@ class MessageContainer: def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: """添加消息到队列""" - print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") + # print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") if isinstance(message, MessageSet): for single_message in message.messages: self.messages.append(single_message) @@ -156,26 +156,21 @@ class MessageManager: #最早的对象,可能是思考消息,也可能是发送消息 message_earliest = container.get_earliest_message() #一个message_thinking or message_sending - #一个月后删了 - if not message_earliest: - print(f"\033[1;34m[BUG,如果出现这个,说明有BUG,3月4日留]\033[0m ") - return - #如果是思考消息 if isinstance(message_earliest, Message_Thinking): #优先等待这条消息 message_earliest.update_thinking_time() thinking_time = message_earliest.thinking_time - print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒") + if thinking_time % 10 == 0: + print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒") else:# 如果不是message_thinking就只能是message_sending print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") #直接发,等什么呢 - if message_earliest.update_thinking_time() < 30: - await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) - else: + if message_earliest.is_head and message_earliest.update_thinking_time() >30: await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id) - - #移除消息 + else: + await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) + #移除消息 if message_earliest.is_emoji: message_earliest.processed_plain_text = "[表情包]" await self.storage.store_message(message_earliest, None) @@ -192,10 +187,11 @@ class MessageManager: try: #发送 - if msg.update_thinking_time() < 30: - await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) - else: + if msg.is_head and msg.update_thinking_time() >30: await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) + else: + await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) + #如果是表情包,则替换为"[表情包]" if msg.is_emoji: diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index ba22a403d..3b7894f56 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -6,9 +6,11 @@ from .utils import get_embedding, combine_messages, get_recent_group_detailed_pl from ...common.database import Database from .config import global_config from .topic_identifier import topic_identifier -from ..memory_system.memory import memory_graph +from ..memory_system.memory import memory_graph,hippocampus from random import choice - +import numpy as np +import jieba +from collections import Counter class PromptBuilder: def __init__(self): @@ -16,7 +18,9 @@ class PromptBuilder: self.activate_messages = '' self.db = Database.get_instance() - def _build_prompt(self, + + + async def _build_prompt(self, message_txt: str, sender_name: str = "某人", relationship_value: float = 0.0, @@ -31,60 +35,7 @@ class PromptBuilder: Returns: str: 构建好的prompt - """ - - - memory_prompt = '' - start_time = time.time() # 记录开始时间 - # topic = await topic_identifier.identify_topic_llm(message_txt) - topic = topic_identifier.identify_topic_snownlp(message_txt) - - # print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}") - - all_first_layer_items = [] # 存储所有第一层记忆 - all_second_layer_items = {} # 用字典存储每个topic的第二层记忆 - overlapping_second_layer = set() # 存储重叠的第二层记忆 - - if topic: - # 遍历所有topic - for current_topic in topic: - first_layer_items, second_layer_items = memory_graph.get_related_item(current_topic, depth=2) - # if first_layer_items: - # print(f"\033[1;32m[前额叶]\033[0m 主题 '{current_topic}' 的第一层记忆: {first_layer_items}") - - # 记录第一层数据 - all_first_layer_items.extend(first_layer_items) - - # 记录第二层数据 - all_second_layer_items[current_topic] = second_layer_items - - # 检查是否有重叠的第二层数据 - for other_topic, other_second_layer in all_second_layer_items.items(): - if other_topic != current_topic: - # 找到重叠的记忆 - overlap = set(second_layer_items) & set(other_second_layer) - if overlap: - # print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}") - overlapping_second_layer.update(overlap) - - selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else [] - selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else [] - - # 合并并去重 - all_memories = list(set(selected_first_layer + selected_second_layer)) - if all_memories: - print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}") - random_item = " ".join(all_memories) - memory_prompt = f"看到这些聊天,你想起来{random_item}\n" - else: - memory_prompt = "" # 如果没有记忆,则返回空字符串 - - end_time = time.time() # 记录结束时间 - print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") # 输出耗时 - - - - + """ #先禁用关系 if 0 > 30: relation_prompt = "关系特别特别好,你很喜欢喜欢他" @@ -109,25 +60,52 @@ class PromptBuilder: prompt_info = '' promt_info_prompt = '' - prompt_info = self.get_prompt_info(message_txt,threshold=0.5) + prompt_info = await self.get_prompt_info(message_txt,threshold=0.5) if prompt_info: prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' - # promt_info_prompt = '你有一些[知识],在上面可以参考。' end_time = time.time() print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒") - # print(f"\033[1;34m[调试]\033[0m 获取知识库内容结果: {prompt_info}") - - # print(f"\033[1;34m[调试信息]\033[0m 正在构建聊天上下文") - + # 获取聊天上下文 chat_talking_prompt = '' if group_id: chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" - # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") + + + # 使用新的记忆获取方法 + memory_prompt = '' + start_time = time.time() + + # 调用 hippocampus 的 get_relevant_memories 方法 + relevant_memories = await hippocampus.get_relevant_memories( + text=message_txt, + max_topics=5, + similarity_threshold=0.4, + max_memory_num=5 + ) + + if relevant_memories: + # 格式化记忆内容 + memory_items = [] + for memory in relevant_memories: + memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}") + + memory_prompt = f"看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n" + + # 打印调试信息 + print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:") + for memory in relevant_memories: + print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}") + + end_time = time.time() + print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒") + + + #激活prompt构建 activate_prompt = '' activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2}。" @@ -162,29 +140,19 @@ class PromptBuilder: if random.random() < 0.01: prompt_ger += '你喜欢用文言文' - #额外信息要求 - extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' - - + extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' #合并prompt prompt = "" prompt += f"{prompt_info}\n" prompt += f"{prompt_date}\n" prompt += f"{chat_talking_prompt}\n" - - # prompt += f"{memory_prompt}\n" - - # prompt += f"{activate_prompt}\n" prompt += f"{prompt_personality}\n" prompt += f"{prompt_ger}\n" prompt += f"{extra_info}\n" - - '''读空气prompt处理''' - activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" prompt_personality_check = '' extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。" @@ -247,10 +215,10 @@ class PromptBuilder: return prompt_for_initiative - def get_prompt_info(self,message:str,threshold:float): + async def get_prompt_info(self,message:str,threshold:float): related_info = '' print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - embedding = get_embedding(message) + embedding = await get_embedding(message) related_info += self.get_info_from_db(embedding,threshold=threshold) return related_info diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 812d4e321..6579d15ac 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -4,7 +4,6 @@ from .message import Message import jieba from nonebot import get_driver from .config import global_config -from snownlp import SnowNLP from ..models.utils_model import LLM_request driver = get_driver() @@ -12,10 +11,8 @@ config = driver.config class TopicIdentifier: def __init__(self): - self.llm_client = LLM_request(model=global_config.llm_topic_extract) - self.select=global_config.topic_extract + self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge) - async def identify_topic_llm(self, text: str) -> Optional[List[str]]: """识别消息主题,返回主题列表""" @@ -26,7 +23,7 @@ class TopicIdentifier: 消息内容:{text}""" # 使用 LLM_request 类进行请求 - topic, _ = await self.llm_client.generate_response(prompt) + topic, _ = await self.llm_topic_judge.generate_response(prompt) if not topic: print(f"\033[1;31m[错误]\033[0m LLM API 返回为空") @@ -42,25 +39,4 @@ class TopicIdentifier: print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}") return topic_list if topic_list else None - def identify_topic_snownlp(self, text: str) -> Optional[List[str]]: - """使用 SnowNLP 进行主题识别 - - Args: - text (str): 需要识别主题的文本 - - Returns: - Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None - """ - if not text or len(text.strip()) == 0: - return None - - try: - s = SnowNLP(text) - # 提取前3个关键词作为主题 - keywords = s.keywords(5) - return keywords if keywords else None - except Exception as e: - print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}") - return None - topic_identifier = TopicIdentifier() \ No newline at end of file diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index aa16268ef..db5a35384 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -11,6 +11,9 @@ from collections import Counter import math from nonebot import get_driver from ..models.utils_model import LLM_request +import aiohttp +import jieba +from ..utils.typo_generator import ChineseTypoGenerator driver = get_driver() config = driver.config @@ -30,16 +33,18 @@ def combine_messages(messages: List[Message]) -> str: time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time)) name = message.user_nickname or f"用户{message.user_id}" content = message.processed_plain_text or message.plain_text - + result += f"[{time_str}] {name}: {content}\n" - + return result -def db_message_to_str (message_dict: Dict) -> str: + +def db_message_to_str(message_dict: Dict) -> str: print(f"message_dict: {message_dict}") time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) try: - name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", "")) + name = "[(%s)%s]%s" % ( + message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", "")) except: name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" content = message_dict.get("processed_plain_text", "") @@ -56,6 +61,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool: return True return False + def is_mentioned_bot_in_txt(message: str) -> bool: """检查消息是否提到了机器人""" keywords = [global_config.BOT_NICKNAME] @@ -64,10 +70,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool: return True return False -def get_embedding(text): + +async def get_embedding(text): """获取文本的embedding向量""" llm = LLM_request(model=global_config.embedding) - return llm.get_embedding_sync(text) + # return llm.get_embedding_sync(text) + return await llm.get_embedding(text) + def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) @@ -75,52 +84,55 @@ def cosine_similarity(v1, v2): norm2 = np.linalg.norm(v2) return dot_product / (norm1 * norm2) + def calculate_information_content(text): """计算文本的信息量(熵)""" char_count = Counter(text) total_chars = len(text) - + entropy = 0 for count in char_count.values(): probability = count / total_chars entropy -= probability * math.log2(probability) - + return entropy + def get_cloest_chat_from_db(db, length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数""" chat_text = '' closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) - - if closest_record and closest_record.get('memorized', 0) < 4: + + if closest_record and closest_record.get('memorized', 0) < 4: closest_time = closest_record['time'] group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_records = list(db.db.messages.find( {"time": {"$gt": closest_time}, "group_id": group_id} ).sort('time', 1).limit(length)) - + # 更新每条消息的memorized属性 for record in chat_records: # 检查当前记录的memorized值 current_memorized = record.get('memorized', 0) - if current_memorized > 3: + if current_memorized > 3: # print(f"消息已读取3次,跳过") return '' - + # 更新memorized值 db.db.messages.update_one( {"_id": record["_id"]}, {"$set": {"memorized": current_memorized + 1}} ) - + chat_text += record["detailed_plain_text"] - + return chat_text - print(f"消息已读取3次,跳过") + # print(f"消息已读取3次,跳过") return '' -def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: + +async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 Args: @@ -132,7 +144,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: list: Message对象列表,按时间正序排列 """ - # 从数据库获取最近消息 + # 从数据库获取最近消息 recent_messages = list(db.db.messages.find( {"group_id": group_id}, # { @@ -147,7 +159,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: if not recent_messages: return [] - + # 转换为 Message对象列表 from .message import Message message_objects = [] @@ -162,16 +174,18 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: processed_plain_text=msg_data.get("processed_text", ""), group_id=group_id ) + await msg.initialize() message_objects.append(msg) except KeyError: print("[WARNING] 数据库中存在无效的消息") continue - + # 按时间正序排列 message_objects.reverse() return message_objects -def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False): + +def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False): recent_messages = list(db.db.messages.find( {"group_id": group_id}, { @@ -185,16 +199,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb if not recent_messages: return [] - + message_detailed_plain_text = '' message_detailed_plain_text_list = [] - + # 反转消息列表,使最新的消息在最后 recent_messages.reverse() - + if combine: for msg_db_data in recent_messages: - message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"]) + message_detailed_plain_text += str(msg_db_data["detailed_plain_text"]) return message_detailed_plain_text else: for msg_db_data in recent_messages: @@ -202,7 +216,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb return message_detailed_plain_text_list - def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: """将文本分割成句子,但保持书名号中的内容完整 Args: @@ -222,30 +235,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: split_strength = 0.7 else: split_strength = 0.9 - #先移除换行符 + # 先移除换行符 # print(f"split_strength: {split_strength}") - + # print(f"处理前的文本: {text}") - + # 统一将英文逗号转换为中文逗号 text = text.replace(',', ',') text = text.replace('\n', ' ') - + # print(f"处理前的文本: {text}") - + text_no_1 = '' for letter in text: # print(f"当前字符: {letter}") - if letter in ['!','!','?','?']: + if letter in ['!', '!', '?', '?']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < split_strength: letter = '' - if letter in ['。','…']: + if letter in ['。', '…']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < 1 - split_strength: letter = '' text_no_1 += letter - + # 对每个逗号单独判断是否分割 sentences = [text_no_1] new_sentences = [] @@ -274,84 +287,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: sentences_done = [] for sentence in sentences: sentence = sentence.rstrip(',,') - if random.random() < split_strength*0.5: + if random.random() < split_strength * 0.5: sentence = sentence.replace(',', '').replace(',', '') elif random.random() < split_strength: sentence = sentence.replace(',', ' ').replace(',', ' ') sentences_done.append(sentence) - + print(f"处理后的句子: {sentences_done}") return sentences_done -# 常见的错别字映射 -TYPO_DICT = { - '的': '地得', - '了': '咯啦勒', - '吗': '嘛麻', - '吧': '八把罢', - '是': '事', - '在': '再在', - '和': '合', - '有': '又', - '我': '沃窝喔', - '你': '泥尼拟', - '他': '它她塔祂', - '们': '门', - '啊': '阿哇', - '呢': '呐捏', - '都': '豆读毒', - '很': '狠', - '会': '回汇', - '去': '趣取曲', - '做': '作坐', - '想': '相像', - '说': '说税睡', - '看': '砍堪刊', - '来': '来莱赖', - '好': '号毫豪', - '给': '给既继', - '过': '锅果裹', - '能': '嫩', - '为': '位未', - '什': '甚深伸', - '么': '末麽嘛', - '话': '话花划', - '知': '织直值', - '道': '到', - '听': '听停挺', - '见': '见件建', - '觉': '觉脚搅', - '得': '得德锝', - '着': '着找招', - '像': '向象想', - '等': '等灯登', - '谢': '谢写卸', - '对': '对队', - '里': '里理鲤', - '啦': '啦拉喇', - '吃': '吃持迟', - '哦': '哦喔噢', - '呀': '呀压', - '要': '药', - '太': '太抬台', - '快': '块', - '点': '店', - '以': '以已', - '因': '因应', - '啥': '啥沙傻', - '行': '行型形', - '哈': '哈蛤铪', - '嘿': '嘿黑嗨', - '嗯': '嗯恩摁', - '哎': '哎爱埃', - '呜': '呜屋污', - '喂': '喂位未', - '嘛': '嘛麻马', - '嗨': '嗨害亥', - '哇': '哇娃蛙', - '咦': '咦意易', - '嘻': '嘻西希' -} + def random_remove_punctuation(text: str) -> str: """随机处理标点符号,模拟人类打字习惯 @@ -364,7 +309,7 @@ def random_remove_punctuation(text: str) -> str: """ result = '' text_len = len(text) - + for i, char in enumerate(text): if char == '。' and i == text_len - 1: # 结尾的句号 if random.random() > 0.4: # 80%概率删除结尾句号 @@ -379,32 +324,30 @@ def random_remove_punctuation(text: str) -> str: result += char return result -def add_typos(text: str) -> str: - TYPO_RATE = 0.02 # 控制错别字出现的概率(2%) - result = "" - for char in text: - if char in TYPO_DICT and random.random() < TYPO_RATE: - # 从可能的错别字中随机选择一个 - typos = TYPO_DICT[char] - result += random.choice(typos) - else: - result += char - return result + def process_llm_response(text: str) -> List[str]: # processed_response = process_text_with_typos(content) - if len(text) > 200: - print(f"回复过长 ({len(text)} 字符),返回默认回复") - return ['懒得说'] + if len(text) > 300: + print(f"回复过长 ({len(text)} 字符),返回默认回复") + return ['懒得说'] # 处理长消息 - sentences = split_into_sentences_w_remove_punctuation(add_typos(text)) + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.2, + word_replace_rate=0.02 + ) + typoed_text = typo_generator.create_typo_sentence(text)[0] + sentences = split_into_sentences_w_remove_punctuation(typoed_text) # 检查分割后的消息数量是否过多(超过3条) - if len(sentences) > 3: + if len(sentences) > 4: print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") return [f'{global_config.BOT_NICKNAME}不知道哦'] - + return sentences + def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float: """ 计算输入字符串所需的时间,中文和英文字符有不同的输入时间 @@ -417,7 +360,46 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_ if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符 total_time += chinese_time else: # 其他字符(如英文) - total_time += english_time + total_time += english_time return total_time +def cosine_similarity(v1, v2): + """计算余弦相似度""" + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + if norm1 == 0 or norm2 == 0: + return 0 + return dot_product / (norm1 * norm2) + + +def text_to_vector(text): + """将文本转换为词频向量""" + # 分词 + words = jieba.lcut(text) + # 统计词频 + word_freq = Counter(words) + return word_freq + + +def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: + """使用简单的余弦相似度计算文本相似度""" + # 将输入文本转换为词频向量 + text_vector = text_to_vector(text) + + # 计算每个主题的相似度 + similarities = [] + for topic in topics: + topic_vector = text_to_vector(topic) + # 获取所有唯一词 + all_words = set(text_vector.keys()) | set(topic_vector.keys()) + # 构建向量 + v1 = [text_vector.get(word, 0) for word in all_words] + v2 = [topic_vector.get(word, 0) for word in all_words] + # 计算相似度 + similarity = cosine_similarity(v1, v2) + similarities.append((topic, similarity)) + + # 按相似度降序排序并返回前k个 + return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 922ab5228..eff788868 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -4,6 +4,7 @@ import hashlib import time import os from ...common.database import Database +from ..chat.config import global_config import zlib # 用于 CRC32 import base64 from nonebot import get_driver @@ -143,6 +144,8 @@ def storage_emoji(image_data: bytes) -> bytes: Returns: bytes: 原始图片数据 """ + if not global_config.EMOJI_SAVE: + return image_data try: # 使用 CRC32 计算哈希值 hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') @@ -227,7 +230,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 image_data = base64.b64decode(base64_data) # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= target_size: + if len(image_data) <= 2*1024*1024: return base64_data # 将字节数据转换为图片对象 @@ -252,7 +255,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 for frame_idx in range(img.n_frames): img.seek(frame_idx) new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) + new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折 frames.append(new_frame) # 保存到缓冲区 @@ -286,4 +289,19 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 logger.error(f"压缩图片失败: {str(e)}") import traceback logger.error(traceback.format_exc()) - return base64_data \ No newline at end of file + return base64_data + +def image_path_to_base64(image_path: str) -> str: + """将图片路径转换为base64编码 + Args: + image_path: 图片文件路径 + Returns: + str: base64编码的图片数据 + """ + try: + with open(image_path, 'rb') as f: + image_data = f.read() + return base64.b64encode(image_data).decode('utf-8') + except Exception as e: + logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}") + return None \ No newline at end of file diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index 7559406f9..16a0570e2 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -34,16 +34,16 @@ class WillingManager: print(f"被重复提及, 当前意愿: {current_willing}") if is_emoji: - current_willing *= 0.15 + current_willing *= 0.1 print(f"表情包, 当前意愿: {current_willing}") - if interested_rate > 0.65: + if interested_rate > 0.4: print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") - current_willing += interested_rate-0.6 + current_willing += interested_rate-0.1 self.group_reply_willing[group_id] = min(current_willing, 3.0) - reply_probability = max((current_willing - 0.55) * 1.9, 0) + reply_probability = max((current_willing - 0.45) * 2, 0) if group_id not in config.talk_allowed_groups: current_willing = 0 reply_probability = 0 diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 4d20d05a9..44f5eb713 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -11,7 +11,7 @@ from ..chat.config import global_config from ...common.database import Database # 使用正确的导入语法 from ..models.utils_model import LLM_request import math -from ..chat.utils import calculate_information_content, get_cloest_chat_from_db +from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,text_to_vector,cosine_similarity @@ -132,9 +132,17 @@ class Memory_graph: class Hippocampus: def __init__(self,memory_graph:Memory_graph): self.memory_graph = memory_graph - self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5) - self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5) + self.llm_topic_judge = LLM_request(model = global_config.llm_topic_judge,temperature=0.5) + self.llm_summary_by_topic = LLM_request(model = global_config.llm_summary_by_topic,temperature=0.5) + def get_all_node_names(self) -> list: + """获取记忆图中所有节点的名字列表 + + Returns: + list: 包含所有节点名字的列表 + """ + return list(self.memory_graph.G.nodes()) + def calculate_node_hash(self, concept, memory_items): """计算节点的特征值""" if not isinstance(memory_items, list): @@ -171,18 +179,24 @@ class Hippocampus: #获取topics topic_num = self.calculate_topic_num(input_text, compress_rate) - topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num)) + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) # 修改话题处理逻辑 - print(f"话题: {topics_response[0]}") - topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] - print(f"话题: {topics}") + # 定义需要过滤的关键词 + filter_keywords = ['表情包', '图片', '回复', '聊天记录'] - # 创建所有话题的请求任务 + # 过滤topics + topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] + filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] + + # print(f"原始话题: {topics}") + print(f"过滤后话题: {filtered_topics}") + + # 使用过滤后的话题继续处理 tasks = [] - for topic in topics: + for topic in filtered_topics: topic_what_prompt = self.topic_what(input_text, topic) # 创建异步任务 - task = self.llm_model_summary.generate_response_async(topic_what_prompt) + task = self.llm_summary_by_topic.generate_response_async(topic_what_prompt) tasks.append((topic.strip(), task)) # 等待所有任务完成 @@ -483,6 +497,198 @@ class Hippocampus: prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' return prompt + async def _identify_topics(self, text: str) -> list: + """从文本中识别可能的主题 + + Args: + text: 输入文本 + + Returns: + list: 识别出的主题列表 + """ + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5)) + # print(f"话题: {topics_response[0]}") + topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] + # print(f"话题: {topics}") + + return topics + + def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list: + """查找与给定主题相似的记忆主题 + + Args: + topics: 主题列表 + similarity_threshold: 相似度阈值 + debug_info: 调试信息前缀 + + Returns: + list: (主题, 相似度) 元组列表 + """ + all_memory_topics = self.get_all_node_names() + all_similar_topics = [] + + # 计算每个识别出的主题与记忆主题的相似度 + for topic in topics: + if debug_info: + print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}") + + topic_vector = text_to_vector(topic) + has_similar_topic = False + + for memory_topic in all_memory_topics: + memory_vector = text_to_vector(memory_topic) + # 获取所有唯一词 + all_words = set(topic_vector.keys()) | set(memory_vector.keys()) + # 构建向量 + v1 = [topic_vector.get(word, 0) for word in all_words] + v2 = [memory_vector.get(word, 0) for word in all_words] + # 计算相似度 + similarity = cosine_similarity(v1, v2) + + if similarity >= similarity_threshold: + has_similar_topic = True + if debug_info: + print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})") + all_similar_topics.append((memory_topic, similarity)) + + if not has_similar_topic and debug_info: + print(f"\033[1;31m[{debug_info}]\033[0m 没有见过: {topic} ,呃呃") + + return all_similar_topics + + def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list: + """获取相似度最高的主题 + + Args: + similar_topics: (主题, 相似度) 元组列表 + max_topics: 最大主题数量 + + Returns: + list: (主题, 相似度) 元组列表 + """ + seen_topics = set() + top_topics = [] + + for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True): + if topic not in seen_topics and len(top_topics) < max_topics: + seen_topics.add(topic) + top_topics.append((topic, score)) + + return top_topics + + async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: + """计算输入文本对记忆的激活程度""" + print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") + + # 识别主题 + identified_topics = await self._identify_topics(text) + if not identified_topics: + return 0 + + # 查找相似主题 + all_similar_topics = self._find_similar_topics( + identified_topics, + similarity_threshold=similarity_threshold, + debug_info="记忆激活" + ) + + if not all_similar_topics: + return 0 + + # 获取最相关的主题 + top_topics = self._get_top_topics(all_similar_topics, max_topics) + + # 如果只找到一个主题,进行惩罚 + if len(top_topics) == 1: + topic, score = top_topics[0] + # 获取主题内容数量并计算惩罚系数 + memory_items = self.memory_graph.G.nodes[topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + penalty = 1.0 / (1 + math.log(content_count + 1)) + + activation = int(score * 50 * penalty) + print(f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}") + return activation + + # 计算关键词匹配率,同时考虑内容数量 + matched_topics = set() + topic_similarities = {} + + for memory_topic, similarity in top_topics: + # 计算内容数量惩罚 + memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + content_count = len(memory_items) + penalty = 1.0 / (1 + math.log(content_count + 1)) + + # 对每个记忆主题,检查它与哪些输入主题相似 + for input_topic in identified_topics: + topic_vector = text_to_vector(input_topic) + memory_vector = text_to_vector(memory_topic) + all_words = set(topic_vector.keys()) | set(memory_vector.keys()) + v1 = [topic_vector.get(word, 0) for word in all_words] + v2 = [memory_vector.get(word, 0) for word in all_words] + sim = cosine_similarity(v1, v2) + if sim >= similarity_threshold: + matched_topics.add(input_topic) + adjusted_sim = sim * penalty + topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim) + print(f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})") + + # 计算主题匹配率和平均相似度 + topic_match = len(matched_topics) / len(identified_topics) + average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0 + + # 计算最终激活值 + activation = int((topic_match + average_similarities) / 2 * 100) + print(f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}") + + return activation + + async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5) -> list: + """根据输入文本获取相关的记忆内容""" + # 识别主题 + identified_topics = await self._identify_topics(text) + + # 查找相似主题 + all_similar_topics = self._find_similar_topics( + identified_topics, + similarity_threshold=similarity_threshold, + debug_info="记忆检索" + ) + + # 获取最相关的主题 + relevant_topics = self._get_top_topics(all_similar_topics, max_topics) + + # 获取相关记忆内容 + relevant_memories = [] + for topic, score in relevant_topics: + # 获取该主题的记忆内容 + first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) + if first_layer: + # 如果记忆条数超过限制,随机选择指定数量的记忆 + if len(first_layer) > max_memory_num/2: + first_layer = random.sample(first_layer, max_memory_num//2) + # 为每条记忆添加来源主题和相似度信息 + for memory in first_layer: + relevant_memories.append({ + 'topic': topic, + 'similarity': score, + 'content': memory + }) + + # 如果记忆数量超过5个,随机选择5个 + # 按相似度排序 + relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) + + if len(relevant_memories) > max_memory_num: + relevant_memories = random.sample(relevant_memories, max_memory_num) + + return relevant_memories + def segment_text(text): seg_text = list(jieba.cut(text)) diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py index d6aa2f669..e99485655 100644 --- a/src/plugins/memory_system/memory_manual_build.py +++ b/src/plugins/memory_system/memory_manual_build.py @@ -13,7 +13,6 @@ from dotenv import load_dotenv import pymongo from loguru import logger from pathlib import Path -from snownlp import SnowNLP # from chat.config import global_config sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 from src.common.database import Database @@ -234,16 +233,22 @@ class Hippocampus: async def memory_compress(self, input_text, compress_rate=0.1): print(input_text) - #获取topics topic_num = self.calculate_topic_num(input_text, compress_rate) - topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num)) + topics_response = self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num)) # 修改话题处理逻辑 + # 定义需要过滤的关键词 + filter_keywords = ['表情包', '图片', '回复', '聊天记录'] + + # 过滤topics topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] - print(f"话题: {topics}") + filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] + + # print(f"原始话题: {topics}") + print(f"过滤后话题: {filtered_topics}") # 创建所有话题的请求任务 tasks = [] - for topic in topics: + for topic in filtered_topics: topic_what_prompt = self.topic_what(input_text, topic) # 创建异步任务 task = self.llm_model_small.generate_response_async(topic_what_prompt) @@ -650,7 +655,22 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal G = memory_graph.G # 创建一个新图用于可视化 - H = G.copy() + H = G.copy() + + # 过滤掉内容数量小于2的节点 + nodes_to_remove = [] + for node in H.nodes(): + memory_items = H.nodes[node].get('memory_items', []) + memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) + if memory_count < 2: + nodes_to_remove.append(node) + + H.remove_nodes_from(nodes_to_remove) + + # 如果没有符合条件的节点,直接返回 + if len(H.nodes()) == 0: + print("没有找到内容数量大于等于2的节点") + return # 计算节点大小和颜色 node_colors = [] @@ -704,7 +724,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal edge_color='gray', width=1.5) # 统一的边宽度 - title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近' + title = '记忆图谱可视化(仅显示内容≥2的节点)\n节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近' plt.title(title, fontsize=16, fontfamily='SimHei') plt.show() diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 11d7e2b72..a471bd72d 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -8,6 +8,8 @@ from nonebot import get_driver from loguru import logger from ..chat.config import global_config from ..chat.utils_image import compress_base64_image_by_scale +from datetime import datetime +from ...common.database import Database driver = get_driver() config = driver.config @@ -24,397 +26,311 @@ class LLM_request: raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e self.model_name = model["name"] self.params = kwargs + + self.pri_in = model.get("pri_in", 0) + self.pri_out = model.get("pri_out", 0) + + # 获取数据库实例 + self.db = Database.get_instance() + self._init_database() - async def generate_response(self, prompt: str) -> Tuple[str, str]: - """根据输入的提示生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" + def _init_database(self): + """初始化数据库集合""" + try: + # 创建llm_usage集合的索引 + self.db.db.llm_usage.create_index([("timestamp", 1)]) + self.db.db.llm_usage.create_index([("model_name", 1)]) + self.db.db.llm_usage.create_index([("user_id", 1)]) + self.db.db.llm_usage.create_index([("request_type", 1)]) + except Exception as e: + logger.error(f"创建数据库索引失败: {e}") + + def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int, + user_id: str = "system", request_type: str = "chat", + endpoint: str = "/chat/completions"): + """记录模型使用情况到数据库 + Args: + prompt_tokens: 输入token数 + completion_tokens: 输出token数 + total_tokens: 总token数 + user_id: 用户ID,默认为system + request_type: 请求类型(chat/embedding/image等) + endpoint: API端点 + """ + try: + usage_data = { + "model_name": self.model_name, + "user_id": user_id, + "request_type": request_type, + "endpoint": endpoint, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "cost": self._calculate_cost(prompt_tokens, completion_tokens), + "status": "success", + "timestamp": datetime.now() + } + self.db.db.llm_usage.insert_one(usage_data) + logger.info( + f"Token使用情况 - 模型: {self.model_name}, " + f"用户: {user_id}, 类型: {request_type}, " + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {e}") + + def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: + """计算API调用成本 + 使用模型的pri_in和pri_out价格计算输入和输出的成本 + + Args: + prompt_tokens: 输入token数量 + completion_tokens: 输出token数量 + + Returns: + float: 总成本(元) + """ + # 使用模型的pri_in和pri_out计算成本 + input_cost = (prompt_tokens / 1000000) * self.pri_in + output_cost = (completion_tokens / 1000000) * self.pri_out + return round(input_cost + output_cost, 6) + + async def _execute_request( + self, + endpoint: str, + prompt: str = None, + image_base64: str = None, + payload: dict = None, + retry_policy: dict = None, + response_handler: callable = None, + user_id: str = "system", + request_type: str = "chat" + ): + """统一请求执行入口 + Args: + endpoint: API端点路径 (如 "chat/completions") + prompt: prompt文本 + image_base64: 图片的base64编码 + payload: 请求体数据 + retry_policy: 自定义重试策略 + response_handler: 自定义响应处理器 + user_id: 用户ID + request_type: 请求类型 + """ + # 合并重试策略 + default_retry = { + "max_retries": 3, "base_wait": 15, + "retry_codes": [429, 413, 500, 503], + "abort_codes": [400, 401, 402, 403]} + policy = {**default_retry, **(retry_policy or {})} + + # 常见Error Code Mapping + error_code_mapping = { + 400: "参数不正确", + 401: "API key 错误,认证失败", + 402: "账号余额不足", + 403: "需要实名,或余额不足", + 404: "Not Found", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高" } + api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" + logger.info(f"发送请求到URL: {api_url}") + logger.info(f"使用模型: {self.model_name}") + # 构建请求体 - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params - } + if image_base64: + payload = await self._build_payload(prompt, image_base64) + elif payload is None: + payload = await self._build_payload(prompt) - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): + for retry in range(policy["max_retries"]): try: + # 使用上下文管理器处理会话 + headers = await self._build_headers() + async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") + async with session.post(api_url, headers=headers, json=payload) as response: + # 处理需要重试的状态码 + if response.status in policy["retry_codes"]: + wait_time = policy["base_wait"] * (2 ** retry) + logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") + if response.status == 413: + logger.warning("请求体过大,尝试压缩...") + image_base64 = compress_base64_image_by_scale(image_base64) + payload = await self._build_payload(prompt, image_base64) + elif response.status in [500, 503]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + else: + logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") + await asyncio.sleep(wait_time) continue - - if response.status in [500, 503]: - logger.error(f"服务器错误: {response.status}") - raise RuntimeError("服务器负载过高,模型恢复失败QAQ") - - response.raise_for_status() # 检查其他响应状态 + elif response.status in policy["abort_codes"]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") + response.raise_for_status() result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" + + # 使用自定义处理器或默认处理 + return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) + if retry < policy["max_retries"] - 1: + wait_time = policy["base_wait"] * (2 ** retry) + logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) + logger.critical(f"请求失败: {str(e)}") + logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") logger.error("达到最大重试次数,请求仍然失败") raise RuntimeError("达到最大重试次数,API请求仍然失败") - async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """根据输入的提示和图片生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - # 构建请求体 - def build_request_data(img_base64: str): + async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: + """构建请求体""" + if image_base64: return { "model": self.model_name, "messages": [ { "role": "user", "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{img_base64}" - } - } + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} ] } ], + "max_tokens": global_config.max_response_length, + **self.params + } + else: + return { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, **self.params } - - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL + def _default_response_handler(self, result: dict, user_id: str = "system", + request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple: + """默认响应解析""" + if "choices" in result and result["choices"]: + message = result["choices"][0]["message"] + content = message.get("content", "") + content, reasoning = self._extract_reasoning(content) + reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") + if not reasoning_content: + reasoning_content = message.get("reasoning_content", "") + if not reasoning_content: + reasoning_content = reasoning - max_retries = 3 - base_wait_time = 15 - - current_image_base64 = image_base64 - current_image_base64 = compress_base64_image_by_scale(current_image_base64) - + # 记录token使用情况 + usage = result.get("usage", {}) + if usage: + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + total_tokens = usage.get("total_tokens", 0) + self._record_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + user_id=user_id, + request_type=request_type, + endpoint=endpoint + ) - for retry in range(max_retries): - try: - data = build_request_data(current_image_base64) - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue + return content, reasoning_content - elif response.status == 413: - logger.warning("图片太大(413),尝试压缩...") - current_image_base64 = compress_base64_image_by_scale(current_image_base64) - continue - - response.raise_for_status() # 检查其他响应状态 + return "没有返回结果", "" - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" + def _extract_reasoning(self, content: str) -> tuple[str, str]: + """CoT思维链提取""" + match = re.search(r'(?:)?(.*?)', content, re.DOTALL) + content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() + if match: + reasoning = match.group(1).strip() + else: + reasoning = "" + return content, reasoning - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - raise RuntimeError(f"API请求失败: {str(e)}") - - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") - - async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: - """异步方式根据输入的提示生成模型的响应""" - headers = { + async def _build_headers(self) -> dict: + """构建请求头""" + return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } - + + async def generate_response(self, prompt: str) -> Tuple[str, str]: + """根据输入的提示生成模型的异步响应""" + + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt + ) + return content, reasoning_content + + async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: + """根据输入的提示和图片生成模型的异步响应""" + + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt, + image_base64=image_base64 + ) + return content, reasoning_content + + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]: + """异步方式根据输入的提示生成模型的响应""" # 构建请求体 data = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], - "temperature": 0.5, - **self.params - } - - # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"Request URL: {api_url}") # 记录请求的 URL - - max_retries = 3 - base_wait_time = 15 - - async with aiohttp.ClientSession() as session: - for retry in range(max_retries): - try: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - content = result["choices"][0]["message"]["content"] - reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") - await asyncio.sleep(wait_time) - else: - logger.error(f"请求失败: {str(e)}") - return f"请求失败: {str(e)}", "" - - logger.error("达到最大重试次数,请求仍然失败") - return "达到最大重试次数,请求仍然失败", "" - - - - def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """同步方法:根据输入的提示和图片生成模型的响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - image_base64=compress_base64_image_by_scale(image_base64) - - # 构建请求体 - data = { - "model": self.model_name, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - } - } - ] - } - ], + "max_tokens": global_config.max_response_length, **self.params } - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + payload=data, + prompt=prompt + ) + return content, reasoning_content - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - raise RuntimeError(f"API请求失败: {str(e)}") - - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") - - def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: - """同步方法:获取文本的embedding向量 - - Args: - text: 需要获取embedding的文本 - model: 使用的模型名称,默认为"BAAI/bge-m3" - - Returns: - list: embedding向量,如果失败则返回None - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": model, - "input": text, - "encoding_format": "float" - } - - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() - - result = response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None - - async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: + async def get_embedding(self, text: str) -> Union[list, None]: """异步方法:获取文本的embedding向量 Args: text: 需要获取embedding的文本 - model: 使用的模型名称,默认为"BAAI/bge-m3" Returns: list: embedding向量,如果失败则返回None """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } + def embedding_handler(result): + """处理响应""" + if "data" in result and len(result["data"]) > 0: + return result["data"][0].get("embedding", None) + return None - data = { - "model": model, - "input": text, - "encoding_format": "float" - } + embedding = await self._execute_request( + endpoint="/embeddings", + prompt=text, + payload={ + "model": self.model_name, + "input": text, + "encoding_format": "float" + }, + retry_policy={ + "max_retries": 2, + "base_wait": 6 + }, + response_handler=embedding_handler + ) + return embedding - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() - - result = await response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py new file mode 100644 index 000000000..093ace539 --- /dev/null +++ b/src/plugins/utils/statistic.py @@ -0,0 +1,162 @@ +from typing import Dict, List, Any +import time +import threading +import json +from datetime import datetime, timedelta +from collections import defaultdict +from ...common.database import Database + +class LLMStatistics: + def __init__(self, output_file: str = "llm_statistics.txt"): + """初始化LLM统计类 + + Args: + output_file: 统计结果输出文件路径 + """ + self.db = Database.get_instance() + self.output_file = output_file + self.running = False + self.stats_thread = None + + def start(self): + """启动统计线程""" + if not self.running: + self.running = True + self.stats_thread = threading.Thread(target=self._stats_loop) + self.stats_thread.daemon = True + self.stats_thread.start() + + def stop(self): + """停止统计线程""" + self.running = False + if self.stats_thread: + self.stats_thread.join() + + def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]: + """收集指定时间段的LLM请求统计数据 + + Args: + start_time: 统计开始时间 + """ + stats = { + "total_requests": 0, + "requests_by_type": defaultdict(int), + "requests_by_user": defaultdict(int), + "requests_by_model": defaultdict(int), + "average_tokens": 0, + "total_tokens": 0, + "total_cost": 0.0, + "costs_by_user": defaultdict(float), + "costs_by_type": defaultdict(float), + "costs_by_model": defaultdict(float) + } + + cursor = self.db.db.llm_usage.find({ + "timestamp": {"$gte": start_time} + }) + + total_requests = 0 + + for doc in cursor: + stats["total_requests"] += 1 + request_type = doc.get("request_type", "unknown") + user_id = str(doc.get("user_id", "unknown")) + model_name = doc.get("model_name", "unknown") + + stats["requests_by_type"][request_type] += 1 + stats["requests_by_user"][user_id] += 1 + stats["requests_by_model"][model_name] += 1 + + prompt_tokens = doc.get("prompt_tokens", 0) + completion_tokens = doc.get("completion_tokens", 0) + stats["total_tokens"] += prompt_tokens + completion_tokens + + cost = doc.get("cost", 0.0) + stats["total_cost"] += cost + stats["costs_by_user"][user_id] += cost + stats["costs_by_type"][request_type] += cost + stats["costs_by_model"][model_name] += cost + + total_requests += 1 + + if total_requests > 0: + stats["average_tokens"] = stats["total_tokens"] / total_requests + + return stats + + def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]: + """收集所有时间范围的统计数据""" + now = datetime.now() + + return { + "all_time": self._collect_statistics_for_period(datetime.min), + "last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)), + "last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)), + "last_hour": self._collect_statistics_for_period(now - timedelta(hours=1)) + } + + def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str: + """格式化统计部分的输出 + + Args: + stats: 统计数据 + title: 部分标题 + """ + output = [] + output.append(f"\n{title}") + output.append("=" * len(title)) + + output.append(f"总请求数: {stats['total_requests']}") + if stats['total_requests'] > 0: + output.append(f"总Token数: {stats['total_tokens']}") + output.append(f"总花费: ¥{stats['total_cost']:.4f}") + + output.append("\n按模型统计:") + for model_name, count in sorted(stats["requests_by_model"].items()): + cost = stats["costs_by_model"][model_name] + output.append(f"- {model_name}: {count}次 (花费: ¥{cost:.4f})") + + output.append("\n按请求类型统计:") + for req_type, count in sorted(stats["requests_by_type"].items()): + cost = stats["costs_by_type"][req_type] + output.append(f"- {req_type}: {count}次 (花费: ¥{cost:.4f})") + + return "\n".join(output) + + def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]): + """将统计结果保存到文件""" + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + output = [] + output.append(f"LLM请求统计报告 (生成时间: {current_time})") + output.append("=" * 50) + + # 添加各个时间段的统计 + sections = [ + ("所有时间统计", "all_time"), + ("最近7天统计", "last_7_days"), + ("最近24小时统计", "last_24_hours"), + ("最近1小时统计", "last_hour") + ] + + for title, key in sections: + output.append(self._format_stats_section(all_stats[key], title)) + + # 写入文件 + with open(self.output_file, "w", encoding="utf-8") as f: + f.write("\n".join(output)) + + def _stats_loop(self): + """统计循环,每1分钟运行一次""" + while self.running: + try: + all_stats = self._collect_all_statistics() + self._save_statistics(all_stats) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}") + + # 等待1分钟 + for _ in range(60): + if not self.running: + break + time.sleep(1) diff --git a/src/plugins/utils/typo_generator.py b/src/plugins/utils/typo_generator.py new file mode 100644 index 000000000..16834200f --- /dev/null +++ b/src/plugins/utils/typo_generator.py @@ -0,0 +1,437 @@ +""" +错别字生成器 - 基于拼音和字频的中文错别字生成工具 +""" + +from pypinyin import pinyin, Style +from collections import defaultdict +import json +import os +import jieba +from pathlib import Path +import random +import math +import time + +class ChineseTypoGenerator: + def __init__(self, + error_rate=0.3, + min_freq=5, + tone_error_rate=0.2, + word_replace_rate=0.3, + max_freq_diff=200): + """ + 初始化错别字生成器 + + 参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + self.error_rate = error_rate + self.min_freq = min_freq + self.tone_error_rate = tone_error_rate + self.word_replace_rate = word_replace_rate + self.max_freq_diff = max_freq_diff + + # 加载数据 + print("正在加载汉字数据库,请稍候...") + self.pinyin_dict = self._create_pinyin_dict() + self.char_frequency = self._load_or_create_char_frequency() + + def _load_or_create_char_frequency(self): + """ + 加载或创建汉字频率字典 + """ + cache_file = Path("char_frequency.json") + + # 如果缓存文件存在,直接加载 + if cache_file.exists(): + with open(cache_file, 'r', encoding='utf-8') as f: + return json.load(f) + + # 使用内置的词频文件 + char_freq = defaultdict(int) + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + + # 读取jieba的词典文件 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, freq = line.strip().split()[:2] + # 对词中的每个字进行频率累加 + for char in word: + if self._is_chinese_char(char): + char_freq[char] += int(freq) + + # 归一化频率值 + max_freq = max(char_freq.values()) + normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} + + # 保存到缓存文件 + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(normalized_freq, f, ensure_ascii=False, indent=2) + + return normalized_freq + + def _create_pinyin_dict(self): + """ + 创建拼音到汉字的映射字典 + """ + # 常用汉字范围 + chars = [chr(i) for i in range(0x4e00, 0x9fff)] + pinyin_dict = defaultdict(list) + + # 为每个汉字建立拼音映射 + for char in chars: + try: + py = pinyin(char, style=Style.TONE3)[0][0] + pinyin_dict[py].append(char) + except Exception: + continue + + return pinyin_dict + + def _is_chinese_char(self, char): + """ + 判断是否为汉字 + """ + try: + return '\u4e00' <= char <= '\u9fff' + except: + return False + + def _get_pinyin(self, sentence): + """ + 将中文句子拆分成单个汉字并获取其拼音 + """ + # 将句子拆分成单个字符 + characters = list(sentence) + + # 获取每个字符的拼音 + result = [] + for char in characters: + # 跳过空格和非汉字字符 + if char.isspace() or not self._is_chinese_char(char): + continue + # 获取拼音(数字声调) + py = pinyin(char, style=Style.TONE3)[0][0] + result.append((char, py)) + + return result + + def _get_similar_tone_pinyin(self, py): + """ + 获取相似声调的拼音 + """ + # 检查拼音是否为空或无效 + if not py or len(py) < 1: + return py + + # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 + if not py[-1].isdigit(): + # 为非数字结尾的拼音添加数字声调1 + return py + '1' + + base = py[:-1] # 去掉声调 + tone = int(py[-1]) # 获取声调 + + # 处理轻声(通常用5表示)或无效声调 + if tone not in [1, 2, 3, 4]: + return base + str(random.choice([1, 2, 3, 4])) + + # 正常处理声调 + possible_tones = [1, 2, 3, 4] + possible_tones.remove(tone) # 移除原声调 + new_tone = random.choice(possible_tones) # 随机选择一个新声调 + return base + str(new_tone) + + def _calculate_replacement_probability(self, orig_freq, target_freq): + """ + 根据频率差计算替换概率 + """ + if target_freq > orig_freq: + return 1.0 # 如果替换字频率更高,保持原有概率 + + freq_diff = orig_freq - target_freq + if freq_diff > self.max_freq_diff: + return 0.0 # 频率差太大,不替换 + + # 使用指数衰减函数计算概率 + # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 + return math.exp(-3 * freq_diff / self.max_freq_diff) + + def _get_similar_frequency_chars(self, char, py, num_candidates=5): + """ + 获取与给定字频率相近的同音字,可能包含声调错误 + """ + homophones = [] + + # 有一定概率使用错误声调 + if random.random() < self.tone_error_rate: + wrong_tone_py = self._get_similar_tone_pinyin(py) + homophones.extend(self.pinyin_dict[wrong_tone_py]) + + # 添加正确声调的同音字 + homophones.extend(self.pinyin_dict[py]) + + if not homophones: + return None + + # 获取原字的频率 + orig_freq = self.char_frequency.get(char, 0) + + # 计算所有同音字与原字的频率差,并过滤掉低频字 + freq_diff = [(h, self.char_frequency.get(h, 0)) + for h in homophones + if h != char and self.char_frequency.get(h, 0) >= self.min_freq] + + if not freq_diff: + return None + + # 计算每个候选字的替换概率 + candidates_with_prob = [] + for h, freq in freq_diff: + prob = self._calculate_replacement_probability(orig_freq, freq) + if prob > 0: # 只保留有效概率的候选字 + candidates_with_prob.append((h, prob)) + + if not candidates_with_prob: + return None + + # 根据概率排序 + candidates_with_prob.sort(key=lambda x: x[1], reverse=True) + + # 返回概率最高的几个字 + return [char for char, _ in candidates_with_prob[:num_candidates]] + + def _get_word_pinyin(self, word): + """ + 获取词语的拼音列表 + """ + return [py[0] for py in pinyin(word, style=Style.TONE3)] + + def _segment_sentence(self, sentence): + """ + 使用jieba分词,返回词语列表 + """ + return list(jieba.cut(sentence)) + + def _get_word_homophones(self, word): + """ + 获取整个词的同音词,只返回高频的有意义词语 + """ + if len(word) == 1: + return [] + + # 获取词的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 遍历所有可能的同音字组合 + candidates = [] + for py in word_pinyin: + chars = self.pinyin_dict.get(py, []) + if not chars: + return [] + candidates.append(chars) + + # 生成所有可能的组合 + import itertools + all_combinations = itertools.product(*candidates) + + # 获取jieba词典和词频信息 + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + valid_words = {} # 改用字典存储词语及其频率 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 2: + word_text = parts[0] + word_freq = float(parts[1]) # 获取词频 + valid_words[word_text] = word_freq + + # 获取原词的词频作为参考 + original_word_freq = valid_words.get(word, 0) + min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% + + # 过滤和计算频率 + homophones = [] + for combo in all_combinations: + new_word = ''.join(combo) + if new_word != word and new_word in valid_words: + new_word_freq = valid_words[new_word] + # 只保留词频达到阈值的词 + if new_word_freq >= min_word_freq: + # 计算词的平均字频(考虑字频和词频) + char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word) + # 综合评分:结合词频和字频 + combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) + if combined_score >= self.min_freq: + homophones.append((new_word, combined_score)) + + # 按综合分数排序并限制返回数量 + sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) + return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 + + def create_typo_sentence(self, sentence): + """ + 创建包含同音字错误的句子,支持词语级别和字级别的替换 + + 参数: + sentence: 输入的中文句子 + + 返回: + typo_sentence: 包含错别字的句子 + typo_info: 错别字信息列表 + """ + result = [] + typo_info = [] + + # 分词 + words = self._segment_sentence(sentence) + + for word in words: + # 如果是标点符号或空格,直接添加 + if all(not self._is_chinese_char(c) for c in word): + result.append(word) + continue + + # 获取词语的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 尝试整词替换 + if len(word) > 1 and random.random() < self.word_replace_rate: + word_homophones = self._get_word_homophones(word) + if word_homophones: + typo_word = random.choice(word_homophones) + # 计算词的平均频率 + orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) + typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) + + # 添加到结果中 + result.append(typo_word) + typo_info.append((word, typo_word, + ' '.join(word_pinyin), + ' '.join(self._get_word_pinyin(typo_word)), + orig_freq, typo_freq)) + continue + + # 如果不进行整词替换,则进行单字替换 + if len(word) == 1: + char = word + py = word_pinyin[0] + if random.random() < self.error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + result.append(char) + else: + # 处理多字词的单字替换 + word_result = [] + for i, (char, py) in enumerate(zip(word, word_pinyin)): + # 词中的字替换概率降低 + word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) + + if random.random() < word_error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + word_result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + word_result.append(char) + result.append(''.join(word_result)) + + return ''.join(result), typo_info + + def format_typo_info(self, typo_info): + """ + 格式化错别字信息 + + 参数: + typo_info: 错别字信息列表 + + 返回: + 格式化后的错别字信息字符串 + """ + if not typo_info: + return "未生成错别字" + + result = [] + for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: + # 判断是否为词语替换 + is_word = ' ' in orig_py + if is_word: + error_type = "整词替换" + else: + tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] + error_type = "声调错误" if tone_error else "同音字替换" + + result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " + f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") + + return "\n".join(result) + + def set_params(self, **kwargs): + """ + 设置参数 + + 可设置参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + print(f"参数 {key} 已设置为 {value}") + else: + print(f"警告: 参数 {key} 不存在") + +def main(): + # 创建错别字生成器实例 + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.02, + word_replace_rate=0.3 + ) + + # 获取用户输入 + sentence = input("请输入中文句子:") + + # 创建包含错别字的句子 + start_time = time.time() + typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) + + # 打印结果 + print("\n原句:", sentence) + print("错字版:", typo_sentence) + + # 打印错别字信息 + if typo_info: + print("\n错别字信息:") + print(typo_generator.format_typo_info(typo_info)) + + # 计算并打印总耗时 + end_time = time.time() + total_time = end_time - start_time + print(f"\n总耗时:{total_time:.2f}秒") + +if __name__ == "__main__": + main() diff --git a/src/test/typo.py b/src/test/typo.py index c452589ce..16834200f 100644 --- a/src/test/typo.py +++ b/src/test/typo.py @@ -1,455 +1,376 @@ """ -错别字生成器 - 流程说明 - -整体替换逻辑: -1. 数据准备 - - 加载字频词典:使用jieba词典计算汉字使用频率 - - 创建拼音映射:建立拼音到汉字的映射关系 - - 加载词频信息:从jieba词典获取词语使用频率 - -2. 分词处理 - - 使用jieba将输入句子分词 - - 区分单字词和多字词 - - 保留标点符号和空格 - -3. 词语级别替换(针对多字词) - - 触发条件:词长>1 且 随机概率<0.3 - - 替换流程: - a. 获取词语拼音 - b. 生成所有可能的同音字组合 - c. 过滤条件: - - 必须是jieba词典中的有效词 - - 词频必须达到原词频的10%以上 - - 综合评分(词频70%+字频30%)必须达到阈值 - d. 按综合评分排序,选择最合适的替换词 - -4. 字级别替换(针对单字词或未进行整词替换的多字词) - - 单字替换概率:0.3 - - 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1)) - - 替换流程: - a. 获取字的拼音 - b. 声调错误处理(20%概率) - c. 获取同音字列表 - d. 过滤条件: - - 字频必须达到最小阈值 - - 频率差异不能过大(指数衰减计算) - e. 按频率排序选择替换字 - -5. 频率控制机制 - - 字频控制:使用归一化的字频(0-1000范围) - - 词频控制:使用jieba词典中的词频 - - 频率差异计算:使用指数衰减函数 - - 最小频率阈值:确保替换字/词不会太生僻 - -6. 输出信息 - - 原文和错字版本的对照 - - 每个替换的详细信息(原字/词、替换后字/词、拼音、频率) - - 替换类型说明(整词替换/声调错误/同音字替换) - - 词语分析和完整拼音 - -注意事项: -1. 所有替换都必须使用有意义的词语 -2. 替换词的使用频率不能过低 -3. 多字词优先考虑整词替换 -4. 考虑声调变化的情况 -5. 保持标点符号和空格不变 +错别字生成器 - 基于拼音和字频的中文错别字生成工具 """ from pypinyin import pinyin, Style from collections import defaultdict import json import os -import unicodedata import jieba -import jieba.posseg as pseg from pathlib import Path import random import math import time -def load_or_create_char_frequency(): - """ - 加载或创建汉字频率字典 - """ - cache_file = Path("char_frequency.json") - - # 如果缓存文件存在,直接加载 - if cache_file.exists(): - with open(cache_file, 'r', encoding='utf-8') as f: - return json.load(f) - - # 使用内置的词频文件 - char_freq = defaultdict(int) - dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') - - # 读取jieba的词典文件 - with open(dict_path, 'r', encoding='utf-8') as f: - for line in f: - word, freq = line.strip().split()[:2] - # 对词中的每个字进行频率累加 - for char in word: - if is_chinese_char(char): - char_freq[char] += int(freq) - - # 归一化频率值 - max_freq = max(char_freq.values()) - normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} - - # 保存到缓存文件 - with open(cache_file, 'w', encoding='utf-8') as f: - json.dump(normalized_freq, f, ensure_ascii=False, indent=2) - - return normalized_freq - -# 创建拼音到汉字的映射字典 -def create_pinyin_dict(): - """ - 创建拼音到汉字的映射字典 - """ - # 常用汉字范围 - chars = [chr(i) for i in range(0x4e00, 0x9fff)] - pinyin_dict = defaultdict(list) - - # 为每个汉字建立拼音映射 - for char in chars: - try: - py = pinyin(char, style=Style.TONE3)[0][0] - pinyin_dict[py].append(char) - except Exception: - continue - - return pinyin_dict - -def is_chinese_char(char): - """ - 判断是否为汉字 - """ - try: - return '\u4e00' <= char <= '\u9fff' - except: - return False - -def get_pinyin(sentence): - """ - 将中文句子拆分成单个汉字并获取其拼音 - :param sentence: 输入的中文句子 - :return: 每个汉字及其拼音的列表 - """ - # 将句子拆分成单个字符 - characters = list(sentence) - - # 获取每个字符的拼音 - result = [] - for char in characters: - # 跳过空格和非汉字字符 - if char.isspace() or not is_chinese_char(char): - continue - # 获取拼音(数字声调) - py = pinyin(char, style=Style.TONE3)[0][0] - result.append((char, py)) - - return result - -def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5): - """ - 获取同音字,按照使用频率排序 - """ - homophones = pinyin_dict[py] - # 移除原字并过滤低频字 - if char in homophones: - homophones.remove(char) - - # 过滤掉低频字 - homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq] - - # 按照字频排序 - sorted_homophones = sorted(homophones, - key=lambda x: char_frequency.get(x, 0), - reverse=True) - - # 只返回前10个同音字,避免输出过多 - return sorted_homophones[:10] - -def get_similar_tone_pinyin(py): - """ - 获取相似声调的拼音 - 例如:'ni3' 可能返回 'ni2' 或 'ni4' - 处理特殊情况: - 1. 轻声(如 'de5' 或 'le') - 2. 非数字结尾的拼音 - """ - # 检查拼音是否为空或无效 - if not py or len(py) < 1: - return py +class ChineseTypoGenerator: + def __init__(self, + error_rate=0.3, + min_freq=5, + tone_error_rate=0.2, + word_replace_rate=0.3, + max_freq_diff=200): + """ + 初始化错别字生成器 - # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 - if not py[-1].isdigit(): - # 为非数字结尾的拼音添加数字声调1 - return py + '1' - - base = py[:-1] # 去掉声调 - tone = int(py[-1]) # 获取声调 - - # 处理轻声(通常用5表示)或无效声调 - if tone not in [1, 2, 3, 4]: - return base + str(random.choice([1, 2, 3, 4])) - - # 正常处理声调 - possible_tones = [1, 2, 3, 4] - possible_tones.remove(tone) # 移除原声调 - new_tone = random.choice(possible_tones) # 随机选择一个新声调 - return base + str(new_tone) - -def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200): - """ - 根据频率差计算替换概率 - 频率差越大,概率越低 - :param orig_freq: 原字频率 - :param target_freq: 目标字频率 - :param max_freq_diff: 最大允许的频率差 - :return: 0-1之间的概率值 - """ - if target_freq > orig_freq: - return 1.0 # 如果替换字频率更高,保持原有概率 - - freq_diff = orig_freq - target_freq - if freq_diff > max_freq_diff: - return 0.0 # 频率差太大,不替换 - - # 使用指数衰减函数计算概率 - # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 - return math.exp(-3 * freq_diff / max_freq_diff) - -def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2): - """ - 获取与给定字频率相近的同音字,可能包含声调错误 - """ - homophones = [] - - # 有20%的概率使用错误声调 - if random.random() < tone_error_rate: - wrong_tone_py = get_similar_tone_pinyin(py) - homophones.extend(pinyin_dict[wrong_tone_py]) - - # 添加正确声调的同音字 - homophones.extend(pinyin_dict[py]) - - if not homophones: - return None + 参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + self.error_rate = error_rate + self.min_freq = min_freq + self.tone_error_rate = tone_error_rate + self.word_replace_rate = word_replace_rate + self.max_freq_diff = max_freq_diff - # 获取原字的频率 - orig_freq = char_frequency.get(char, 0) + # 加载数据 + print("正在加载汉字数据库,请稍候...") + self.pinyin_dict = self._create_pinyin_dict() + self.char_frequency = self._load_or_create_char_frequency() - # 计算所有同音字与原字的频率差,并过滤掉低频字 - freq_diff = [(h, char_frequency.get(h, 0)) - for h in homophones - if h != char and char_frequency.get(h, 0) >= min_freq] - - if not freq_diff: - return None - - # 计算每个候选字的替换概率 - candidates_with_prob = [] - for h, freq in freq_diff: - prob = calculate_replacement_probability(orig_freq, freq) - if prob > 0: # 只保留有效概率的候选字 - candidates_with_prob.append((h, prob)) - - if not candidates_with_prob: - return None - - # 根据概率排序 - candidates_with_prob.sort(key=lambda x: x[1], reverse=True) - - # 返回概率最高的几个字 - return [char for char, _ in candidates_with_prob[:num_candidates]] - -def get_word_pinyin(word): - """ - 获取词语的拼音列表 - """ - return [py[0] for py in pinyin(word, style=Style.TONE3)] - -def segment_sentence(sentence): - """ - 使用jieba分词,返回词语列表 - """ - return list(jieba.cut(sentence)) - -def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5): - """ - 获取整个词的同音词,只返回高频的有意义词语 - :param word: 输入词语 - :param pinyin_dict: 拼音字典 - :param char_frequency: 字频字典 - :param min_freq: 最小频率阈值 - :return: 同音词列表 - """ - if len(word) == 1: - return [] + def _load_or_create_char_frequency(self): + """ + 加载或创建汉字频率字典 + """ + cache_file = Path("char_frequency.json") - # 获取词的拼音 - word_pinyin = get_word_pinyin(word) - word_pinyin_str = ''.join(word_pinyin) - - # 创建词语频率字典 - word_freq = defaultdict(float) - - # 遍历所有可能的同音字组合 - candidates = [] - for py in word_pinyin: - chars = pinyin_dict.get(py, []) - if not chars: - return [] - candidates.append(chars) - - # 生成所有可能的组合 - import itertools - all_combinations = itertools.product(*candidates) - - # 获取jieba词典和词频信息 - dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') - valid_words = {} # 改用字典存储词语及其频率 - with open(dict_path, 'r', encoding='utf-8') as f: - for line in f: - parts = line.strip().split() - if len(parts) >= 2: - word_text = parts[0] - word_freq = float(parts[1]) # 获取词频 - valid_words[word_text] = word_freq - - # 获取原词的词频作为参考 - original_word_freq = valid_words.get(word, 0) - min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% - - # 过滤和计算频率 - homophones = [] - for combo in all_combinations: - new_word = ''.join(combo) - if new_word != word and new_word in valid_words: - new_word_freq = valid_words[new_word] - # 只保留词频达到阈值的词 - if new_word_freq >= min_word_freq: - # 计算词的平均字频(考虑字频和词频) - char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word) - # 综合评分:结合词频和字频 - combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) - if combined_score >= min_freq: - homophones.append((new_word, combined_score)) - - # 按综合分数排序并限制返回数量 - sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) - return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 - -def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3): - """ - 创建包含同音字错误的句子,支持词语级别和字级别的替换 - 只使用高频的有意义词语进行替换 - """ - result = [] - typo_info = [] - - # 分词 - words = segment_sentence(sentence) - - for word in words: - # 如果是标点符号或空格,直接添加 - if all(not is_chinese_char(c) for c in word): - result.append(word) - continue - - # 获取词语的拼音 - word_pinyin = get_word_pinyin(word) + # 如果缓存文件存在,直接加载 + if cache_file.exists(): + with open(cache_file, 'r', encoding='utf-8') as f: + return json.load(f) - # 尝试整词替换 - if len(word) > 1 and random.random() < word_replace_rate: - word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq) - if word_homophones: - typo_word = random.choice(word_homophones) - # 计算词的平均频率 - orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word) - typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word) - - # 添加到结果中 - result.append(typo_word) - typo_info.append((word, typo_word, - ' '.join(word_pinyin), - ' '.join(get_word_pinyin(typo_word)), - orig_freq, typo_freq)) + # 使用内置的词频文件 + char_freq = defaultdict(int) + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + + # 读取jieba的词典文件 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + word, freq = line.strip().split()[:2] + # 对词中的每个字进行频率累加 + for char in word: + if self._is_chinese_char(char): + char_freq[char] += int(freq) + + # 归一化频率值 + max_freq = max(char_freq.values()) + normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} + + # 保存到缓存文件 + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(normalized_freq, f, ensure_ascii=False, indent=2) + + return normalized_freq + + def _create_pinyin_dict(self): + """ + 创建拼音到汉字的映射字典 + """ + # 常用汉字范围 + chars = [chr(i) for i in range(0x4e00, 0x9fff)] + pinyin_dict = defaultdict(list) + + # 为每个汉字建立拼音映射 + for char in chars: + try: + py = pinyin(char, style=Style.TONE3)[0][0] + pinyin_dict[py].append(char) + except Exception: continue - # 如果不进行整词替换,则进行单字替换 + return pinyin_dict + + def _is_chinese_char(self, char): + """ + 判断是否为汉字 + """ + try: + return '\u4e00' <= char <= '\u9fff' + except: + return False + + def _get_pinyin(self, sentence): + """ + 将中文句子拆分成单个汉字并获取其拼音 + """ + # 将句子拆分成单个字符 + characters = list(sentence) + + # 获取每个字符的拼音 + result = [] + for char in characters: + # 跳过空格和非汉字字符 + if char.isspace() or not self._is_chinese_char(char): + continue + # 获取拼音(数字声调) + py = pinyin(char, style=Style.TONE3)[0][0] + result.append((char, py)) + + return result + + def _get_similar_tone_pinyin(self, py): + """ + 获取相似声调的拼音 + """ + # 检查拼音是否为空或无效 + if not py or len(py) < 1: + return py + + # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 + if not py[-1].isdigit(): + # 为非数字结尾的拼音添加数字声调1 + return py + '1' + + base = py[:-1] # 去掉声调 + tone = int(py[-1]) # 获取声调 + + # 处理轻声(通常用5表示)或无效声调 + if tone not in [1, 2, 3, 4]: + return base + str(random.choice([1, 2, 3, 4])) + + # 正常处理声调 + possible_tones = [1, 2, 3, 4] + possible_tones.remove(tone) # 移除原声调 + new_tone = random.choice(possible_tones) # 随机选择一个新声调 + return base + str(new_tone) + + def _calculate_replacement_probability(self, orig_freq, target_freq): + """ + 根据频率差计算替换概率 + """ + if target_freq > orig_freq: + return 1.0 # 如果替换字频率更高,保持原有概率 + + freq_diff = orig_freq - target_freq + if freq_diff > self.max_freq_diff: + return 0.0 # 频率差太大,不替换 + + # 使用指数衰减函数计算概率 + # 频率差为0时概率为1,频率差为max_freq_diff时概率接近0 + return math.exp(-3 * freq_diff / self.max_freq_diff) + + def _get_similar_frequency_chars(self, char, py, num_candidates=5): + """ + 获取与给定字频率相近的同音字,可能包含声调错误 + """ + homophones = [] + + # 有一定概率使用错误声调 + if random.random() < self.tone_error_rate: + wrong_tone_py = self._get_similar_tone_pinyin(py) + homophones.extend(self.pinyin_dict[wrong_tone_py]) + + # 添加正确声调的同音字 + homophones.extend(self.pinyin_dict[py]) + + if not homophones: + return None + + # 获取原字的频率 + orig_freq = self.char_frequency.get(char, 0) + + # 计算所有同音字与原字的频率差,并过滤掉低频字 + freq_diff = [(h, self.char_frequency.get(h, 0)) + for h in homophones + if h != char and self.char_frequency.get(h, 0) >= self.min_freq] + + if not freq_diff: + return None + + # 计算每个候选字的替换概率 + candidates_with_prob = [] + for h, freq in freq_diff: + prob = self._calculate_replacement_probability(orig_freq, freq) + if prob > 0: # 只保留有效概率的候选字 + candidates_with_prob.append((h, prob)) + + if not candidates_with_prob: + return None + + # 根据概率排序 + candidates_with_prob.sort(key=lambda x: x[1], reverse=True) + + # 返回概率最高的几个字 + return [char for char, _ in candidates_with_prob[:num_candidates]] + + def _get_word_pinyin(self, word): + """ + 获取词语的拼音列表 + """ + return [py[0] for py in pinyin(word, style=Style.TONE3)] + + def _segment_sentence(self, sentence): + """ + 使用jieba分词,返回词语列表 + """ + return list(jieba.cut(sentence)) + + def _get_word_homophones(self, word): + """ + 获取整个词的同音词,只返回高频的有意义词语 + """ if len(word) == 1: - char = word - py = word_pinyin[0] - if random.random() < error_rate: - similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, - min_freq=min_freq, tone_error_rate=tone_error_rate) - if similar_chars: - typo_char = random.choice(similar_chars) - typo_freq = char_frequency.get(typo_char, 0) - orig_freq = char_frequency.get(char, 0) - replace_prob = calculate_replacement_probability(orig_freq, typo_freq) - if random.random() < replace_prob: - result.append(typo_char) - typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] - typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) - continue - result.append(char) - else: - # 处理多字词的单字替换 - word_result = [] - for i, (char, py) in enumerate(zip(word, word_pinyin)): - # 词中的字替换概率降低 - word_error_rate = error_rate * (0.7 ** (len(word) - 1)) + return [] + + # 获取词的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 遍历所有可能的同音字组合 + candidates = [] + for py in word_pinyin: + chars = self.pinyin_dict.get(py, []) + if not chars: + return [] + candidates.append(chars) + + # 生成所有可能的组合 + import itertools + all_combinations = itertools.product(*candidates) + + # 获取jieba词典和词频信息 + dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') + valid_words = {} # 改用字典存储词语及其频率 + with open(dict_path, 'r', encoding='utf-8') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 2: + word_text = parts[0] + word_freq = float(parts[1]) # 获取词频 + valid_words[word_text] = word_freq + + # 获取原词的词频作为参考 + original_word_freq = valid_words.get(word, 0) + min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% + + # 过滤和计算频率 + homophones = [] + for combo in all_combinations: + new_word = ''.join(combo) + if new_word != word and new_word in valid_words: + new_word_freq = valid_words[new_word] + # 只保留词频达到阈值的词 + if new_word_freq >= min_word_freq: + # 计算词的平均字频(考虑字频和词频) + char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word) + # 综合评分:结合词频和字频 + combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) + if combined_score >= self.min_freq: + homophones.append((new_word, combined_score)) + + # 按综合分数排序并限制返回数量 + sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) + return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 + + def create_typo_sentence(self, sentence): + """ + 创建包含同音字错误的句子,支持词语级别和字级别的替换 + + 参数: + sentence: 输入的中文句子 + + 返回: + typo_sentence: 包含错别字的句子 + typo_info: 错别字信息列表 + """ + result = [] + typo_info = [] + + # 分词 + words = self._segment_sentence(sentence) + + for word in words: + # 如果是标点符号或空格,直接添加 + if all(not self._is_chinese_char(c) for c in word): + result.append(word) + continue - if random.random() < word_error_rate: - similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, - min_freq=min_freq, tone_error_rate=tone_error_rate) + # 获取词语的拼音 + word_pinyin = self._get_word_pinyin(word) + + # 尝试整词替换 + if len(word) > 1 and random.random() < self.word_replace_rate: + word_homophones = self._get_word_homophones(word) + if word_homophones: + typo_word = random.choice(word_homophones) + # 计算词的平均频率 + orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) + typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) + + # 添加到结果中 + result.append(typo_word) + typo_info.append((word, typo_word, + ' '.join(word_pinyin), + ' '.join(self._get_word_pinyin(typo_word)), + orig_freq, typo_freq)) + continue + + # 如果不进行整词替换,则进行单字替换 + if len(word) == 1: + char = word + py = word_pinyin[0] + if random.random() < self.error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) if similar_chars: typo_char = random.choice(similar_chars) - typo_freq = char_frequency.get(typo_char, 0) - orig_freq = char_frequency.get(char, 0) - replace_prob = calculate_replacement_probability(orig_freq, typo_freq) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) if random.random() < replace_prob: - word_result.append(typo_char) + result.append(typo_char) typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) continue - word_result.append(char) - result.append(''.join(word_result)) - - return ''.join(result), typo_info + result.append(char) + else: + # 处理多字词的单字替换 + word_result = [] + for i, (char, py) in enumerate(zip(word, word_pinyin)): + # 词中的字替换概率降低 + word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) + + if random.random() < word_error_rate: + similar_chars = self._get_similar_frequency_chars(char, py) + if similar_chars: + typo_char = random.choice(similar_chars) + typo_freq = self.char_frequency.get(typo_char, 0) + orig_freq = self.char_frequency.get(char, 0) + replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq) + if random.random() < replace_prob: + word_result.append(typo_char) + typo_py = pinyin(typo_char, style=Style.TONE3)[0][0] + typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq)) + continue + word_result.append(char) + result.append(''.join(word_result)) + + return ''.join(result), typo_info -def format_frequency(freq): - """ - 格式化频率显示 - """ - return f"{freq:.2f}" - -def main(): - # 记录开始时间 - start_time = time.time() - - # 首先创建拼音字典和加载字频统计 - print("正在加载汉字数据库,请稍候...") - pinyin_dict = create_pinyin_dict() - char_frequency = load_or_create_char_frequency() - - # 获取用户输入 - sentence = input("请输入中文句子:") - - # 创建包含错别字的句子 - typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency, - error_rate=0.3, min_freq=5, - tone_error_rate=0.2, word_replace_rate=0.3) - - # 打印结果 - print("\n原句:", sentence) - print("错字版:", typo_sentence) - - if typo_info: - print("\n错别字信息:") + def format_typo_info(self, typo_info): + """ + 格式化错别字信息 + + 参数: + typo_info: 错别字信息列表 + + 返回: + 格式化后的错别字信息字符串 + """ + if not typo_info: + return "未生成错别字" + + result = [] for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: # 判断是否为词语替换 is_word = ' ' in orig_py @@ -459,25 +380,53 @@ def main(): tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] error_type = "声调错误" if tone_error else "同音字替换" - print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> " - f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]") + result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " + f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") + + return "\n".join(result) - # 获取拼音结果 - result = get_pinyin(sentence) + def set_params(self, **kwargs): + """ + 设置参数 + + 可设置参数: + error_rate: 单字替换概率 + min_freq: 最小字频阈值 + tone_error_rate: 声调错误概率 + word_replace_rate: 整词替换概率 + max_freq_diff: 最大允许的频率差异 + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + print(f"参数 {key} 已设置为 {value}") + else: + print(f"警告: 参数 {key} 不存在") + +def main(): + # 创建错别字生成器实例 + typo_generator = ChineseTypoGenerator( + error_rate=0.03, + min_freq=7, + tone_error_rate=0.02, + word_replace_rate=0.3 + ) - # 打印完整拼音 - print("\n完整拼音:") - print(" ".join(py for _, py in result)) + # 获取用户输入 + sentence = input("请输入中文句子:") - # 打印词语分析 - print("\n词语分析:") - words = segment_sentence(sentence) - for word in words: - if any(is_chinese_char(c) for c in word): - word_pinyin = get_word_pinyin(word) - print(f"词语:{word}") - print(f"拼音:{' '.join(word_pinyin)}") - print("---") + # 创建包含错别字的句子 + start_time = time.time() + typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) + + # 打印结果 + print("\n原句:", sentence) + print("错字版:", typo_sentence) + + # 打印错别字信息 + if typo_info: + print("\n错别字信息:") + print(typo_generator.format_typo_info(typo_info)) # 计算并打印总耗时 end_time = time.time() diff --git a/config/auto_format.py b/templete/auto_format.py similarity index 100% rename from config/auto_format.py rename to templete/auto_format.py diff --git a/config/bot_config_template.toml b/templete/bot_config_template.toml similarity index 60% rename from config/bot_config_template.toml rename to templete/bot_config_template.toml index 28ffb0ce3..e6246be07 100644 --- a/config/bot_config_template.toml +++ b/templete/bot_config_template.toml @@ -20,14 +20,18 @@ ban_words = [ [emoji] check_interval = 120 # 检查表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔 +auto_save = true # 自动偷表情包 +enable_check = false # 是否启用表情包过滤 +check_prompt = "符合公序良俗" # 表情包过滤要求 [cq_code] enable_pic_translate = false [response] -model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 -model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 -model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 +model_r1_probability = 0.8 # 麦麦回答时选择主要回复模型1 模型的概率 +model_v3_probability = 0.1 # 麦麦回答时选择次要回复模型2 模型的概率 +model_r1_distill_probability = 0.1 # 麦麦回答时选择次要回复模型3 模型的概率 +max_response_length = 1024 # 麦麦回答的最大token数 [memory] build_memory_interval = 300 # 记忆构建间隔 单位秒 @@ -58,17 +62,23 @@ ban_user_id = [] #禁止回复消息的QQ号 #下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env.prod自定义的宏,使用自定义模型则选择定位相似的模型自己填写 -[model.llm_reasoning] #R1 +#推理模型: + +[model.llm_reasoning] #回复模型1 主要回复模型 name = "Pro/deepseek-ai/DeepSeek-R1" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" +pri_in = 0 #模型的输入价格(非必填,可以记录消耗) +pri_out = 0 #模型的输出价格(非必填,可以记录消耗) -[model.llm_reasoning_minor] #R1蒸馏 +[model.llm_reasoning_minor] #回复模型3 次要回复模型 name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" -[model.llm_normal] #V3 +#非推理模型 + +[model.llm_normal] #V3 回复模型2 次要回复模型 name = "Pro/deepseek-ai/DeepSeek-V3" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" @@ -78,21 +88,42 @@ name = "deepseek-ai/DeepSeek-V2.5" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" -[model.vlm] #图像识别 -name = "deepseek-ai/deepseek-vl2" +[model.llm_emotion_judge] #主题判断 0.7/m +name = "Qwen/Qwen2.5-14B-Instruct" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" +[model.llm_topic_judge] #主题判断:建议使用qwen2.5 7b +name = "Pro/Qwen/Qwen2.5-7B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + +[model.llm_summary_by_topic] #建议使用qwen2.5 32b 及以上 +name = "Qwen/Qwen2.5-32B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +pri_in = 0 +pri_out = 0 + +[model.moderation] #内容审核 未启用 +name = "" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" +pri_in = 0 +pri_out = 0 + +# 识图模型 + +[model.vlm] #图像识别 0.35/m +name = "Pro/Qwen/Qwen2-VL-7B-Instruct" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" + + + +#嵌入模型 + [model.embedding] #嵌入 name = "BAAI/bge-m3" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" - -# 主题提取,jieba和snownlp不用api,llm需要api -[topic] -topic_extract='snownlp' # 只支持jieba,snownlp,llm三种选项 - -[topic.llm_topic] -name = "Pro/deepseek-ai/DeepSeek-V3" -base_url = "SILICONFLOW_BASE_URL" -key = "SILICONFLOW_KEY"