6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.git
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.DS_Store
|
||||||
26
.env
Normal file
26
.env
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
||||||
|
ENVIRONMENT=dev
|
||||||
|
# HOST=127.0.0.1
|
||||||
|
# PORT=8080
|
||||||
|
|
||||||
|
# COMMAND_START=["/"]
|
||||||
|
|
||||||
|
# # 插件配置
|
||||||
|
# PLUGINS=["src2.plugins.chat"]
|
||||||
|
|
||||||
|
# # 默认配置
|
||||||
|
# MONGODB_HOST=127.0.0.1
|
||||||
|
# MONGODB_PORT=27017
|
||||||
|
# DATABASE_NAME=MegBot
|
||||||
|
|
||||||
|
# MONGODB_USERNAME = "" # 默认空值
|
||||||
|
# MONGODB_PASSWORD = "" # 默认空值
|
||||||
|
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
|
# #key and url
|
||||||
|
# CHAT_ANY_WHERE_KEY=
|
||||||
|
# SILICONFLOW_KEY=
|
||||||
|
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
|
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
|
# DEEP_SEEK_KEY=
|
||||||
|
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
ENVIRONMENT=dev
|
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
@@ -11,15 +10,17 @@ PLUGINS=["src2.plugins.chat"]
|
|||||||
MONGODB_HOST=127.0.0.1
|
MONGODB_HOST=127.0.0.1
|
||||||
MONGODB_PORT=27017
|
MONGODB_PORT=27017
|
||||||
DATABASE_NAME=MegBot
|
DATABASE_NAME=MegBot
|
||||||
|
|
||||||
MONGODB_USERNAME = "" # 默认空值
|
MONGODB_USERNAME = "" # 默认空值
|
||||||
MONGODB_PASSWORD = "" # 默认空值
|
MONGODB_PASSWORD = "" # 默认空值
|
||||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
#api配置项
|
#key and url
|
||||||
SILICONFLOW_KEY=
|
|
||||||
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_KEY=
|
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
DEEP_SEEK_KEY=
|
||||||
|
CHAT_ANY_WHERE_KEY=
|
||||||
|
SILICONFLOW_KEY=
|
||||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -2,19 +2,15 @@ data/
|
|||||||
mongodb/
|
mongodb/
|
||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
log/
|
log/
|
||||||
src/plugins/memory
|
|
||||||
config/bot_config.toml
|
|
||||||
/test
|
/test
|
||||||
message_queue_content.txt
|
message_queue_content.txt
|
||||||
message_queue_content.bat
|
message_queue_content.bat
|
||||||
message_queue_window.bat
|
message_queue_window.bat
|
||||||
message_queue_window.txt
|
message_queue_window.txt
|
||||||
reasoning_content.txt
|
|
||||||
reasoning_content.bat
|
|
||||||
reasoning_window.bat
|
|
||||||
queue_update.txt
|
queue_update.txt
|
||||||
memory_graph.gml
|
memory_graph.gml
|
||||||
|
.env.*
|
||||||
|
config/bot_config_dev.toml
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
@@ -146,7 +142,6 @@ celerybeat.pid
|
|||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
@@ -187,3 +182,4 @@ cython_debug/
|
|||||||
|
|
||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
|
.env
|
||||||
|
|||||||
16
Dockerfile
16
Dockerfile
@@ -1,10 +1,18 @@
|
|||||||
FROM nonebot/nb-cli:latest
|
FROM nonebot/nb-cli:latest
|
||||||
WORKDIR /
|
|
||||||
COPY . /MaiMBot/
|
# 设置工作目录
|
||||||
WORKDIR /MaiMBot
|
WORKDIR /MaiMBot
|
||||||
RUN mv config/env.example config/.env \
|
|
||||||
&& mv config/bot_config_toml config/bot_config.toml
|
# 先复制依赖列表
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# 安装依赖(这层会被缓存直到requirements.txt改变)
|
||||||
RUN pip install --upgrade -r requirements.txt
|
RUN pip install --upgrade -r requirements.txt
|
||||||
|
|
||||||
|
# 然后复制项目代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
VOLUME [ "/MaiMBot/config" ]
|
VOLUME [ "/MaiMBot/config" ]
|
||||||
|
VOLUME [ "/MaiMBot/data" ]
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
ENTRYPOINT [ "nb","run" ]
|
ENTRYPOINT [ "nb","run" ]
|
||||||
176
README.md
176
README.md
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
@@ -12,165 +11,33 @@
|
|||||||
|
|
||||||
## 📝 项目简介
|
## 📝 项目简介
|
||||||
|
|
||||||
**麦麦qq机器人的源代码仓库**
|
**🍔麦麦是一个基于大语言模型的智能QQ群聊机器人**
|
||||||
|
|
||||||
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
|
- 🤖 基于 nonebot2 框架开发
|
||||||
|
- 🧠 LLM 提供对话能力
|
||||||
|
- 💾 MongoDB 提供数据持久化支持
|
||||||
|
- 🐧 NapCat 作为QQ协议端支持
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||||
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
|
<img src="docs/video.png" width="300" alt="麦麦演示视频">
|
||||||
<br>
|
<br>
|
||||||
👆 点击观看麦麦演示视频 👆
|
👆 点击观看麦麦演示视频 👆
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
|
> ⚠️ **注意事项**
|
||||||
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
> - 项目处于活跃开发阶段,代码可能随时更改
|
||||||
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
|
||||||
|
> - QQ机器人存在被限制风险,请自行了解,谨慎使用
|
||||||
|
> - 由于持续迭代,可能存在一些已知或未知的bug
|
||||||
|
|
||||||
关于麦麦的开发和建议相关的讨论群(不建议发布无关消息)这里不会有麦麦发言!
|
**交流群**: 766798517(仅用于开发和建议相关讨论)
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
## 📚 文档
|
||||||
|
|
||||||
- 兼容gif的解析和保存
|
- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明
|
||||||
- 小程序转发链接解析
|
- [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节
|
||||||
- 对思考链长度限制
|
|
||||||
- 修复已知bug
|
|
||||||
- 完善文档
|
|
||||||
- 修复转发
|
|
||||||
- config自动生成和检测
|
|
||||||
- log别用print
|
|
||||||
- 给发送消息写专门的类
|
|
||||||
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="docs/qq.png" width="300" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## 📚 详细文档
|
|
||||||
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
|
|
||||||
|
|
||||||
### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!)
|
|
||||||
|
|
||||||
#### Linux 使用 Docker Compose 部署
|
|
||||||
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
|
|
||||||
```bash
|
|
||||||
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
|
|
||||||
```
|
|
||||||
配置文件修改完成后,运行以下命令
|
|
||||||
```bash
|
|
||||||
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 手动运行
|
|
||||||
1. **创建Python环境**
|
|
||||||
推荐使用conda或其他虚拟环境进行依赖安装,防止出现依赖版本冲突问题
|
|
||||||
```bash
|
|
||||||
# 安装requirements
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
2. **MongoDB设置**
|
|
||||||
- 安装并运行mongodb
|
|
||||||
- 麦麦bot会自动连接默认的mongodb,端口和数据库名可配置
|
|
||||||
|
|
||||||
3. **Napcat配置**
|
|
||||||
- 安装并运行Napcat,登录
|
|
||||||
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
|
|
||||||
|
|
||||||
4. **配置文件设置**
|
|
||||||
- 将.env文件打开,填上你的apikey(硅基流动或deepseekapi)
|
|
||||||
- 将bot_config.toml文件打开,并填写相关内容,不然无法正常运行
|
|
||||||
|
|
||||||
#### .env 文件配置说明
|
|
||||||
```ini
|
|
||||||
# 环境配置
|
|
||||||
ENVIRONMENT=dev # 开发环境设置
|
|
||||||
HOST=127.0.0.1 # 主机地址
|
|
||||||
PORT=8080 # 端口号
|
|
||||||
|
|
||||||
# 命令前缀设置
|
|
||||||
COMMAND_START=["/"] # 命令起始符
|
|
||||||
|
|
||||||
# 插件配置
|
|
||||||
PLUGINS=["src2.plugins.chat"] # 启用的插件列表
|
|
||||||
|
|
||||||
# MongoDB配置
|
|
||||||
MONGODB_HOST=127.0.0.1 # MongoDB主机地址
|
|
||||||
MONGODB_PORT=27017 # MongoDB端口
|
|
||||||
DATABASE_NAME=MegBot # 数据库名称
|
|
||||||
MONGODB_USERNAME="" # MongoDB用户名(可选)
|
|
||||||
MONGODB_PASSWORD="" # MongoDB密码(可选)
|
|
||||||
MONGODB_AUTH_SOURCE="" # MongoDB认证源(可选)
|
|
||||||
|
|
||||||
#api配置项,建议siliconflow必填,识图需要这个
|
|
||||||
SILICONFLOW_KEY=
|
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
|
||||||
DEEP_SEEK_KEY=
|
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### bot_config.toml 文件配置说明
|
|
||||||
```toml
|
|
||||||
# 数据库设置
|
|
||||||
[database]
|
|
||||||
host = "127.0.0.1" # MongoDB主机地址
|
|
||||||
port = 27017 # MongoDB端口
|
|
||||||
name = "MegBot" # 数据库名称
|
|
||||||
|
|
||||||
# 机器人基本设置
|
|
||||||
[bot]
|
|
||||||
qq = # 你的机器人QQ号(必填)
|
|
||||||
nickname = "麦麦" # 机器人昵称
|
|
||||||
|
|
||||||
# 消息处理设置
|
|
||||||
[message]
|
|
||||||
min_text_length = 2 # 最小响应文本长度
|
|
||||||
max_context_size = 15 # 上下文最大保存数量
|
|
||||||
emoji_chance = 0.2 # 表情包使用概率
|
|
||||||
|
|
||||||
# 表情包功能设置
|
|
||||||
[emoji]
|
|
||||||
check_interval = 120 # 表情检查间隔(秒)
|
|
||||||
register_interval = 10 # 表情注册间隔(秒)
|
|
||||||
|
|
||||||
# CQ码设置
|
|
||||||
[cq_code]
|
|
||||||
enable_pic_translate = false # 是否启用图片转换(无效)
|
|
||||||
|
|
||||||
# 响应设置
|
|
||||||
[response]
|
|
||||||
api_using = "siliconflow" # 回复使用的API(siliconflow/deepseek)
|
|
||||||
model_r1_probability = 0.8 # R1模型使用概率
|
|
||||||
model_v3_probability = 0.1 # V3模型使用概率
|
|
||||||
model_r1_distill_probability = 0.1 # R1蒸馏模型使用概率(对deepseek api 无效)
|
|
||||||
|
|
||||||
# 其他设置
|
|
||||||
[others]
|
|
||||||
enable_advance_output = false # 是否启用详细日志输出
|
|
||||||
|
|
||||||
# 群组设置
|
|
||||||
[groups]
|
|
||||||
talk_allowed = [ # 允许回复的群号列表
|
|
||||||
# 在这里添加群号,逗号隔开
|
|
||||||
]
|
|
||||||
|
|
||||||
talk_frequency_down = [ # 降低回复频率的群号列表
|
|
||||||
# 在这里添加群号,逗号隔开
|
|
||||||
]
|
|
||||||
|
|
||||||
ban_user_id = [ # 禁止回复的用户QQ号列表
|
|
||||||
# 在这里添加QQ号,逗号隔开
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
5. **运行麦麦**
|
|
||||||
在含有bot.py程序的目录下运行(如果使用了虚拟环境需要先进入虚拟环境)
|
|
||||||
```bash
|
|
||||||
nb run
|
|
||||||
```
|
|
||||||
6. **运行其他组件**
|
|
||||||
run_thingking.bat 可以启动可视化的推理界面(未完善)和消息队列及其他信息预览(WIP)
|
|
||||||
knowledge.bat可以将/data/raw_info下的文本文档载入到数据库(未启动)
|
|
||||||
|
|
||||||
## 🎯 功能介绍
|
## 🎯 功能介绍
|
||||||
|
|
||||||
@@ -206,6 +73,19 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|||||||
- 幽默和meme功能:WIP的WIP
|
- 幽默和meme功能:WIP的WIP
|
||||||
- 让麦麦玩mc:WIP的WIP的WIP
|
- 让麦麦玩mc:WIP的WIP的WIP
|
||||||
|
|
||||||
|
## 开发计划TODO:LIST
|
||||||
|
|
||||||
|
- 兼容gif的解析和保存
|
||||||
|
- 小程序转发链接解析
|
||||||
|
- 对思考链长度限制
|
||||||
|
- 修复已知bug
|
||||||
|
- 完善文档
|
||||||
|
- 修复转发
|
||||||
|
- config自动生成和检测
|
||||||
|
- log别用print
|
||||||
|
- 给发送消息写专门的类
|
||||||
|
- 改进表情包发送逻辑
|
||||||
|
|
||||||
## 📌 注意事项
|
## 📌 注意事项
|
||||||
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
||||||
|
|
||||||
|
|||||||
61
bot.py
61
bot.py
@@ -4,28 +4,50 @@ from nonebot.adapters.onebot.v11 import Adapter
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# 加载全局环境变量
|
'''彩蛋'''
|
||||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
from colorama import init, Fore
|
||||||
env_path=os.path.join(root_dir, "config",'.env')
|
init()
|
||||||
|
text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||||
|
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
||||||
|
rainbow_text = ""
|
||||||
|
for i, char in enumerate(text):
|
||||||
|
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
||||||
|
print(rainbow_text)
|
||||||
|
'''彩蛋'''
|
||||||
|
|
||||||
logger.info(f"尝试从 {env_path} 加载环境变量配置")
|
# 首先加载基础环境变量
|
||||||
if os.path.exists(env_path):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(env_path)
|
load_dotenv(".env")
|
||||||
logger.success("成功加载环境变量配置")
|
logger.success("成功加载基础环境变量配置")
|
||||||
else:
|
else:
|
||||||
logger.error(f"环境变量配置文件不存在: {env_path}")
|
logger.error("基础环境变量配置文件 .env 不存在")
|
||||||
|
exit(1)
|
||||||
|
# 根据 ENVIRONMENT 加载对应的环境配置
|
||||||
|
env = os.getenv("ENVIRONMENT")
|
||||||
|
env_file = f".env.{env}"
|
||||||
|
|
||||||
# 初始化 NoneBot
|
if env_file == ".env.dev" and os.path.exists(env_file):
|
||||||
nonebot.init(
|
logger.success("加载开发环境变量配置")
|
||||||
# napcat 默认使用 8080 端口
|
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
websocket_port=8080,
|
elif os.path.exists(".env.prod"):
|
||||||
# 设置日志级别
|
logger.success("加载环境变量配置")
|
||||||
log_level="INFO",
|
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
# 设置超级用户
|
else:
|
||||||
superusers={"你的QQ号"},
|
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||||
# TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config
|
exit(1)
|
||||||
_env_file=env_path
|
|
||||||
)
|
# 获取所有环境变量
|
||||||
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
|
||||||
|
# 设置基础配置
|
||||||
|
base_config = {
|
||||||
|
"websocket_port": int(env_config.get("PORT", 8080)),
|
||||||
|
"host": env_config.get("HOST", "127.0.0.1"),
|
||||||
|
"log_level": "INFO",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 合并配置
|
||||||
|
nonebot.init(**base_config, **env_config)
|
||||||
|
|
||||||
# 注册适配器
|
# 注册适配器
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
@@ -35,4 +57,5 @@ driver.register_adapter(Adapter)
|
|||||||
nonebot.load_plugins("src/plugins")
|
nonebot.load_plugins("src/plugins")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
nonebot.run()
|
nonebot.run()
|
||||||
12012
char_frequency.json
Normal file
12012
char_frequency.json
Normal file
File diff suppressed because it is too large
Load Diff
46
config/auto_format.py
Normal file
46
config/auto_format.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import tomli
|
||||||
|
import tomli_w
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
def sync_configs():
|
||||||
|
# 读取两个配置文件
|
||||||
|
try:
|
||||||
|
with open('bot_config_dev.toml', 'rb') as f: # tomli需要使用二进制模式读取
|
||||||
|
dev_config = tomli.load(f)
|
||||||
|
|
||||||
|
with open('bot_config.toml', 'rb') as f:
|
||||||
|
prod_config = tomli.load(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f"错误:找不到配置文件 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except tomli.TOMLDecodeError as e:
|
||||||
|
print(f"错误:TOML格式解析失败 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 递归合并配置
|
||||||
|
def merge_configs(source, target):
|
||||||
|
for key, value in source.items():
|
||||||
|
if key not in target:
|
||||||
|
target[key] = value
|
||||||
|
elif isinstance(value, dict) and isinstance(target[key], dict):
|
||||||
|
merge_configs(value, target[key])
|
||||||
|
|
||||||
|
# 将dev配置的新属性合并到prod配置中
|
||||||
|
merge_configs(dev_config, prod_config)
|
||||||
|
|
||||||
|
# 保存更新后的配置
|
||||||
|
try:
|
||||||
|
with open('bot_config.toml', 'wb') as f: # tomli_w需要使用二进制模式写入
|
||||||
|
tomli_w.dump(prod_config, f)
|
||||||
|
print("配置文件同步完成!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误:保存配置文件失败 - {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 确保在正确的目录下运行
|
||||||
|
script_dir = Path(__file__).parent
|
||||||
|
os.chdir(script_dir)
|
||||||
|
sync_configs()
|
||||||
61
config/bot_config.toml
Normal file
61
config/bot_config.toml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
[bot]
|
||||||
|
qq = 123
|
||||||
|
nickname = "麦麦"
|
||||||
|
|
||||||
|
[message]
|
||||||
|
min_text_length = 2
|
||||||
|
max_context_size = 15
|
||||||
|
emoji_chance = 0.2
|
||||||
|
|
||||||
|
[emoji]
|
||||||
|
check_interval = 120
|
||||||
|
register_interval = 10
|
||||||
|
|
||||||
|
[cq_code]
|
||||||
|
enable_pic_translate = false
|
||||||
|
|
||||||
|
[response]
|
||||||
|
api_using = "siliconflow"
|
||||||
|
api_paid = true
|
||||||
|
model_r1_probability = 0.8
|
||||||
|
model_v3_probability = 0.1
|
||||||
|
model_r1_distill_probability = 0.1
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300
|
||||||
|
|
||||||
|
[others]
|
||||||
|
enable_advance_output = true
|
||||||
|
|
||||||
|
[groups]
|
||||||
|
talk_allowed = [
|
||||||
|
123,
|
||||||
|
123,
|
||||||
|
]
|
||||||
|
talk_frequency_down = []
|
||||||
|
ban_user_id = []
|
||||||
|
|
||||||
|
[model.llm_reasoning]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.vlm]
|
||||||
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
[bot]
|
|
||||||
qq = 123456 #填入你的机器人QQ
|
|
||||||
nickname = "麦麦" #你希望bot被称呼的名字
|
|
||||||
|
|
||||||
[message]
|
|
||||||
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
|
|
||||||
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
|
|
||||||
emoji_chance = 0.2 # 麦麦使用表情包的概率
|
|
||||||
|
|
||||||
[emoji]
|
|
||||||
check_interval = 120
|
|
||||||
register_interval = 10
|
|
||||||
|
|
||||||
[cq_code]
|
|
||||||
enable_pic_translate = false
|
|
||||||
|
|
||||||
|
|
||||||
[response]
|
|
||||||
api_using = "siliconflow" # 选择大模型API,可选值为siliconflow,deepseek,建议使用siliconflow,因为识图api目前只支持siliconflow的deepseek-vl2模型
|
|
||||||
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
|
||||||
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
|
||||||
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
|
||||||
|
|
||||||
[memory]
|
|
||||||
build_memory_interval = 300 # 记忆构建间隔
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[others]
|
|
||||||
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
|
|
||||||
|
|
||||||
|
|
||||||
[groups]
|
|
||||||
|
|
||||||
talk_allowed = [
|
|
||||||
123456,12345678
|
|
||||||
] #可以回复消息的群
|
|
||||||
|
|
||||||
talk_frequency_down = [
|
|
||||||
123456,12345678
|
|
||||||
] #降低回复频率的群
|
|
||||||
|
|
||||||
ban_user_id = [
|
|
||||||
123456,12345678
|
|
||||||
] #禁止回复消息的QQ号
|
|
||||||
@@ -27,7 +27,7 @@ services:
|
|||||||
- mongodb:/data/db
|
- mongodb:/data/db
|
||||||
- mongodbCONFIG:/data/configdb
|
- mongodbCONFIG:/data/configdb
|
||||||
image: mongo:latest
|
image: mongo:latest
|
||||||
|
|
||||||
maimbot:
|
maimbot:
|
||||||
container_name: maimbot
|
container_name: maimbot
|
||||||
environment:
|
environment:
|
||||||
@@ -41,8 +41,8 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- maimbotCONFIG:/MaiMBot/config
|
- maimbotCONFIG:/MaiMBot/config
|
||||||
- maimbotDATA:/MaiMBot/data
|
- maimbotDATA:/MaiMBot/data
|
||||||
|
- ./.env.prod:/MaiMBot/.env.prod
|
||||||
image: sengokucola/maimbot:latest
|
image: sengokucola/maimbot:latest
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
maimbotCONFIG:
|
maimbotCONFIG:
|
||||||
@@ -51,4 +51,5 @@ volumes:
|
|||||||
napcatCONFIG:
|
napcatCONFIG:
|
||||||
mongodb:
|
mongodb:
|
||||||
mongodbCONFIG:
|
mongodbCONFIG:
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
102
docs/installation.md
Normal file
102
docs/installation.md
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# 🔧 安装与配置指南
|
||||||
|
|
||||||
|
## 部署方式
|
||||||
|
|
||||||
|
### 🐳 Docker部署(推荐)
|
||||||
|
|
||||||
|
1. 获取配置文件:
|
||||||
|
```bash
|
||||||
|
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动服务:
|
||||||
|
```bash
|
||||||
|
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 修改配置后重启:
|
||||||
|
```bash
|
||||||
|
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📦 手动部署
|
||||||
|
|
||||||
|
1. **环境准备**
|
||||||
|
```bash
|
||||||
|
# 创建虚拟环境(推荐)
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # Linux
|
||||||
|
venv\\Scripts\\activate # Windows
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **配置MongoDB**
|
||||||
|
- 安装并启动MongoDB服务
|
||||||
|
- 默认连接本地27017端口
|
||||||
|
|
||||||
|
3. **配置NapCat**
|
||||||
|
- 安装并登录NapCat
|
||||||
|
- 添加反向WS:`ws://localhost:8080/onebot/v11/ws`
|
||||||
|
|
||||||
|
4. **配置文件设置**
|
||||||
|
- 复制并修改环境配置:`.env.prod`
|
||||||
|
- 复制并修改机器人配置:`bot_config.toml`
|
||||||
|
|
||||||
|
5. **启动服务**
|
||||||
|
```bash
|
||||||
|
nb run
|
||||||
|
```
|
||||||
|
|
||||||
|
6. **其他组件**
|
||||||
|
- `run_thingking.bat`: 启动可视化推理界面(未完善)和消息队列预览
|
||||||
|
- `knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库
|
||||||
|
|
||||||
|
## ⚙️ 配置说明
|
||||||
|
|
||||||
|
### 环境配置 (.env.prod)
|
||||||
|
```ini
|
||||||
|
# API配置(必填)
|
||||||
|
SILICONFLOW_KEY=your_key
|
||||||
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
|
DEEP_SEEK_KEY=your_key
|
||||||
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
# 服务配置
|
||||||
|
HOST=127.0.0.1
|
||||||
|
PORT=8080
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
MONGODB_HOST=127.0.0.1
|
||||||
|
MONGODB_PORT=27017
|
||||||
|
DATABASE_NAME=MegBot
|
||||||
|
```
|
||||||
|
|
||||||
|
### 机器人配置 (bot_config.toml)
|
||||||
|
```toml
|
||||||
|
[bot]
|
||||||
|
qq = "你的机器人QQ号"
|
||||||
|
nickname = "麦麦"
|
||||||
|
|
||||||
|
[message]
|
||||||
|
max_context_size = 15
|
||||||
|
emoji_chance = 0.2
|
||||||
|
|
||||||
|
[response]
|
||||||
|
api_using = "siliconflow" # 或 "deepseek"
|
||||||
|
|
||||||
|
[others]
|
||||||
|
enable_advance_output = false # 是否启用详细日志输出
|
||||||
|
|
||||||
|
[groups]
|
||||||
|
talk_allowed = [] # 允许回复的群号列表
|
||||||
|
talk_frequency_down = [] # 降低回复频率的群号列表
|
||||||
|
ban_user_id = [] # 禁止回复的用户QQ号列表
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚠️ 注意事项
|
||||||
|
|
||||||
|
- 目前部署方案仍在测试中,可能存在未知问题
|
||||||
|
- 配置文件中的API密钥请妥善保管,不要泄露
|
||||||
|
- 建议先在测试环境中运行,确认无误后再部署到生产环境
|
||||||
BIN
docs/qq.png
BIN
docs/qq.png
Binary file not shown.
|
Before Width: | Height: | Size: 191 KiB |
BIN
docs/video.png
Normal file
BIN
docs/video.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -2,4 +2,5 @@ call conda activate niuniu
|
|||||||
cd .
|
cd .
|
||||||
|
|
||||||
REM 执行nb run命令
|
REM 执行nb run命令
|
||||||
nb run
|
nb run
|
||||||
|
pause
|
||||||
@@ -7,6 +7,23 @@ import threading
|
|||||||
import queue
|
import queue
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 获取当前文件的目录
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# 获取项目根目录
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
||||||
|
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
||||||
|
print("成功加载开发环境配置")
|
||||||
|
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
||||||
|
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
||||||
|
print("成功加载生产环境配置")
|
||||||
|
else:
|
||||||
|
print("未找到环境配置文件")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -14,14 +31,23 @@ from typing import Optional
|
|||||||
class Database:
|
class Database:
|
||||||
_instance: Optional["Database"] = None
|
_instance: Optional["Database"] = None
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, db_name: str):
|
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
|
||||||
self.client = MongoClient(host, port)
|
if username and password:
|
||||||
|
self.client = MongoClient(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
authSource=auth_source or 'admin'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = MongoClient(host, port)
|
||||||
self.db = self.client[db_name]
|
self.db = self.client[db_name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, host: str, port: int, db_name: str) -> "Database":
|
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls(host, port, db_name)
|
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
|
||||||
|
|
||||||
# 获取驱动器
|
# 获取驱动器
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source= config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ emoji_manager.initialize()
|
|||||||
|
|
||||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
||||||
# 创建机器人实例
|
# 创建机器人实例
|
||||||
chat_bot = ChatBot(global_config)
|
chat_bot = ChatBot()
|
||||||
# 注册消息处理器
|
# 注册消息处理器
|
||||||
group_msg = on_message()
|
group_msg = on_message()
|
||||||
# 创建定时任务
|
# 创建定时任务
|
||||||
@@ -50,6 +52,7 @@ async def start_background_tasks():
|
|||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
# 只启动表情包管理任务
|
# 只启动表情包管理任务
|
||||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||||
|
await bot_schedule.initialize()
|
||||||
bot_schedule.print_schedule()
|
bot_schedule.print_schedule()
|
||||||
|
|
||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
@@ -88,7 +91,7 @@ async def monitor_relationships():
|
|||||||
async def build_memory_task():
|
async def build_memory_task():
|
||||||
"""每30秒执行一次记忆构建"""
|
"""每30秒执行一次记忆构建"""
|
||||||
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
||||||
hippocampus.build_memory(chat_size=12)
|
await hippocampus.build_memory(chat_size=30)
|
||||||
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessa
|
|||||||
from .message import Message,MessageSet
|
from .message import Message,MessageSet
|
||||||
from .config import BotConfig, global_config
|
from .config import BotConfig, global_config
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .llm_generator import LLMResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message_stream import MessageStream, MessageStreamContainer
|
from .message_stream import MessageStream, MessageStreamContainer
|
||||||
from .topic_identifier import topic_identifier
|
from .topic_identifier import topic_identifier
|
||||||
from random import random, choice
|
from random import random, choice
|
||||||
@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
|||||||
from ..memory_system.memory import memory_graph
|
from ..memory_system.memory import memory_graph
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self):
|
||||||
self.config = config
|
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
self.gpt = LLMResponseGenerator(config)
|
self.gpt = ResponseGenerator()
|
||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
@@ -39,11 +38,11 @@ class ChatBot:
|
|||||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||||
"""处理收到的群消息"""
|
"""处理收到的群消息"""
|
||||||
|
|
||||||
if event.group_id not in self.config.talk_allowed_groups:
|
if event.group_id not in global_config.talk_allowed_groups:
|
||||||
return
|
return
|
||||||
self.bot = bot # 更新 bot 实例
|
self.bot = bot # 更新 bot 实例
|
||||||
|
|
||||||
if event.user_id in self.config.ban_user_id:
|
if event.user_id in global_config.ban_user_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 打印原始消息内容
|
# 打印原始消息内容
|
||||||
@@ -120,7 +119,7 @@ class ChatBot:
|
|||||||
event.group_id,
|
event.group_id,
|
||||||
topic[0] if topic else None,
|
topic[0] if topic else None,
|
||||||
is_mentioned,
|
is_mentioned,
|
||||||
self.config,
|
global_config,
|
||||||
event.user_id,
|
event.user_id,
|
||||||
message.is_emoji,
|
message.is_emoji,
|
||||||
interested_rate
|
interested_rate
|
||||||
@@ -143,10 +142,14 @@ class ChatBot:
|
|||||||
response, emotion = await self.gpt.generate_response(message)
|
response, emotion = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
# 如果生成了回复,发送并记录
|
# 如果生成了回复,发送并记录
|
||||||
|
|
||||||
|
'''
|
||||||
|
生成回复后的内容
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
|
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
for msg in response:
|
for msg in response:
|
||||||
print(f"当前消息: {msg}")
|
print(f"当前消息: {msg}")
|
||||||
@@ -157,7 +160,7 @@ class ChatBot:
|
|||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=self.config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
message_based_id=event.message_id,
|
message_based_id=event.message_id,
|
||||||
raw_message=msg,
|
raw_message=msg,
|
||||||
@@ -174,7 +177,7 @@ class ChatBot:
|
|||||||
|
|
||||||
|
|
||||||
bot_response_time = tinking_time_point
|
bot_response_time = tinking_time_point
|
||||||
if random() < self.config.emoji_chance:
|
if random() < global_config.emoji_chance:
|
||||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
||||||
if emoji_path:
|
if emoji_path:
|
||||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||||
@@ -186,7 +189,7 @@ class ChatBot:
|
|||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=self.config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=0,
|
message_id=0,
|
||||||
raw_message=emoji_cq,
|
raw_message=emoji_cq,
|
||||||
plain_text=emoji_cq,
|
plain_text=emoji_cq,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Any, Optional, Set
|
from typing import Dict, Any, Optional, Set
|
||||||
import os
|
import os
|
||||||
from nonebot.log import logger, default_format
|
from nonebot.log import logger, default_format
|
||||||
@@ -7,6 +7,7 @@ import configparser
|
|||||||
import tomli
|
import tomli
|
||||||
import sys
|
import sys
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -31,28 +32,36 @@ class BotConfig:
|
|||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
API_USING: str = "siliconflow" # 使用的API
|
API_USING: str = "siliconflow" # 使用的API
|
||||||
|
API_PAID: bool = False # 是否使用付费API
|
||||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||||
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||||||
|
|
||||||
enable_advance_output: bool = False # 是否启用高级输出
|
enable_advance_output: bool = False # 是否启用高级输出
|
||||||
|
enable_kuuki_read: bool = True # 是否启用读空气功能
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_config_path() -> str:
|
def get_config_dir() -> str:
|
||||||
"""获取默认配置文件路径"""
|
"""获取配置文件目录"""
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||||
config_dir = os.path.join(root_dir, 'config')
|
config_dir = os.path.join(root_dir, 'config')
|
||||||
return os.path.join(config_dir, 'bot_config.toml')
|
if not os.path.exists(config_dir):
|
||||||
|
os.makedirs(config_dir)
|
||||||
|
return config_dir
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, config_path: str = None) -> "BotConfig":
|
def load_config(cls, config_path: str = None) -> "BotConfig":
|
||||||
"""从TOML配置文件加载配置"""
|
"""从TOML配置文件加载配置"""
|
||||||
if config_path is None:
|
|
||||||
config_path = cls.get_default_config_path()
|
|
||||||
logger.info(f"使用默认配置文件路径: {config_path}")
|
|
||||||
|
|
||||||
config = cls()
|
config = cls()
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
with open(config_path, "rb") as f:
|
with open(config_path, "rb") as f:
|
||||||
@@ -80,6 +89,26 @@ class BotConfig:
|
|||||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||||
config.API_USING = response_config.get("api_using", config.API_USING)
|
config.API_USING = response_config.get("api_using", config.API_USING)
|
||||||
|
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
||||||
|
|
||||||
|
# 加载模型配置
|
||||||
|
if "model" in toml_dict:
|
||||||
|
model_config = toml_dict["model"]
|
||||||
|
|
||||||
|
if "llm_reasoning" in model_config:
|
||||||
|
config.llm_reasoning = model_config["llm_reasoning"]
|
||||||
|
|
||||||
|
if "llm_reasoning_minor" in model_config:
|
||||||
|
config.llm_reasoning_minor = model_config["llm_reasoning_minor"]
|
||||||
|
|
||||||
|
if "llm_normal" in model_config:
|
||||||
|
config.llm_normal = model_config["llm_normal"]
|
||||||
|
|
||||||
|
if "llm_normal_minor" in model_config:
|
||||||
|
config.llm_normal_minor = model_config["llm_normal_minor"]
|
||||||
|
|
||||||
|
if "vlm" in model_config:
|
||||||
|
config.vlm = model_config["vlm"]
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
@@ -108,13 +137,21 @@ class BotConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
bot_config_path = BotConfig.get_default_config_path()
|
|
||||||
config_dir = os.path.dirname(bot_config_path)
|
|
||||||
env_path = os.path.join(config_dir, '.env')
|
|
||||||
|
|
||||||
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
bot_config_floder_path = BotConfig.get_config_dir()
|
||||||
|
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
|
||||||
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
||||||
|
if not os.path.exists(bot_config_path):
|
||||||
|
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||||
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||||
|
logger.info("使用默认配置文件")
|
||||||
|
else:
|
||||||
|
logger.info("已找到开发环境配置文件")
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMConfig:
|
class LLMConfig:
|
||||||
"""机器人配置类"""
|
"""机器人配置类"""
|
||||||
@@ -125,12 +162,14 @@ class LLMConfig:
|
|||||||
DEEP_SEEK_BASE_URL: str = None
|
DEEP_SEEK_BASE_URL: str = None
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
llm_config = LLMConfig()
|
||||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
config = get_driver().config
|
||||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
||||||
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
|
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
||||||
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
||||||
|
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# logger.remove()
|
# logger.remove()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,20 @@ from PIL import Image
|
|||||||
import os
|
import os
|
||||||
from random import random
|
from random import random
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
from .config import global_config, llm_config
|
from .config import global_config
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from .utils_image import storage_image,storage_emoji
|
from .utils_image import storage_image,storage_emoji
|
||||||
from .utils_user import get_user_nickname
|
from .utils_user import get_user_nickname
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
#解析各种CQ码
|
#解析各种CQ码
|
||||||
#包含CQ码类
|
#包含CQ码类
|
||||||
import urllib3
|
import urllib3
|
||||||
from urllib3.util import create_urllib3_context
|
from urllib3.util import create_urllib3_context
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
||||||
ctx = create_urllib3_context()
|
ctx = create_urllib3_context()
|
||||||
@@ -53,6 +58,11 @@ class CQCode:
|
|||||||
translated_plain_text: Optional[str] = None
|
translated_plain_text: Optional[str] = None
|
||||||
reply_message: Dict = None # 存储回复消息
|
reply_message: Dict = None # 存储回复消息
|
||||||
image_base64: Optional[str] = None
|
image_base64: Optional[str] = None
|
||||||
|
_llm: Optional[LLM_request] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""初始化LLM实例"""
|
||||||
|
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||||||
|
|
||||||
def translate(self):
|
def translate(self):
|
||||||
"""根据CQ码类型进行相应的翻译处理"""
|
"""根据CQ码类型进行相应的翻译处理"""
|
||||||
@@ -157,7 +167,7 @@ class CQCode:
|
|||||||
# 将 base64 字符串转换为字节类型
|
# 将 base64 字符串转换为字节类型
|
||||||
image_bytes = base64.b64decode(base64_str)
|
image_bytes = base64.b64decode(base64_str)
|
||||||
storage_emoji(image_bytes)
|
storage_emoji(image_bytes)
|
||||||
return self.get_image_description(base64_str)
|
return self.get_emoji_description(base64_str)
|
||||||
else:
|
else:
|
||||||
return '[表情包]'
|
return '[表情包]'
|
||||||
|
|
||||||
@@ -177,93 +187,23 @@ class CQCode:
|
|||||||
|
|
||||||
def get_emoji_description(self, image_base64: str) -> str:
|
def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取表情包描述"""
|
"""调用AI接口获取表情包描述"""
|
||||||
headers = {
|
try:
|
||||||
"Content-Type": "application/json",
|
prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
}
|
return f"[表情包:{description}]"
|
||||||
|
except Exception as e:
|
||||||
payload = {
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
return "[表情包]"
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 50,
|
|
||||||
"temperature": 0.4
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result_json = response.json()
|
|
||||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
|
||||||
description = result_json["choices"][0]["message"]["content"]
|
|
||||||
return f"[表情包:{description}]"
|
|
||||||
|
|
||||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
|
||||||
|
|
||||||
def get_image_description(self, image_base64: str) -> str:
|
def get_image_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取普通图片描述"""
|
"""调用AI接口获取普通图片描述"""
|
||||||
headers = {
|
try:
|
||||||
"Content-Type": "application/json",
|
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
}
|
return f"[图片:{description}]"
|
||||||
|
except Exception as e:
|
||||||
payload = {
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
return "[图片]"
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 300,
|
|
||||||
"temperature": 0.6
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result_json = response.json()
|
|
||||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
|
||||||
description = result_json["choices"][0]["message"]["content"]
|
|
||||||
return f"[图片:{description}]"
|
|
||||||
|
|
||||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
|
||||||
|
|
||||||
def translate_forward(self) -> str:
|
def translate_forward(self) -> str:
|
||||||
"""处理转发消息"""
|
"""处理转发消息"""
|
||||||
@@ -345,7 +285,7 @@ class CQCode:
|
|||||||
# 创建Message对象
|
# 创建Message对象
|
||||||
from .message import Message
|
from .message import Message
|
||||||
if self.reply_message == None:
|
if self.reply_message == None:
|
||||||
print(f"\033[1;31m[错误]\033[0m 回复消息为空")
|
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
|
||||||
return '[回复某人消息]'
|
return '[回复某人消息]'
|
||||||
|
|
||||||
if self.reply_message.sender.user_id:
|
if self.reply_message.sender.user_id:
|
||||||
|
|||||||
@@ -10,10 +10,16 @@ import hashlib
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import base64
|
import base64
|
||||||
import shutil
|
import shutil
|
||||||
from .config import global_config, llm_config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
|
from ..chat.config import global_config
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -39,6 +45,7 @@ class EmojiManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
|
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
@@ -83,55 +90,23 @@ class EmojiManager:
|
|||||||
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
|
||||||
|
|
||||||
async def _get_emotion_from_text(self, text: str) -> List[str]:
|
async def _get_emotion_from_text(self, text: str) -> List[str]:
|
||||||
"""从文本中识别情感关键词,使用DeepSeek API进行分析
|
"""从文本中识别情感关键词
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: 匹配到的情感标签列表
|
List[str]: 匹配到的情感标签列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 准备请求数据
|
prompt = f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。'
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
content, _ = await self.llm.generate_response(prompt)
|
||||||
"model": "deepseek-ai/DeepSeek-V3",
|
emotion = content.strip().lower()
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 50,
|
|
||||||
"temperature": 0.3
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
if emotion in self.EMOTION_KEYWORDS:
|
||||||
async with session.post(
|
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
return [emotion]
|
||||||
headers=headers,
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m API请求失败: {await response.text()}")
|
|
||||||
return ['neutral']
|
|
||||||
|
|
||||||
result = json.loads(await response.text())
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
emotion = result["choices"][0]["message"]["content"].strip().lower()
|
|
||||||
# 确保返回的标签是有效的
|
|
||||||
if emotion in self.EMOTION_KEYWORDS:
|
|
||||||
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
|
|
||||||
return [emotion] # 返回单个情感标签的列表
|
|
||||||
|
|
||||||
return ['neutral'] # 如果无法识别情感,返回neutral
|
return ['neutral']
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
|
||||||
@@ -246,52 +221,20 @@ class EmojiManager:
|
|||||||
|
|
||||||
async def _get_emoji_tag(self, image_base64: str) -> str:
|
async def _get_emoji_tag(self, image_base64: str) -> str:
|
||||||
"""获取表情包的标签"""
|
"""获取表情包的标签"""
|
||||||
async with aiohttp.ClientSession() as session:
|
try:
|
||||||
headers = {
|
prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好'
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
|
||||||
"model": "deepseek-ai/deepseek-vl2",
|
tag_result = content.strip().lower()
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 60,
|
|
||||||
"temperature": 0.3
|
|
||||||
}
|
|
||||||
|
|
||||||
async with session.post(
|
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
|
||||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
for tag_match in valid_tags:
|
||||||
headers=headers,
|
if tag_match in tag_result or tag_match == tag_result:
|
||||||
json=payload
|
return tag_match
|
||||||
) as response:
|
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过")
|
||||||
if response.status == 200:
|
|
||||||
result = await response.json()
|
except Exception as e:
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
|
||||||
tag_result = result["choices"][0]["message"]["content"].strip().lower()
|
|
||||||
|
|
||||||
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
|
|
||||||
for tag_match in valid_tags:
|
|
||||||
if tag_match in tag_result or tag_match == tag_result:
|
|
||||||
return tag_match
|
|
||||||
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_match}, 跳过")
|
|
||||||
else:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m 获取标签失败, 状态码: {response.status}")
|
|
||||||
|
|
||||||
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
||||||
return "skip" # 默认标签
|
return "skip" # 默认标签
|
||||||
|
|||||||
@@ -1,204 +1,125 @@
|
|||||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .config import BotConfig, global_config
|
from .config import global_config
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from ..schedule.schedule_generator import bot_schedule
|
|
||||||
from .prompt_builder import prompt_builder
|
from .prompt_builder import prompt_builder
|
||||||
from .config import llm_config, global_config
|
from .config import global_config
|
||||||
from .utils import process_llm_response
|
from .utils import process_llm_response
|
||||||
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class LLMResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self):
|
||||||
self.config = config
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
||||||
if self.config.API_USING == "siliconflow":
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
||||||
self.client = OpenAI(
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
|
||||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
|
||||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
|
||||||
)
|
|
||||||
elif self.config.API_USING == "deepseek":
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=llm_config.DEEP_SEEK_API_KEY,
|
|
||||||
base_url=llm_config.DEEP_SEEK_BASE_URL
|
|
||||||
)
|
|
||||||
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
# 当前使用的模型类型
|
|
||||||
self.current_model_type = 'r1' # 默认使用 R1
|
self.current_model_type = 'r1' # 默认使用 R1
|
||||||
|
|
||||||
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值
|
# 从global_config中获取模型概率值并选择模型
|
||||||
model_r1_probability = global_config.MODEL_R1_PROBABILITY
|
|
||||||
model_v3_probability = global_config.MODEL_V3_PROBABILITY
|
|
||||||
model_r1_distill_probability = global_config.MODEL_R1_DISTILL_PROBABILITY
|
|
||||||
|
|
||||||
# 生成随机数并根据概率选择模型
|
|
||||||
rand = random.random()
|
rand = random.random()
|
||||||
if rand < model_r1_probability:
|
if rand < global_config.MODEL_R1_PROBABILITY:
|
||||||
self.current_model_type = 'r1'
|
self.current_model_type = 'r1'
|
||||||
elif rand < model_r1_probability + model_v3_probability:
|
current_model = self.model_r1
|
||||||
|
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
|
||||||
self.current_model_type = 'v3'
|
self.current_model_type = 'v3'
|
||||||
|
current_model = self.model_v3
|
||||||
else:
|
else:
|
||||||
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
|
self.current_model_type = 'r1_distill'
|
||||||
|
current_model = self.model_r1_distill
|
||||||
|
|
||||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
||||||
if self.current_model_type == 'r1':
|
|
||||||
model_response = await self._generate_r1_response(message)
|
|
||||||
elif self.current_model_type == 'v3':
|
|
||||||
model_response = await self._generate_v3_response(message)
|
|
||||||
else:
|
|
||||||
model_response = await self._generate_r1_distill_response(message)
|
|
||||||
|
|
||||||
# 打印情感标签
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
|
||||||
model_response, emotion = await self._process_response(model_response)
|
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
||||||
|
model_response, emotion = await self._process_response(model_response)
|
||||||
return model_response, emotion
|
if model_response:
|
||||||
|
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
||||||
|
return model_response, emotion
|
||||||
|
return None, []
|
||||||
|
|
||||||
async def _generate_base_response(
|
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
||||||
self,
|
"""使用指定的模型生成回复"""
|
||||||
message: Message,
|
|
||||||
model_name: str,
|
|
||||||
model_params: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||||
|
|
||||||
# 获取关系值
|
# 获取关系值
|
||||||
if relationship_manager.get_relationship(message.user_id):
|
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
||||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
|
if relationship_value != 0.0:
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||||
else:
|
|
||||||
relationship_value = 0.0
|
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
prompt = prompt_builder._build_prompt(
|
prompt, prompt_check = prompt_builder._build_prompt(
|
||||||
message_txt=message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
relationship_value=relationship_value,
|
relationship_value=relationship_value,
|
||||||
group_id=message.group_id
|
group_id=message.group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 读空气模块
|
||||||
|
if global_config.enable_kuuki_read:
|
||||||
|
content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
|
||||||
|
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
||||||
|
if 'yes' not in content_check.lower() and random.random() < 0.3:
|
||||||
|
self._save_to_db(
|
||||||
|
message=message,
|
||||||
|
sender_name=sender_name,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_check=prompt_check,
|
||||||
|
content="",
|
||||||
|
content_check=content_check,
|
||||||
|
reasoning_content="",
|
||||||
|
reasoning_content_check=reasoning_content_check
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 生成回复
|
||||||
|
content, reasoning_content = await model.generate_response(prompt)
|
||||||
|
|
||||||
# 设置默认参数
|
|
||||||
default_params = {
|
|
||||||
"model": model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"stream": False,
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新参数
|
|
||||||
if model_params:
|
|
||||||
default_params.update(model_params)
|
|
||||||
|
|
||||||
def create_completion():
|
|
||||||
return self.client.chat.completions.create(**default_params)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
response = await loop.run_in_executor(None, create_completion)
|
|
||||||
|
|
||||||
# 检查响应内容
|
|
||||||
if not response:
|
|
||||||
print("请求未返回任何内容")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not response.choices or not response.choices[0].message.content:
|
|
||||||
print("请求返回的内容无效:", response)
|
|
||||||
return None
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
|
|
||||||
# 获取推理内容
|
|
||||||
reasoning_content = ""
|
|
||||||
if hasattr(response.choices[0].message, "reasoning"):
|
|
||||||
reasoning_content = response.choices[0].message.reasoning or reasoning_content
|
|
||||||
elif hasattr(response.choices[0].message, "reasoning_content"):
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content or reasoning_content
|
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
|
self._save_to_db(
|
||||||
|
message=message,
|
||||||
|
sender_name=sender_name,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_check=prompt_check,
|
||||||
|
content=content,
|
||||||
|
content_check=content_check if global_config.enable_kuuki_read else "",
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||||
|
content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||||
|
"""保存对话记录到数据库"""
|
||||||
self.db.db.reasoning_logs.insert_one({
|
self.db.db.reasoning_logs.insert_one({
|
||||||
'time': time.time(),
|
'time': time.time(),
|
||||||
'group_id': message.group_id,
|
'group_id': message.group_id,
|
||||||
'user': sender_name,
|
'user': sender_name,
|
||||||
'message': message.processed_plain_text,
|
'message': message.processed_plain_text,
|
||||||
'model': model_name,
|
'model': self.current_model_type,
|
||||||
|
'reasoning_check': reasoning_content_check,
|
||||||
|
'response_check': content_check,
|
||||||
'reasoning': reasoning_content,
|
'reasoning': reasoning_content,
|
||||||
'response': content,
|
'response': content,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'model_params': default_params
|
'prompt_check': prompt_check
|
||||||
})
|
})
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-R1 模型生成回复"""
|
|
||||||
if self.config.API_USING == "deepseek":
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-reasoner",
|
|
||||||
{"temperature": 0.7, "max_tokens": 1024}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"Pro/deepseek-ai/DeepSeek-R1",
|
|
||||||
{"temperature": 0.7, "max_tokens": 1024}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-V3 模型生成回复"""
|
|
||||||
if self.config.API_USING == "deepseek":
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-chat",
|
|
||||||
{"temperature": 0.8, "max_tokens": 1024}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"Pro/deepseek-ai/DeepSeek-V3",
|
|
||||||
{"temperature": 0.8, "max_tokens": 1024}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
|
|
||||||
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
|
||||||
return await self._generate_base_response(
|
|
||||||
message,
|
|
||||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
|
||||||
{"temperature": 0.7, "max_tokens": 1024}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_group_chat_context(self, message: Message) -> str:
|
|
||||||
"""获取群聊上下文"""
|
|
||||||
recent_messages = self.db.db.messages.find(
|
|
||||||
{"group_id": message.group_id}
|
|
||||||
).sort("time", -1).limit(15)
|
|
||||||
|
|
||||||
messages_list = list(recent_messages)[::-1]
|
|
||||||
group_chat = ""
|
|
||||||
|
|
||||||
for msg_dict in messages_list:
|
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(msg_dict['time']))
|
|
||||||
display_name = msg_dict.get('user_nickname', f"用户{msg_dict['user_id']}")
|
|
||||||
content = msg_dict.get('processed_plain_text', msg_dict['plain_text'])
|
|
||||||
|
|
||||||
group_chat += f"[{time_str}] {display_name}: {content}\n"
|
|
||||||
|
|
||||||
return group_chat
|
|
||||||
|
|
||||||
async def _get_emotion_tags(self, content: str) -> List[str]:
|
async def _get_emotion_tags(self, content: str) -> List[str]:
|
||||||
"""提取情感标签"""
|
"""提取情感标签"""
|
||||||
@@ -209,33 +130,12 @@ class LLMResponseGenerator:
|
|||||||
输出:
|
输出:
|
||||||
'''
|
'''
|
||||||
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
content, _ = await self.model_v3.generate_response(prompt)
|
||||||
|
return [content.strip()] if content else ["neutral"]
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if self.config.API_USING == "deepseek":
|
|
||||||
model = "deepseek-chat"
|
|
||||||
else:
|
|
||||||
model = "Pro/deepseek-ai/DeepSeek-V3"
|
|
||||||
create_completion = partial(
|
|
||||||
self.client.chat.completions.create,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=30,
|
|
||||||
temperature=0.6
|
|
||||||
)
|
|
||||||
response = await loop.run_in_executor(None, create_completion)
|
|
||||||
|
|
||||||
if response.choices[0].message.content:
|
|
||||||
# 确保返回的是列表格式
|
|
||||||
emotion_tag = response.choices[0].message.content.strip()
|
|
||||||
return [emotion_tag] # 将单个标签包装成列表返回
|
|
||||||
|
|
||||||
return ["neutral"] # 如果无法获取情感标签,返回默认值
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取情感标签时出错: {e}")
|
print(f"获取情感标签时出错: {e}")
|
||||||
return ["neutral"] # 发生错误时返回默认值
|
return ["neutral"]
|
||||||
|
|
||||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||||
@@ -243,10 +143,6 @@ class LLMResponseGenerator:
|
|||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
emotion_tags = await self._get_emotion_tags(content)
|
emotion_tags = await self._get_emotion_tags(content)
|
||||||
|
|
||||||
processed_response = process_llm_response(content)
|
processed_response = process_llm_response(content)
|
||||||
|
|
||||||
return processed_response, emotion_tags
|
return processed_response, emotion_tags
|
||||||
|
|
||||||
# 创建全局实例
|
|
||||||
llm_response = LLMResponseGenerator(global_config)
|
|
||||||
@@ -66,15 +66,21 @@ class PromptBuilder:
|
|||||||
overlapping_second_layer.update(overlap)
|
overlapping_second_layer.update(overlap)
|
||||||
|
|
||||||
# 合并所有需要的记忆
|
# 合并所有需要的记忆
|
||||||
if all_first_layer_items:
|
# if all_first_layer_items:
|
||||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||||
if overlapping_second_layer:
|
# if overlapping_second_layer:
|
||||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||||
|
|
||||||
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
# 使用集合去重
|
||||||
|
# 从每个来源随机选择2条记忆(如果有的话)
|
||||||
|
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
|
||||||
|
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
|
||||||
|
|
||||||
if all_memories: # 只在列表非空时选择随机项
|
# 合并并去重
|
||||||
random_item = choice(all_memories)
|
all_memories = list(set(selected_first_layer + selected_second_layer))
|
||||||
|
if all_memories:
|
||||||
|
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
|
||||||
|
random_item = " ".join(all_memories)
|
||||||
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
|
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
|
||||||
else:
|
else:
|
||||||
memory_prompt = "" # 如果没有记忆,则返回空字符串
|
memory_prompt = "" # 如果没有记忆,则返回空字符串
|
||||||
@@ -112,7 +118,7 @@ class PromptBuilder:
|
|||||||
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
||||||
promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
||||||
@@ -144,9 +150,9 @@ class PromptBuilder:
|
|||||||
prompt_personality = ''
|
prompt_personality = ''
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
if personality_choice < 4/6: # 第一种人格
|
if personality_choice < 4/6: # 第一种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
||||||
请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。'''
|
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
||||||
elif personality_choice < 1: # 第二种人格
|
elif personality_choice < 1: # 第二种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
|
||||||
|
|
||||||
@@ -179,9 +185,23 @@ class PromptBuilder:
|
|||||||
# prompt += f"{activate_prompt}\n"
|
# prompt += f"{activate_prompt}\n"
|
||||||
prompt += f"{prompt_personality}\n"
|
prompt += f"{prompt_personality}\n"
|
||||||
prompt += f"{prompt_ger}\n"
|
prompt += f"{prompt_ger}\n"
|
||||||
prompt += f"{extra_info}\n"
|
prompt += f"{extra_info}\n"
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
'''读空气prompt处理'''
|
||||||
|
|
||||||
|
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
|
||||||
|
prompt_personality_check = ''
|
||||||
|
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
||||||
|
if personality_choice < 4/6: # 第一种人格
|
||||||
|
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
||||||
|
elif personality_choice < 1: # 第二种人格
|
||||||
|
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
||||||
|
|
||||||
|
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
|
||||||
|
|
||||||
|
return prompt,prompt_check_if_response
|
||||||
|
|
||||||
def get_prompt_info(self,message:str,threshold:float):
|
def get_prompt_info(self,message:str,threshold:float):
|
||||||
related_info = ''
|
related_info = ''
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict, List
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .config import global_config, llm_config
|
|
||||||
import jieba
|
import jieba
|
||||||
|
from nonebot import get_driver
|
||||||
|
from .config import global_config
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
api_key=config.siliconflow_key,
|
||||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
base_url=config.siliconflow_base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
def identify_topic_llm(self, text: str) -> Optional[str]:
|
||||||
@@ -21,7 +25,7 @@ class TopicIdentifier:
|
|||||||
消息内容:{text}"""
|
消息内容:{text}"""
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
model=global_config.SILICONFLOW_MODEL_V3,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
max_tokens=10
|
max_tokens=10
|
||||||
|
|||||||
@@ -4,11 +4,15 @@ from typing import List
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
import requests
|
import requests
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .config import llm_config, global_config
|
from .config import global_config
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import math
|
import math
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
def combine_messages(messages: List[Message]) -> str:
|
def combine_messages(messages: List[Message]) -> str:
|
||||||
@@ -64,7 +68,7 @@ def get_embedding(text):
|
|||||||
"encoding_format": "float"
|
"encoding_format": "float"
|
||||||
}
|
}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
"Authorization": f"Bearer {config.siliconflow_key}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,6 +185,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
|||||||
message_detailed_plain_text = ''
|
message_detailed_plain_text = ''
|
||||||
message_detailed_plain_text_list = []
|
message_detailed_plain_text_list = []
|
||||||
|
|
||||||
|
# 反转消息列表,使最新的消息在最后
|
||||||
|
recent_messages.reverse()
|
||||||
|
|
||||||
if combine:
|
if combine:
|
||||||
for msg_db_data in recent_messages:
|
for msg_db_data in recent_messages:
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ from ...common.database import Database
|
|||||||
import zlib # 用于 CRC32
|
import zlib # 用于 CRC32
|
||||||
import base64
|
import base64
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
||||||
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|||||||
|
|
||||||
# 连接数据库
|
# 连接数据库
|
||||||
db = Database(
|
db = Database(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否已存在相同哈希值的图片
|
# 检查是否已存在相同哈希值的图片
|
||||||
|
|||||||
@@ -58,8 +58,8 @@ class WillingManager:
|
|||||||
if group_id in config.talk_frequency_down_groups:
|
if group_id in config.talk_frequency_down_groups:
|
||||||
reply_probability = reply_probability / 3.5
|
reply_probability = reply_probability / 3.5
|
||||||
|
|
||||||
if is_mentioned_bot and user_id == int(1026294844):
|
# if is_mentioned_bot and user_id == int(1026294844):
|
||||||
reply_probability = 1
|
# reply_probability = 1
|
||||||
|
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,10 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
@@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config
|
|||||||
|
|
||||||
# 直接配置数据库连接信息
|
# 直接配置数据库连接信息
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
class KnowledgeLibrary:
|
||||||
|
|||||||
@@ -168,10 +168,12 @@ def main():
|
|||||||
memory_graph.load_graph_from_db()
|
memory_graph.load_graph_from_db()
|
||||||
# 展示两种不同的可视化方式
|
# 展示两种不同的可视化方式
|
||||||
print("\n按连接数量着色的图谱:")
|
print("\n按连接数量着色的图谱:")
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
# visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
visualize_graph_lite(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
print("\n按记忆数量着色的图谱:")
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
# visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
visualize_graph_lite(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
# memory_graph.save_graph_to_db()
|
# memory_graph.save_graph_to_db()
|
||||||
|
|
||||||
@@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 创建一个新图用于可视化
|
||||||
|
H = G.copy()
|
||||||
|
|
||||||
|
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||||
|
nodes_to_remove = []
|
||||||
|
for node in H.nodes():
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
degree = H.degree(node)
|
||||||
|
if memory_count <= 2 or degree <= 2:
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
H.remove_nodes_from(nodes_to_remove)
|
||||||
|
|
||||||
|
# 如果过滤后没有节点,则返回
|
||||||
|
if len(H.nodes()) == 0:
|
||||||
|
print("过滤后没有符合条件的节点可显示")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 保存图到本地
|
||||||
|
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
|
node_colors = []
|
||||||
|
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
|
if color_by_memory:
|
||||||
|
# 计算每个节点的记忆数量
|
||||||
|
memory_counts = []
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
count = len(memory_items)
|
||||||
|
else:
|
||||||
|
count = 1 if memory_items else 0
|
||||||
|
memory_counts.append(count)
|
||||||
|
max_memories = max(memory_counts) if memory_counts else 1
|
||||||
|
|
||||||
|
for count in memory_counts:
|
||||||
|
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
||||||
|
if max_memories > 0:
|
||||||
|
intensity = min(1.0, count / max_memories)
|
||||||
|
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
||||||
|
node_colors.append(color)
|
||||||
|
else:
|
||||||
|
# 使用原来的连接数量着色方案
|
||||||
|
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||||
|
for node in nodes:
|
||||||
|
degree = H.degree(node)
|
||||||
|
if max_degree > 0:
|
||||||
|
red = min(1.0, degree / max_degree)
|
||||||
|
blue = 1.0 - red
|
||||||
|
color = (red, 0, blue)
|
||||||
|
else:
|
||||||
|
color = (0, 0, 1)
|
||||||
|
node_colors.append(color)
|
||||||
|
|
||||||
|
# 绘制图形
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||||
|
nx.draw(H, pos,
|
||||||
|
with_labels=True,
|
||||||
|
node_color=node_colors,
|
||||||
|
node_size=2000,
|
||||||
|
font_size=10,
|
||||||
|
font_family='SimHei',
|
||||||
|
font_weight='bold')
|
||||||
|
|
||||||
|
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
from typing import Tuple, Union
|
|
||||||
import time
|
|
||||||
|
|
||||||
class LLMModel:
|
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.params = kwargs
|
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
|
||||||
"""根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.5,
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
max_retries = 3
|
|
||||||
base_wait_time = 15 # 基础等待时间(秒)
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
return f"请求失败: {str(e)}", ""
|
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
|
||||||
@@ -2,22 +2,28 @@ import os
|
|||||||
import requests
|
import requests
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import time
|
import time
|
||||||
from ..chat.config import BotConfig
|
from nonebot import get_driver
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from src.plugins.chat.config import BotConfig, global_config
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
class LLMModel:
|
class LLMModel:
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
self.api_key = config.siliconflow_key
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
self.base_url = config.siliconflow_base_url
|
||||||
|
|
||||||
if not self.api_key or not self.base_url:
|
if not self.api_key or not self.base_url:
|
||||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||||
|
|
||||||
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
"""根据输入的提示生成模型的响应"""
|
"""根据输入的提示生成模型的响应"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
@@ -40,28 +46,28 @@ class LLMModel:
|
|||||||
|
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
if response.status_code == 429:
|
except Exception as e:
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
time.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
import jieba
|
import jieba
|
||||||
from .llm_module import LLMModel
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
import sys
|
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -169,8 +166,8 @@ class Memory_graph:
|
|||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self,memory_graph:Memory_graph):
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
self.memory_graph = memory_graph
|
self.memory_graph = memory_graph
|
||||||
self.llm_model = LLMModel()
|
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
||||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
||||||
|
|
||||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
@@ -193,13 +190,29 @@ class Hippocampus:
|
|||||||
chat_text.append(chat_)
|
chat_text.append(chat_)
|
||||||
return chat_text
|
return chat_text
|
||||||
|
|
||||||
def build_memory(self,chat_size=12):
|
async def memory_compress(self, input_text, rate=1):
|
||||||
|
information_content = calculate_information_content(input_text)
|
||||||
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
|
topic_response = await self.llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
compressed_memory = set()
|
||||||
|
for topic in topics:
|
||||||
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
|
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
|
||||||
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
|
return compressed_memory
|
||||||
|
|
||||||
|
async def build_memory(self,chat_size=12):
|
||||||
#最近消息获取频率
|
#最近消息获取频率
|
||||||
time_frequency = {'near':1,'mid':2,'far':2}
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
||||||
|
|
||||||
|
|
||||||
for i, input_text in enumerate(memory_sample, 1):
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
#加载进度可视化
|
#加载进度可视化
|
||||||
progress = (i / len(memory_sample)) * 100
|
progress = (i / len(memory_sample)) * 100
|
||||||
@@ -207,44 +220,23 @@ class Hippocampus:
|
|||||||
filled_length = int(bar_length * i // len(memory_sample))
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
|
if input_text:
|
||||||
# 生成压缩后记忆
|
# 生成压缩后记忆
|
||||||
first_memory = set()
|
first_memory = set()
|
||||||
first_memory = self.memory_compress(input_text, 2.5)
|
first_memory = await self.memory_compress(input_text, 2.5)
|
||||||
# 延时防止访问超频
|
#将记忆加入到图谱中
|
||||||
# time.sleep(5)
|
for topic, memory in first_memory:
|
||||||
#将记忆加入到图谱中
|
topics = segment_text(topic)
|
||||||
for topic, memory in first_memory:
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
topics = segment_text(topic)
|
for split_topic in topics:
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
self.memory_graph.add_dot(split_topic,memory)
|
||||||
for split_topic in topics:
|
for split_topic in topics:
|
||||||
self.memory_graph.add_dot(split_topic,memory)
|
for other_split_topic in topics:
|
||||||
for split_topic in topics:
|
if split_topic != other_split_topic:
|
||||||
for other_split_topic in topics:
|
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
if split_topic != other_split_topic:
|
else:
|
||||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
print(f"空消息 跳过")
|
||||||
|
self.memory_graph.save_graph_to_db()
|
||||||
self.memory_graph.save_graph_to_db()
|
|
||||||
|
|
||||||
def memory_compress(self, input_text, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
# print(topic_num)
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = self.llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
# print(topics)
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
@@ -260,16 +252,19 @@ def topic_what(text, topic):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
#创建记忆图
|
#创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
|
|||||||
@@ -13,7 +13,38 @@ import os
|
|||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
from src.plugins.memory_system.llm_module import LLMModel
|
from src.plugins.memory_system.llm_module import LLMModel
|
||||||
|
|
||||||
|
def calculate_information_content(text):
|
||||||
|
"""计算文本的信息量(熵)"""
|
||||||
|
# 统计字符频率
|
||||||
|
char_count = Counter(text)
|
||||||
|
total_chars = len(text)
|
||||||
|
|
||||||
|
# 计算熵
|
||||||
|
entropy = 0
|
||||||
|
for count in char_count.values():
|
||||||
|
probability = count / total_chars
|
||||||
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
|
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||||
|
|
||||||
|
if closest_record:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
|
for record in chat_record:
|
||||||
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
return ''
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -102,7 +133,8 @@ class Memory_graph:
|
|||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
|
||||||
|
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
if closest_record:
|
if closest_record:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
@@ -110,8 +142,9 @@ class Memory_graph:
|
|||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
for record in chat_record:
|
for record in chat_record:
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
if record:
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
|
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
return chat_text
|
return chat_text
|
||||||
|
|
||||||
return [] # 如果没有找到记录,返回空列表
|
return [] # 如果没有找到记录,返回空列表
|
||||||
@@ -187,155 +220,80 @@ class Memory_graph:
|
|||||||
for edge in edges:
|
for edge in edges:
|
||||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
||||||
|
|
||||||
def calculate_information_content(text):
|
# 海马体
|
||||||
|
class Hippocampus:
|
||||||
"""计算文本的信息量(熵)"""
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
# 统计字符频率
|
self.memory_graph = memory_graph
|
||||||
char_count = Counter(text)
|
self.llm_model = LLMModel()
|
||||||
total_chars = len(text)
|
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
|
||||||
# 计算熵
|
|
||||||
entropy = 0
|
|
||||||
for count in char_count.values():
|
|
||||||
probability = count / total_chars
|
|
||||||
entropy -= probability * math.log2(probability)
|
|
||||||
|
|
||||||
return entropy
|
|
||||||
|
|
||||||
|
|
||||||
# Database.initialize(
|
|
||||||
# global_config.MONGODB_HOST,
|
|
||||||
# global_config.MONGODB_PORT,
|
|
||||||
# global_config.DATABASE_NAME
|
|
||||||
# )
|
|
||||||
# memory_graph = Memory_graph()
|
|
||||||
|
|
||||||
# llm_model = LLMModel()
|
|
||||||
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
# memory_graph.load_graph_from_db()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 初始化数据库
|
|
||||||
Database.initialize(
|
|
||||||
host= os.getenv("MONGODB_HOST"),
|
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
|
||||||
)
|
|
||||||
|
|
||||||
memory_graph = Memory_graph()
|
|
||||||
# 创建LLM模型实例
|
|
||||||
llm_model = LLMModel()
|
|
||||||
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
# 使用当前时间戳进行测试
|
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
|
||||||
chat_text = []
|
|
||||||
|
|
||||||
chat_size =25
|
|
||||||
|
|
||||||
for _ in range(30): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
|
|
||||||
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
|
|
||||||
chat_text.append(chat_) # 拼接所有text
|
|
||||||
# time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for i, input_text in enumerate(chat_text, 1):
|
|
||||||
|
|
||||||
progress = (i / len(chat_text)) * 100
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
bar_length = 30
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
filled_length = int(bar_length * i // len(chat_text))
|
chat_text = []
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
#短期:1h 中期:4h 长期:24h
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
|
for _ in range(time_frequency.get('near')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('far')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
return chat_text
|
||||||
|
|
||||||
|
def build_memory(self,chat_size=12):
|
||||||
|
#最近消息获取频率
|
||||||
|
time_frequency = {'near':1,'mid':2,'far':2}
|
||||||
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
|
|
||||||
# print(input_text)
|
#加载进度可视化
|
||||||
first_memory = set()
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
|
progress = (i / len(memory_sample)) * 100
|
||||||
# time.sleep(5)
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
#将记忆加入到图谱中
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
for topic, memory in first_memory:
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
topics = segment_text(topic)
|
# print(f"第{i}条消息: {input_text}")
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
if input_text:
|
||||||
for split_topic in topics:
|
# 生成压缩后记忆
|
||||||
memory_graph.add_dot(split_topic,memory)
|
first_memory = set()
|
||||||
for split_topic in topics:
|
first_memory = self.memory_compress(input_text, 2.5)
|
||||||
for other_split_topic in topics:
|
#将记忆加入到图谱中
|
||||||
if split_topic != other_split_topic:
|
for topic, memory in first_memory:
|
||||||
memory_graph.connect_dot(split_topic, other_split_topic)
|
topics = segment_text(topic)
|
||||||
|
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
||||||
# memory_graph.store_memory()
|
for split_topic in topics:
|
||||||
|
self.memory_graph.add_dot(split_topic,memory)
|
||||||
# 展示两种不同的可视化方式
|
for split_topic in topics:
|
||||||
print("\n按连接数量着色的图谱:")
|
for other_split_topic in topics:
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
if split_topic != other_split_topic:
|
||||||
|
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
||||||
print("\n按记忆数量着色的图谱:")
|
else:
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
print(f"空消息 跳过")
|
||||||
|
|
||||||
memory_graph.save_graph_to_db()
|
|
||||||
# memory_graph.load_graph_from_db()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
items_list = memory_graph.get_related_item(query)
|
|
||||||
if items_list:
|
|
||||||
# print(items_list)
|
|
||||||
for memory_item in items_list:
|
|
||||||
print(memory_item)
|
|
||||||
else:
|
|
||||||
print("未找到相关记忆。")
|
|
||||||
|
|
||||||
while True:
|
self.memory_graph.save_graph_to_db()
|
||||||
query = input("请输入问题:")
|
|
||||||
|
def memory_compress(self, input_text, rate=1):
|
||||||
if query.lower() == '退出':
|
information_content = calculate_information_content(input_text)
|
||||||
break
|
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
||||||
|
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
||||||
topic_prompt = find_topic(query, 3)
|
topic_prompt = find_topic(input_text, topic_num)
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
topic_response = self.llm_model.generate_response(topic_prompt)
|
||||||
# 检查 topic_response 是否为元组
|
# 检查 topic_response 是否为元组
|
||||||
if isinstance(topic_response, tuple):
|
if isinstance(topic_response, tuple):
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
else:
|
else:
|
||||||
topics = topic_response.split(",")
|
topics = topic_response.split(",")
|
||||||
print(topics)
|
compressed_memory = set()
|
||||||
|
for topic in topics:
|
||||||
for keyword in topics:
|
topic_what_prompt = topic_what(input_text,topic)
|
||||||
items_list = memory_graph.get_related_item(keyword)
|
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
||||||
if items_list:
|
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
||||||
print(items_list)
|
return compressed_memory
|
||||||
|
|
||||||
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
print(topic_num)
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
print(topics)
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
@@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
|
|
||||||
G = memory_graph.G
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 创建一个新图用于可视化
|
||||||
|
H = G.copy()
|
||||||
|
|
||||||
|
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||||
|
nodes_to_remove = []
|
||||||
|
for node in H.nodes():
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
degree = H.degree(node)
|
||||||
|
if memory_count <= 1 or degree <= 2:
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
H.remove_nodes_from(nodes_to_remove)
|
||||||
|
|
||||||
|
# 如果过滤后没有节点,则返回
|
||||||
|
if len(H.nodes()) == 0:
|
||||||
|
print("过滤后没有符合条件的节点可显示")
|
||||||
|
return
|
||||||
|
|
||||||
# 保存图到本地
|
# 保存图到本地
|
||||||
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
|
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
# 根据连接条数或记忆数量设置节点颜色
|
# 根据连接条数或记忆数量设置节点颜色
|
||||||
node_colors = []
|
node_colors = []
|
||||||
nodes = list(G.nodes()) # 获取图中实际的节点列表
|
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
||||||
|
|
||||||
if color_by_memory:
|
if color_by_memory:
|
||||||
# 计算每个节点的记忆数量
|
# 计算每个节点的记忆数量
|
||||||
memory_counts = []
|
memory_counts = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
memory_items = G.nodes[node].get('memory_items', [])
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
if isinstance(memory_items, list):
|
if isinstance(memory_items, list):
|
||||||
count = len(memory_items)
|
count = len(memory_items)
|
||||||
else:
|
else:
|
||||||
@@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
node_colors.append(color)
|
node_colors.append(color)
|
||||||
else:
|
else:
|
||||||
# 使用原来的连接数量着色方案
|
# 使用原来的连接数量着色方案
|
||||||
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
|
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
degree = G.degree(node)
|
degree = H.degree(node)
|
||||||
if max_degree > 0:
|
if max_degree > 0:
|
||||||
red = min(1.0, degree / max_degree)
|
red = min(1.0, degree / max_degree)
|
||||||
blue = 1.0 - red
|
blue = 1.0 - red
|
||||||
@@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
|
|
||||||
# 绘制图形
|
# 绘制图形
|
||||||
plt.figure(figsize=(12, 8))
|
plt.figure(figsize=(12, 8))
|
||||||
pos = nx.spring_layout(G, k=1, iterations=50)
|
pos = nx.spring_layout(H, k=1, iterations=50)
|
||||||
nx.draw(G, pos,
|
nx.draw(H, pos,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
node_color=node_colors,
|
node_color=node_colors,
|
||||||
node_size=2000,
|
node_size=2000,
|
||||||
@@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 初始化数据库
|
||||||
|
Database.initialize(
|
||||||
|
host= os.getenv("MONGODB_HOST"),
|
||||||
|
port= int(os.getenv("MONGODB_PORT")),
|
||||||
|
db_name= os.getenv("DATABASE_NAME"),
|
||||||
|
username= os.getenv("MONGODB_USERNAME"),
|
||||||
|
password= os.getenv("MONGODB_PASSWORD"),
|
||||||
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 创建记忆图
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
# 加载数据库中存储的记忆图
|
||||||
|
memory_graph.load_graph_from_db()
|
||||||
|
# 创建海马体
|
||||||
|
hippocampus = Hippocampus(memory_graph)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
|
|
||||||
|
# 构建记忆
|
||||||
|
hippocampus.build_memory(chat_size=25)
|
||||||
|
|
||||||
|
# 展示两种不同的可视化方式
|
||||||
|
print("\n按连接数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=False)
|
||||||
|
|
||||||
|
print("\n按记忆数量着色的图谱:")
|
||||||
|
visualize_graph(memory_graph, color_by_memory=True)
|
||||||
|
|
||||||
|
# 交互式查询
|
||||||
|
while True:
|
||||||
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
items_list = memory_graph.get_related_item(query)
|
||||||
|
if items_list:
|
||||||
|
for memory_item in items_list:
|
||||||
|
print(memory_item)
|
||||||
|
else:
|
||||||
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
query = input("请输入问题:")
|
||||||
|
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
|
||||||
|
topic_prompt = find_topic(query, 3)
|
||||||
|
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
|
||||||
|
# 检查 topic_response 是否为元组
|
||||||
|
if isinstance(topic_response, tuple):
|
||||||
|
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
||||||
|
else:
|
||||||
|
topics = topic_response.split(",")
|
||||||
|
print(topics)
|
||||||
|
|
||||||
|
for keyword in topics:
|
||||||
|
items_list = memory_graph.get_related_item(keyword)
|
||||||
|
if items_list:
|
||||||
|
print(items_list)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
199
src/plugins/models/utils_model.py
Normal file
199
src/plugins/models/utils_model.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
from typing import Tuple, Union
|
||||||
|
from nonebot import get_driver
|
||||||
|
from ..chat.config import global_config
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
class LLM_request:
|
||||||
|
def __init__(self, model = global_config.llm_normal,**kwargs):
|
||||||
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
|
try:
|
||||||
|
self.api_key = getattr(config, model["key"])
|
||||||
|
self.base_url = getattr(config, model["base_url"])
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
|
self.model_name = model["name"]
|
||||||
|
self.params = kwargs
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示生成模型的异步响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示和图片生成模型的异步响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
|
"""同步方法:根据输入的提示和图片生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 2
|
||||||
|
base_wait_time = 6
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
@@ -1,37 +1,47 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from .schedule_llm_module import LLMModel
|
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.config import global_config
|
from src.plugins.chat.config import global_config
|
||||||
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host= config.mongodb_host,
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port= int(config.mongodb_port),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name= config.database_name,
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username= config.mongodb_username,
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password= config.mongodb_password,
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=config.mongodb_auth_source
|
||||||
)
|
)
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if global_config.API_USING == "siliconflow":
|
#根据global_config.llm_normal这一字典配置指定模型
|
||||||
self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3")
|
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
|
||||||
elif global_config.API_USING == "deepseek":
|
self.llm_scheduler = LLM_request(model = global_config.llm_normal,temperature=0.9)
|
||||||
self.llm_scheduler = LLMModel(model_name="deepseek-chat",api_using="deepseek")
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
self.today_schedule_text = ""
|
||||||
|
self.today_schedule = {}
|
||||||
|
self.tomorrow_schedule_text = ""
|
||||||
|
self.tomorrow_schedule = {}
|
||||||
|
self.yesterday_schedule_text = ""
|
||||||
|
self.yesterday_schedule = {}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
today = datetime.datetime.now()
|
today = datetime.datetime.now()
|
||||||
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
||||||
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
|
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
|
||||||
|
|
||||||
self.today_schedule_text, self.today_schedule = self.generate_daily_schedule(target_date=today)
|
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today)
|
||||||
|
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,read_only=True)
|
||||||
self.tomorrow_schedule_text, self.tomorrow_schedule = self.generate_daily_schedule(target_date=tomorrow,read_only=True)
|
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
||||||
self.yesterday_schedule_text, self.yesterday_schedule = self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
|
||||||
|
|
||||||
def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
||||||
if target_date is None:
|
if target_date is None:
|
||||||
target_date = datetime.datetime.now()
|
target_date = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -55,7 +65,7 @@ class ScheduleGenerator:
|
|||||||
3. 晚上的计划和休息时间
|
3. 晚上的计划和休息时间
|
||||||
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
||||||
|
|
||||||
schedule_text, _ = self.llm_scheduler.generate_response(prompt)
|
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
||||||
# print(self.schedule_text)
|
# print(self.schedule_text)
|
||||||
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
class LLMModel:
|
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
|
||||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
|
|
||||||
if api_using == "deepseek":
|
|
||||||
self.api_key = os.getenv("DEEP_SEEK_KEY")
|
|
||||||
self.base_url = os.getenv("DEEP_SEEK_BASE_URL")
|
|
||||||
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
|
|
||||||
self.model_name = model_name
|
|
||||||
else:
|
|
||||||
self.model_name = "deepseek-reasoner"
|
|
||||||
else:
|
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
|
||||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
|
||||||
self.model_name = model_name
|
|
||||||
self.params = kwargs
|
|
||||||
|
|
||||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
|
||||||
"""根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.9,
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
|
||||||
response.raise_for_status() # 检查响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content # 返回内容和推理内容
|
|
||||||
return "没有返回结果", "" # 返回两个值
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串
|
|
||||||
|
|
||||||
# 示例用法
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = LLMModel() # 默认使用 DeepSeek-V3 模型
|
|
||||||
prompt = "你好,你喜欢我吗?"
|
|
||||||
result, reasoning = model.generate_response(prompt)
|
|
||||||
print("回复内容:", result)
|
|
||||||
print("推理内容:", reasoning)
|
|
||||||
70
src/test/emotion_cal.py
Normal file
70
src/test/emotion_cal.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from textblob import TextBlob
|
||||||
|
import jieba
|
||||||
|
from translate import Translator
|
||||||
|
|
||||||
|
def analyze_emotion(text):
|
||||||
|
"""
|
||||||
|
分析文本的情感,返回情感极性和主观性得分
|
||||||
|
:param text: 输入文本
|
||||||
|
:return: (情感极性, 主观性) 元组
|
||||||
|
情感极性: -1(非常消极) 到 1(非常积极)
|
||||||
|
主观性: 0(客观) 到 1(主观)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 创建翻译器
|
||||||
|
translator = Translator(to_lang="en", from_lang="zh")
|
||||||
|
|
||||||
|
# 如果是中文文本,先翻译成英文
|
||||||
|
# 因为TextBlob的情感分析主要基于英文
|
||||||
|
translated_text = translator.translate(text)
|
||||||
|
|
||||||
|
# 创建TextBlob对象
|
||||||
|
blob = TextBlob(translated_text)
|
||||||
|
|
||||||
|
# 获取情感极性和主观性
|
||||||
|
polarity = blob.sentiment.polarity
|
||||||
|
subjectivity = blob.sentiment.subjectivity
|
||||||
|
|
||||||
|
return polarity, subjectivity
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分析过程中出现错误: {str(e)}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def get_emotion_description(polarity, subjectivity):
|
||||||
|
"""
|
||||||
|
根据情感极性和主观性生成描述性文字
|
||||||
|
"""
|
||||||
|
if polarity is None or subjectivity is None:
|
||||||
|
return "无法分析情感"
|
||||||
|
|
||||||
|
# 情感极性描述
|
||||||
|
if polarity > 0.5:
|
||||||
|
emotion = "非常积极"
|
||||||
|
elif polarity > 0:
|
||||||
|
emotion = "较为积极"
|
||||||
|
elif polarity == 0:
|
||||||
|
emotion = "中性"
|
||||||
|
elif polarity > -0.5:
|
||||||
|
emotion = "较为消极"
|
||||||
|
else:
|
||||||
|
emotion = "非常消极"
|
||||||
|
|
||||||
|
# 主观性描述
|
||||||
|
if subjectivity > 0.7:
|
||||||
|
subj = "非常主观"
|
||||||
|
elif subjectivity > 0.3:
|
||||||
|
subj = "较为主观"
|
||||||
|
else:
|
||||||
|
subj = "较为客观"
|
||||||
|
|
||||||
|
return f"情感倾向: {emotion}, 表达方式: {subj}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试样例
|
||||||
|
test_text = "今天天气真好,我感到非常开心!"
|
||||||
|
polarity, subjectivity = analyze_emotion(test_text)
|
||||||
|
print(f"测试文本: {test_text}")
|
||||||
|
print(f"情感极性: {polarity:.2f}")
|
||||||
|
print(f"主观性得分: {subjectivity:.2f}")
|
||||||
|
print(get_emotion_description(polarity, subjectivity))
|
||||||
74
src/test/emotion_cal_bert.py
Normal file
74
src/test/emotion_cal_bert.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
def setup_bert_analyzer():
|
||||||
|
"""
|
||||||
|
设置中文BERT情感分析器
|
||||||
|
"""
|
||||||
|
# 使用专门针对中文情感分析的模型
|
||||||
|
model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载模型和分词器
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# 创建情感分析pipeline
|
||||||
|
analyzer = pipeline("sentiment-analysis",
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
|
||||||
|
return analyzer
|
||||||
|
except Exception as e:
|
||||||
|
print(f"模型加载错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def analyze_emotion_bert(text, analyzer):
|
||||||
|
"""
|
||||||
|
使用BERT模型进行中文情感分析
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not analyzer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 进行情感分析
|
||||||
|
result = analyzer(text)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'label': result['label'],
|
||||||
|
'score': result['score']
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分析过程中出现错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_emotion_description_bert(result):
|
||||||
|
"""
|
||||||
|
将BERT的情感分析结果转换为描述性文字
|
||||||
|
"""
|
||||||
|
if not result:
|
||||||
|
return "无法分析情感"
|
||||||
|
|
||||||
|
label = "积极" if result['label'] == 'positive' else "消极"
|
||||||
|
confidence = result['score']
|
||||||
|
|
||||||
|
if confidence > 0.9:
|
||||||
|
strength = "强烈"
|
||||||
|
elif confidence > 0.7:
|
||||||
|
strength = "明显"
|
||||||
|
else:
|
||||||
|
strength = "轻微"
|
||||||
|
|
||||||
|
return f"{strength}{label}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 初始化分析器
|
||||||
|
analyzer = setup_bert_analyzer()
|
||||||
|
|
||||||
|
# 测试样例
|
||||||
|
test_text = "这个产品质量很好,使用起来非常方便,推荐购买!"
|
||||||
|
result = analyze_emotion_bert(test_text, analyzer)
|
||||||
|
|
||||||
|
print(f"测试文本: {test_text}")
|
||||||
|
if result:
|
||||||
|
print(f"情感倾向: {get_emotion_description_bert(result)}")
|
||||||
|
print(f"置信度: {result['score']:.2f}")
|
||||||
62
src/test/emotion_cal_hanlp.py
Normal file
62
src/test/emotion_cal_hanlp.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import hanlp
|
||||||
|
|
||||||
|
def analyze_emotion_hanlp(text):
|
||||||
|
"""
|
||||||
|
使用HanLP进行中文情感分析
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用更基础的模型
|
||||||
|
tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
words = tokenizer(text)
|
||||||
|
|
||||||
|
# 简单的情感词典方法
|
||||||
|
positive_words = {'好', '棒', '优秀', '喜欢', '开心', '快乐', '美味', '推荐', '优质', '满意'}
|
||||||
|
negative_words = {'差', '糟', '烂', '讨厌', '失望', '难受', '恶心', '不满', '差劲', '垃圾'}
|
||||||
|
|
||||||
|
# 计算情感得分
|
||||||
|
score = 0
|
||||||
|
for word in words:
|
||||||
|
if word in positive_words:
|
||||||
|
score += 1
|
||||||
|
elif word in negative_words:
|
||||||
|
score -= 1
|
||||||
|
|
||||||
|
# 归一化得分
|
||||||
|
if score > 0:
|
||||||
|
return 1
|
||||||
|
elif score < 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分析过程中出现错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_emotion_description_hanlp(score):
|
||||||
|
"""
|
||||||
|
将HanLP的情感分析结果转换为描述性文字
|
||||||
|
"""
|
||||||
|
if score is None:
|
||||||
|
return "无法分析情感"
|
||||||
|
elif score == 1:
|
||||||
|
return "积极"
|
||||||
|
elif score == 0:
|
||||||
|
return "消极"
|
||||||
|
else:
|
||||||
|
return "中性"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试样例
|
||||||
|
test_texts = [
|
||||||
|
"这家餐厅的服务态度很好,菜品也很美味!",
|
||||||
|
"这个产品质量太差了,一点都不值这个价",
|
||||||
|
"今天天气不错,但是工作很累"
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_text in test_texts:
|
||||||
|
result = analyze_emotion_hanlp(test_text)
|
||||||
|
print(f"\n测试文本: {test_text}")
|
||||||
|
print(f"情感倾向: {get_emotion_description_hanlp(result)}")
|
||||||
53
src/test/emotion_cal_snownlp.py
Normal file
53
src/test/emotion_cal_snownlp.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from snownlp import SnowNLP
|
||||||
|
|
||||||
|
def analyze_emotion_snownlp(text):
|
||||||
|
"""
|
||||||
|
使用SnowNLP进行中文情感分析
|
||||||
|
:param text: 输入文本
|
||||||
|
:return: 情感得分(0-1之间,越接近1越积极)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
s = SnowNLP(text)
|
||||||
|
sentiment_score = s.sentiments
|
||||||
|
|
||||||
|
# 获取文本的关键词
|
||||||
|
keywords = s.keywords(3)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sentiment_score': sentiment_score,
|
||||||
|
'keywords': keywords,
|
||||||
|
'summary': s.summary(1) # 生成文本摘要
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分析过程中出现错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_emotion_description_snownlp(score):
|
||||||
|
"""
|
||||||
|
将情感得分转换为描述性文字
|
||||||
|
"""
|
||||||
|
if score is None:
|
||||||
|
return "无法分析情感"
|
||||||
|
|
||||||
|
if score > 0.8:
|
||||||
|
return "非常积极"
|
||||||
|
elif score > 0.6:
|
||||||
|
return "较为积极"
|
||||||
|
elif score > 0.4:
|
||||||
|
return "中性偏积极"
|
||||||
|
elif score > 0.2:
|
||||||
|
return "中性偏消极"
|
||||||
|
else:
|
||||||
|
return "消极"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试样例
|
||||||
|
test_text = "我们学校有免费的gpt4用"
|
||||||
|
result = analyze_emotion_snownlp(test_text)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
print(f"测试文本: {test_text}")
|
||||||
|
print(f"情感得分: {result['sentiment_score']:.2f}")
|
||||||
|
print(f"情感倾向: {get_emotion_description_snownlp(result['sentiment_score'])}")
|
||||||
|
print(f"关键词: {', '.join(result['keywords'])}")
|
||||||
|
print(f"文本摘要: {result['summary'][0]}")
|
||||||
54
src/test/snownlp_demo.py
Normal file
54
src/test/snownlp_demo.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from snownlp import SnowNLP
|
||||||
|
|
||||||
|
def demo_snownlp_features(text):
|
||||||
|
"""
|
||||||
|
展示SnowNLP的主要功能
|
||||||
|
:param text: 输入文本
|
||||||
|
"""
|
||||||
|
print(f"\n=== SnowNLP功能演示 ===")
|
||||||
|
print(f"输入文本: {text}")
|
||||||
|
|
||||||
|
# 创建SnowNLP对象
|
||||||
|
s = SnowNLP(text)
|
||||||
|
|
||||||
|
# 1. 分词
|
||||||
|
print(f"\n1. 分词结果:")
|
||||||
|
print(f" {' | '.join(s.words)}")
|
||||||
|
|
||||||
|
# 2. 情感分析
|
||||||
|
print(f"\n2. 情感分析:")
|
||||||
|
sentiment = s.sentiments
|
||||||
|
print(f" 情感得分: {sentiment:.2f}")
|
||||||
|
print(f" 情感倾向: {'积极' if sentiment > 0.5 else '消极' if sentiment < 0.5 else '中性'}")
|
||||||
|
|
||||||
|
# 3. 关键词提取
|
||||||
|
print(f"\n3. 关键词提取:")
|
||||||
|
print(f" {', '.join(s.keywords(3))}")
|
||||||
|
|
||||||
|
# 4. 词性标注
|
||||||
|
print(f"\n4. 词性标注:")
|
||||||
|
print(f" {' '.join([f'{word}/{tag}' for word, tag in s.tags])}")
|
||||||
|
|
||||||
|
# 5. 拼音转换
|
||||||
|
print(f"\n5. 拼音:")
|
||||||
|
print(f" {' '.join(s.pinyin)}")
|
||||||
|
|
||||||
|
# 6. 文本摘要
|
||||||
|
if len(text) > 100: # 只对较长文本生成摘要
|
||||||
|
print(f"\n6. 文本摘要:")
|
||||||
|
print(f" {' '.join(s.summary(3))}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试用例
|
||||||
|
test_texts = [
|
||||||
|
"这家新开的餐厅很不错,菜品种类丰富,味道可口,服务态度也很好,价格实惠,强烈推荐大家来尝试!",
|
||||||
|
"这部电影剧情混乱,演技浮夸,特效粗糙,配乐难听,完全浪费了我的时间和票价。",
|
||||||
|
"""人工智能正在改变我们的生活方式。它能够帮助我们完成复杂的计算任务,
|
||||||
|
提供个性化的服务推荐,优化交通路线,辅助医疗诊断。但同时我们也要警惕
|
||||||
|
人工智能带来的问题,比如隐私安全、就业变化等。如何正确认识和利用人工智能,
|
||||||
|
是我们每个人都需要思考的问题。"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in test_texts:
|
||||||
|
demo_snownlp_features(text)
|
||||||
|
print("\n" + "="*50)
|
||||||
488
src/test/typo.py
Normal file
488
src/test/typo.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
错别字生成器 - 流程说明
|
||||||
|
|
||||||
|
整体替换逻辑:
|
||||||
|
1. 数据准备
|
||||||
|
- 加载字频词典:使用jieba词典计算汉字使用频率
|
||||||
|
- 创建拼音映射:建立拼音到汉字的映射关系
|
||||||
|
- 加载词频信息:从jieba词典获取词语使用频率
|
||||||
|
|
||||||
|
2. 分词处理
|
||||||
|
- 使用jieba将输入句子分词
|
||||||
|
- 区分单字词和多字词
|
||||||
|
- 保留标点符号和空格
|
||||||
|
|
||||||
|
3. 词语级别替换(针对多字词)
|
||||||
|
- 触发条件:词长>1 且 随机概率<0.3
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取词语拼音
|
||||||
|
b. 生成所有可能的同音字组合
|
||||||
|
c. 过滤条件:
|
||||||
|
- 必须是jieba词典中的有效词
|
||||||
|
- 词频必须达到原词频的10%以上
|
||||||
|
- 综合评分(词频70%+字频30%)必须达到阈值
|
||||||
|
d. 按综合评分排序,选择最合适的替换词
|
||||||
|
|
||||||
|
4. 字级别替换(针对单字词或未进行整词替换的多字词)
|
||||||
|
- 单字替换概率:0.3
|
||||||
|
- 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1))
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取字的拼音
|
||||||
|
b. 声调错误处理(20%概率)
|
||||||
|
c. 获取同音字列表
|
||||||
|
d. 过滤条件:
|
||||||
|
- 字频必须达到最小阈值
|
||||||
|
- 频率差异不能过大(指数衰减计算)
|
||||||
|
e. 按频率排序选择替换字
|
||||||
|
|
||||||
|
5. 频率控制机制
|
||||||
|
- 字频控制:使用归一化的字频(0-1000范围)
|
||||||
|
- 词频控制:使用jieba词典中的词频
|
||||||
|
- 频率差异计算:使用指数衰减函数
|
||||||
|
- 最小频率阈值:确保替换字/词不会太生僻
|
||||||
|
|
||||||
|
6. 输出信息
|
||||||
|
- 原文和错字版本的对照
|
||||||
|
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
|
||||||
|
- 替换类型说明(整词替换/声调错误/同音字替换)
|
||||||
|
- 词语分析和完整拼音
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
1. 所有替换都必须使用有意义的词语
|
||||||
|
2. 替换词的使用频率不能过低
|
||||||
|
3. 多字词优先考虑整词替换
|
||||||
|
4. 考虑声调变化的情况
|
||||||
|
5. 保持标点符号和空格不变
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pypinyin import pinyin, Style
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import jieba.posseg as pseg
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
def load_or_create_char_frequency():
|
||||||
|
"""
|
||||||
|
加载或创建汉字频率字典
|
||||||
|
"""
|
||||||
|
cache_file = Path("char_frequency.json")
|
||||||
|
|
||||||
|
# 如果缓存文件存在,直接加载
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# 使用内置的词频文件
|
||||||
|
char_freq = defaultdict(int)
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
|
||||||
|
# 读取jieba的词典文件
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
word, freq = line.strip().split()[:2]
|
||||||
|
# 对词中的每个字进行频率累加
|
||||||
|
for char in word:
|
||||||
|
if is_chinese_char(char):
|
||||||
|
char_freq[char] += int(freq)
|
||||||
|
|
||||||
|
# 归一化频率值
|
||||||
|
max_freq = max(char_freq.values())
|
||||||
|
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||||
|
|
||||||
|
# 保存到缓存文件
|
||||||
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return normalized_freq
|
||||||
|
|
||||||
|
# 创建拼音到汉字的映射字典
|
||||||
|
def create_pinyin_dict():
|
||||||
|
"""
|
||||||
|
创建拼音到汉字的映射字典
|
||||||
|
"""
|
||||||
|
# 常用汉字范围
|
||||||
|
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||||
|
pinyin_dict = defaultdict(list)
|
||||||
|
|
||||||
|
# 为每个汉字建立拼音映射
|
||||||
|
for char in chars:
|
||||||
|
try:
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
pinyin_dict[py].append(char)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
|
def is_chinese_char(char):
|
||||||
|
"""
|
||||||
|
判断是否为汉字
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pinyin(sentence):
|
||||||
|
"""
|
||||||
|
将中文句子拆分成单个汉字并获取其拼音
|
||||||
|
:param sentence: 输入的中文句子
|
||||||
|
:return: 每个汉字及其拼音的列表
|
||||||
|
"""
|
||||||
|
# 将句子拆分成单个字符
|
||||||
|
characters = list(sentence)
|
||||||
|
|
||||||
|
# 获取每个字符的拼音
|
||||||
|
result = []
|
||||||
|
for char in characters:
|
||||||
|
# 跳过空格和非汉字字符
|
||||||
|
if char.isspace() or not is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
# 获取拼音(数字声调)
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
result.append((char, py))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取同音字,按照使用频率排序
|
||||||
|
"""
|
||||||
|
homophones = pinyin_dict[py]
|
||||||
|
# 移除原字并过滤低频字
|
||||||
|
if char in homophones:
|
||||||
|
homophones.remove(char)
|
||||||
|
|
||||||
|
# 过滤掉低频字
|
||||||
|
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
# 按照字频排序
|
||||||
|
sorted_homophones = sorted(homophones,
|
||||||
|
key=lambda x: char_frequency.get(x, 0),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# 只返回前10个同音字,避免输出过多
|
||||||
|
return sorted_homophones[:10]
|
||||||
|
|
||||||
|
def get_similar_tone_pinyin(py):
|
||||||
|
"""
|
||||||
|
获取相似声调的拼音
|
||||||
|
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
||||||
|
处理特殊情况:
|
||||||
|
1. 轻声(如 'de5' 或 'le')
|
||||||
|
2. 非数字结尾的拼音
|
||||||
|
"""
|
||||||
|
# 检查拼音是否为空或无效
|
||||||
|
if not py or len(py) < 1:
|
||||||
|
return py
|
||||||
|
|
||||||
|
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||||
|
if not py[-1].isdigit():
|
||||||
|
# 为非数字结尾的拼音添加数字声调1
|
||||||
|
return py + '1'
|
||||||
|
|
||||||
|
base = py[:-1] # 去掉声调
|
||||||
|
tone = int(py[-1]) # 获取声调
|
||||||
|
|
||||||
|
# 处理轻声(通常用5表示)或无效声调
|
||||||
|
if tone not in [1, 2, 3, 4]:
|
||||||
|
return base + str(random.choice([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
# 正常处理声调
|
||||||
|
possible_tones = [1, 2, 3, 4]
|
||||||
|
possible_tones.remove(tone) # 移除原声调
|
||||||
|
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||||
|
return base + str(new_tone)
|
||||||
|
|
||||||
|
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
||||||
|
"""
|
||||||
|
根据频率差计算替换概率
|
||||||
|
频率差越大,概率越低
|
||||||
|
:param orig_freq: 原字频率
|
||||||
|
:param target_freq: 目标字频率
|
||||||
|
:param max_freq_diff: 最大允许的频率差
|
||||||
|
:return: 0-1之间的概率值
|
||||||
|
"""
|
||||||
|
if target_freq > orig_freq:
|
||||||
|
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||||
|
|
||||||
|
freq_diff = orig_freq - target_freq
|
||||||
|
if freq_diff > max_freq_diff:
|
||||||
|
return 0.0 # 频率差太大,不替换
|
||||||
|
|
||||||
|
# 使用指数衰减函数计算概率
|
||||||
|
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||||
|
return math.exp(-3 * freq_diff / max_freq_diff)
|
||||||
|
|
||||||
|
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
获取与给定字频率相近的同音字,可能包含声调错误
|
||||||
|
"""
|
||||||
|
homophones = []
|
||||||
|
|
||||||
|
# 有20%的概率使用错误声调
|
||||||
|
if random.random() < tone_error_rate:
|
||||||
|
wrong_tone_py = get_similar_tone_pinyin(py)
|
||||||
|
homophones.extend(pinyin_dict[wrong_tone_py])
|
||||||
|
|
||||||
|
# 添加正确声调的同音字
|
||||||
|
homophones.extend(pinyin_dict[py])
|
||||||
|
|
||||||
|
if not homophones:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取原字的频率
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||||
|
freq_diff = [(h, char_frequency.get(h, 0))
|
||||||
|
for h in homophones
|
||||||
|
if h != char and char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
if not freq_diff:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算每个候选字的替换概率
|
||||||
|
candidates_with_prob = []
|
||||||
|
for h, freq in freq_diff:
|
||||||
|
prob = calculate_replacement_probability(orig_freq, freq)
|
||||||
|
if prob > 0: # 只保留有效概率的候选字
|
||||||
|
candidates_with_prob.append((h, prob))
|
||||||
|
|
||||||
|
if not candidates_with_prob:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据概率排序
|
||||||
|
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 返回概率最高的几个字
|
||||||
|
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||||
|
|
||||||
|
def get_word_pinyin(word):
|
||||||
|
"""
|
||||||
|
获取词语的拼音列表
|
||||||
|
"""
|
||||||
|
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||||
|
|
||||||
|
def segment_sentence(sentence):
|
||||||
|
"""
|
||||||
|
使用jieba分词,返回词语列表
|
||||||
|
"""
|
||||||
|
return list(jieba.cut(sentence))
|
||||||
|
|
||||||
|
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取整个词的同音词,只返回高频的有意义词语
|
||||||
|
:param word: 输入词语
|
||||||
|
:param pinyin_dict: 拼音字典
|
||||||
|
:param char_frequency: 字频字典
|
||||||
|
:param min_freq: 最小频率阈值
|
||||||
|
:return: 同音词列表
|
||||||
|
"""
|
||||||
|
if len(word) == 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取词的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
word_pinyin_str = ''.join(word_pinyin)
|
||||||
|
|
||||||
|
# 创建词语频率字典
|
||||||
|
word_freq = defaultdict(float)
|
||||||
|
|
||||||
|
# 遍历所有可能的同音字组合
|
||||||
|
candidates = []
|
||||||
|
for py in word_pinyin:
|
||||||
|
chars = pinyin_dict.get(py, [])
|
||||||
|
if not chars:
|
||||||
|
return []
|
||||||
|
candidates.append(chars)
|
||||||
|
|
||||||
|
# 生成所有可能的组合
|
||||||
|
import itertools
|
||||||
|
all_combinations = itertools.product(*candidates)
|
||||||
|
|
||||||
|
# 获取jieba词典和词频信息
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
valid_words = {} # 改用字典存储词语及其频率
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
word_text = parts[0]
|
||||||
|
word_freq = float(parts[1]) # 获取词频
|
||||||
|
valid_words[word_text] = word_freq
|
||||||
|
|
||||||
|
# 获取原词的词频作为参考
|
||||||
|
original_word_freq = valid_words.get(word, 0)
|
||||||
|
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||||
|
|
||||||
|
# 过滤和计算频率
|
||||||
|
homophones = []
|
||||||
|
for combo in all_combinations:
|
||||||
|
new_word = ''.join(combo)
|
||||||
|
if new_word != word and new_word in valid_words:
|
||||||
|
new_word_freq = valid_words[new_word]
|
||||||
|
# 只保留词频达到阈值的词
|
||||||
|
if new_word_freq >= min_word_freq:
|
||||||
|
# 计算词的平均字频(考虑字频和词频)
|
||||||
|
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||||
|
# 综合评分:结合词频和字频
|
||||||
|
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
|
||||||
|
if combined_score >= min_freq:
|
||||||
|
homophones.append((new_word, combined_score))
|
||||||
|
|
||||||
|
# 按综合分数排序并限制返回数量
|
||||||
|
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||||
|
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||||
|
|
||||||
|
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
|
||||||
|
"""
|
||||||
|
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||||
|
只使用高频的有意义词语进行替换
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
typo_info = []
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# 如果是标点符号或空格,直接添加
|
||||||
|
if all(not is_chinese_char(c) for c in word):
|
||||||
|
result.append(word)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取词语的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
|
||||||
|
# 尝试整词替换
|
||||||
|
if len(word) > 1 and random.random() < word_replace_rate:
|
||||||
|
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
|
||||||
|
if word_homophones:
|
||||||
|
typo_word = random.choice(word_homophones)
|
||||||
|
# 计算词的平均频率
|
||||||
|
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
|
||||||
|
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||||
|
|
||||||
|
# 添加到结果中
|
||||||
|
result.append(typo_word)
|
||||||
|
typo_info.append((word, typo_word,
|
||||||
|
' '.join(word_pinyin),
|
||||||
|
' '.join(get_word_pinyin(typo_word)),
|
||||||
|
orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果不进行整词替换,则进行单字替换
|
||||||
|
if len(word) == 1:
|
||||||
|
char = word
|
||||||
|
py = word_pinyin[0]
|
||||||
|
if random.random() < error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
# 处理多字词的单字替换
|
||||||
|
word_result = []
|
||||||
|
for i, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||||
|
# 词中的字替换概率降低
|
||||||
|
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
|
||||||
|
|
||||||
|
if random.random() < word_error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
word_result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
word_result.append(char)
|
||||||
|
result.append(''.join(word_result))
|
||||||
|
|
||||||
|
return ''.join(result), typo_info
|
||||||
|
|
||||||
|
def format_frequency(freq):
|
||||||
|
"""
|
||||||
|
格式化频率显示
|
||||||
|
"""
|
||||||
|
return f"{freq:.2f}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 首先创建拼音字典和加载字频统计
|
||||||
|
print("正在加载汉字数据库,请稍候...")
|
||||||
|
pinyin_dict = create_pinyin_dict()
|
||||||
|
char_frequency = load_or_create_char_frequency()
|
||||||
|
|
||||||
|
# 获取用户输入
|
||||||
|
sentence = input("请输入中文句子:")
|
||||||
|
|
||||||
|
# 创建包含错别字的句子
|
||||||
|
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
||||||
|
error_rate=0.3, min_freq=5,
|
||||||
|
tone_error_rate=0.2, word_replace_rate=0.3)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n原句:", sentence)
|
||||||
|
print("错字版:", typo_sentence)
|
||||||
|
|
||||||
|
if typo_info:
|
||||||
|
print("\n错别字信息:")
|
||||||
|
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||||
|
# 判断是否为词语替换
|
||||||
|
is_word = ' ' in orig_py
|
||||||
|
if is_word:
|
||||||
|
error_type = "整词替换"
|
||||||
|
else:
|
||||||
|
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||||
|
error_type = "声调错误" if tone_error else "同音字替换"
|
||||||
|
|
||||||
|
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
||||||
|
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
||||||
|
|
||||||
|
# 获取拼音结果
|
||||||
|
result = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 打印完整拼音
|
||||||
|
print("\n完整拼音:")
|
||||||
|
print(" ".join(py for _, py in result))
|
||||||
|
|
||||||
|
# 打印词语分析
|
||||||
|
print("\n词语分析:")
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
for word in words:
|
||||||
|
if any(is_chinese_char(c) for c in word):
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
print(f"词语:{word}")
|
||||||
|
print(f"拼音:{' '.join(word_pinyin)}")
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
# 计算并打印总耗时
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
print(f"\n总耗时:{total_time:.2f}秒")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
301
src/test/typo_word.py
Normal file
301
src/test/typo_word.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
from pypinyin import pinyin, Style
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import jieba.posseg as pseg
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
|
def load_or_create_char_frequency():
|
||||||
|
"""
|
||||||
|
加载或创建汉字频率字典
|
||||||
|
"""
|
||||||
|
cache_file = Path("char_frequency.json")
|
||||||
|
|
||||||
|
# 如果缓存文件存在,直接加载
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# 使用内置的词频文件
|
||||||
|
char_freq = defaultdict(int)
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
|
||||||
|
# 读取jieba的词典文件
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
word, freq = line.strip().split()[:2]
|
||||||
|
# 对词中的每个字进行频率累加
|
||||||
|
for char in word:
|
||||||
|
if is_chinese_char(char):
|
||||||
|
char_freq[char] += int(freq)
|
||||||
|
|
||||||
|
# 归一化频率值
|
||||||
|
max_freq = max(char_freq.values())
|
||||||
|
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||||
|
|
||||||
|
# 保存到缓存文件
|
||||||
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return normalized_freq
|
||||||
|
|
||||||
|
# 创建拼音到汉字的映射字典
|
||||||
|
def create_pinyin_dict():
|
||||||
|
"""
|
||||||
|
创建拼音到汉字的映射字典
|
||||||
|
"""
|
||||||
|
# 常用汉字范围
|
||||||
|
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||||
|
pinyin_dict = defaultdict(list)
|
||||||
|
|
||||||
|
# 为每个汉字建立拼音映射
|
||||||
|
for char in chars:
|
||||||
|
try:
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
pinyin_dict[py].append(char)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
|
def is_chinese_char(char):
|
||||||
|
"""
|
||||||
|
判断是否为汉字
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pinyin(sentence):
|
||||||
|
"""
|
||||||
|
将中文句子拆分成单个汉字并获取其拼音
|
||||||
|
:param sentence: 输入的中文句子
|
||||||
|
:return: 每个汉字及其拼音的列表
|
||||||
|
"""
|
||||||
|
# 将句子拆分成单个字符
|
||||||
|
characters = list(sentence)
|
||||||
|
|
||||||
|
# 获取每个字符的拼音
|
||||||
|
result = []
|
||||||
|
for char in characters:
|
||||||
|
# 跳过空格和非汉字字符
|
||||||
|
if char.isspace() or not is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
# 获取拼音(数字声调)
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
result.append((char, py))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取同音字,按照使用频率排序
|
||||||
|
"""
|
||||||
|
homophones = pinyin_dict[py]
|
||||||
|
# 移除原字并过滤低频字
|
||||||
|
if char in homophones:
|
||||||
|
homophones.remove(char)
|
||||||
|
|
||||||
|
# 过滤掉低频字
|
||||||
|
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
# 按照字频排序
|
||||||
|
sorted_homophones = sorted(homophones,
|
||||||
|
key=lambda x: char_frequency.get(x, 0),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# 只返回前10个同音字,避免输出过多
|
||||||
|
return sorted_homophones[:10]
|
||||||
|
|
||||||
|
def get_similar_tone_pinyin(py):
|
||||||
|
"""
|
||||||
|
获取相似声调的拼音
|
||||||
|
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
||||||
|
"""
|
||||||
|
base = py[:-1] # 去掉声调
|
||||||
|
tone = int(py[-1]) # 获取声调
|
||||||
|
possible_tones = [1, 2, 3, 4]
|
||||||
|
possible_tones.remove(tone) # 移除原声调
|
||||||
|
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||||
|
return base + str(new_tone)
|
||||||
|
|
||||||
|
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
||||||
|
"""
|
||||||
|
根据频率差计算替换概率
|
||||||
|
频率差越大,概率越低
|
||||||
|
:param orig_freq: 原字频率
|
||||||
|
:param target_freq: 目标字频率
|
||||||
|
:param max_freq_diff: 最大允许的频率差
|
||||||
|
:return: 0-1之间的概率值
|
||||||
|
"""
|
||||||
|
if target_freq > orig_freq:
|
||||||
|
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||||
|
|
||||||
|
freq_diff = orig_freq - target_freq
|
||||||
|
if freq_diff > max_freq_diff:
|
||||||
|
return 0.0 # 频率差太大,不替换
|
||||||
|
|
||||||
|
# 使用指数衰减函数计算概率
|
||||||
|
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||||
|
return math.exp(-3 * freq_diff / max_freq_diff)
|
||||||
|
|
||||||
|
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
获取与给定字频率相近的同音字,可能包含声调错误
|
||||||
|
"""
|
||||||
|
homophones = []
|
||||||
|
|
||||||
|
# 有20%的概率使用错误声调
|
||||||
|
if random.random() < tone_error_rate:
|
||||||
|
wrong_tone_py = get_similar_tone_pinyin(py)
|
||||||
|
homophones.extend(pinyin_dict[wrong_tone_py])
|
||||||
|
|
||||||
|
# 添加正确声调的同音字
|
||||||
|
homophones.extend(pinyin_dict[py])
|
||||||
|
|
||||||
|
if not homophones:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取原字的频率
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||||
|
freq_diff = [(h, char_frequency.get(h, 0))
|
||||||
|
for h in homophones
|
||||||
|
if h != char and char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
if not freq_diff:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算每个候选字的替换概率
|
||||||
|
candidates_with_prob = []
|
||||||
|
for h, freq in freq_diff:
|
||||||
|
prob = calculate_replacement_probability(orig_freq, freq)
|
||||||
|
if prob > 0: # 只保留有效概率的候选字
|
||||||
|
candidates_with_prob.append((h, prob))
|
||||||
|
|
||||||
|
if not candidates_with_prob:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据概率排序
|
||||||
|
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 返回概率最高的几个字
|
||||||
|
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||||
|
|
||||||
|
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
创建包含同音字错误的句子,保留原文标点符号
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
typo_info = []
|
||||||
|
|
||||||
|
# 获取每个字的拼音
|
||||||
|
chars_with_pinyin = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 创建原字到拼音的映射,用于跟踪已处理的字符
|
||||||
|
processed_chars = {char: py for char, py in chars_with_pinyin}
|
||||||
|
|
||||||
|
# 遍历原句中的每个字符
|
||||||
|
char_index = 0
|
||||||
|
for i, char in enumerate(sentence):
|
||||||
|
if char.isspace():
|
||||||
|
# 保留空格
|
||||||
|
result.append(char)
|
||||||
|
elif char in processed_chars:
|
||||||
|
# 处理汉字
|
||||||
|
py = processed_chars[char]
|
||||||
|
# 基础错误率
|
||||||
|
if random.random() < error_rate:
|
||||||
|
# 获取频率相近的同音字(可能包含声调错误)
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
# 随机选择一个替换字
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
# 获取替换字的频率
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算实际替换概率
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
|
||||||
|
# 根据频率差进行概率替换
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
result.append(typo_char)
|
||||||
|
# 获取替换字的实际拼音
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
char_index += 1
|
||||||
|
else:
|
||||||
|
# 保留非汉字字符(标点符号等)
|
||||||
|
result.append(char)
|
||||||
|
|
||||||
|
return ''.join(result), typo_info
|
||||||
|
|
||||||
|
def format_frequency(freq):
|
||||||
|
"""
|
||||||
|
格式化频率显示
|
||||||
|
"""
|
||||||
|
return f"{freq:.2f}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 首先创建拼音字典和加载字频统计
|
||||||
|
print("正在加载汉字数据库,请稍候...")
|
||||||
|
pinyin_dict = create_pinyin_dict()
|
||||||
|
char_frequency = load_or_create_char_frequency()
|
||||||
|
|
||||||
|
# 获取用户输入
|
||||||
|
sentence = input("请输入中文句子:")
|
||||||
|
|
||||||
|
# 创建包含错别字的句子
|
||||||
|
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
||||||
|
min_freq=5, tone_error_rate=0.2)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n原句:", sentence)
|
||||||
|
print("错字版:", typo_sentence)
|
||||||
|
|
||||||
|
if typo_info:
|
||||||
|
print("\n错别字信息:")
|
||||||
|
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||||
|
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||||
|
error_type = "声调错误" if tone_error else "同音字替换"
|
||||||
|
print(f"原字:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
||||||
|
f"错字:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
||||||
|
|
||||||
|
# 获取拼音结果
|
||||||
|
result = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 打印完整拼音
|
||||||
|
print("\n完整拼音:")
|
||||||
|
print(" ".join(py for _, py in result))
|
||||||
|
|
||||||
|
# 打印所有可能的同音字
|
||||||
|
print("\n每个字的所有同音字(按频率排序,仅显示频率>=5的字):")
|
||||||
|
for char, py in result:
|
||||||
|
homophones = get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5)
|
||||||
|
char_freq = char_frequency.get(char, 0)
|
||||||
|
print(f"{char}: {py} [频率:{format_frequency(char_freq)}]")
|
||||||
|
if homophones:
|
||||||
|
homophone_info = []
|
||||||
|
for h in homophones:
|
||||||
|
h_freq = char_frequency.get(h, 0)
|
||||||
|
homophone_info.append(f"{h}[{format_frequency(h_freq)}]")
|
||||||
|
print(f"同音字: {','.join(homophone_info)}")
|
||||||
|
else:
|
||||||
|
print("没有找到频率>=5的同音字")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user