Merge remote-tracking branch 'upstream/debug' into feature
This commit is contained in:
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.git
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.DS_Store
|
||||||
26
.env
26
.env
@@ -1,26 +1,2 @@
|
|||||||
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
||||||
ENVIRONMENT=dev
|
ENVIRONMENT=.dev
|
||||||
# HOST=127.0.0.1
|
|
||||||
# PORT=8080
|
|
||||||
|
|
||||||
# COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# # 插件配置
|
|
||||||
# PLUGINS=["src2.plugins.chat"]
|
|
||||||
|
|
||||||
# # 默认配置
|
|
||||||
# MONGODB_HOST=127.0.0.1
|
|
||||||
# MONGODB_PORT=27017
|
|
||||||
# DATABASE_NAME=MegBot
|
|
||||||
|
|
||||||
# MONGODB_USERNAME = "" # 默认空值
|
|
||||||
# MONGODB_PASSWORD = "" # 默认空值
|
|
||||||
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
|
||||||
|
|
||||||
# #key and url
|
|
||||||
# CHAT_ANY_WHERE_KEY=
|
|
||||||
# SILICONFLOW_KEY=
|
|
||||||
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
|
||||||
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
|
||||||
# DEEP_SEEK_KEY=
|
|
||||||
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# 插件配置
|
# 插件配置
|
||||||
PLUGINS=["src2.plugins.chat"]
|
PLUGINS=["src2.plugins.chat"]
|
||||||
|
|
||||||
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
|
|||||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
#key and url
|
#key and url
|
||||||
|
|
||||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
#定义你要用的api的base_url
|
||||||
DEEP_SEEK_KEY=
|
DEEP_SEEK_KEY=
|
||||||
CHAT_ANY_WHERE_KEY=
|
CHAT_ANY_WHERE_KEY=
|
||||||
SILICONFLOW_KEY=
|
SILICONFLOW_KEY=
|
||||||
16
Dockerfile
16
Dockerfile
@@ -1,8 +1,18 @@
|
|||||||
FROM nonebot/nb-cli:latest
|
FROM nonebot/nb-cli:latest
|
||||||
WORKDIR /
|
|
||||||
COPY . /MaiMBot/
|
# 设置工作目录
|
||||||
WORKDIR /MaiMBot
|
WORKDIR /MaiMBot
|
||||||
|
|
||||||
|
# 先复制依赖列表
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# 安装依赖(这层会被缓存直到requirements.txt改变)
|
||||||
RUN pip install --upgrade -r requirements.txt
|
RUN pip install --upgrade -r requirements.txt
|
||||||
|
|
||||||
|
# 然后复制项目代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
VOLUME [ "/MaiMBot/config" ]
|
VOLUME [ "/MaiMBot/config" ]
|
||||||
|
VOLUME [ "/MaiMBot/data" ]
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
ENTRYPOINT [ "nb","run" ]
|
ENTRYPOINT [ "nb","run" ]
|
||||||
178
README.md
178
README.md
@@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
|
||||||

|

|
||||||

|

|
||||||

|

|
||||||
@@ -12,163 +11,33 @@
|
|||||||
|
|
||||||
## 📝 项目简介
|
## 📝 项目简介
|
||||||
|
|
||||||
**麦麦qq机器人的源代码仓库**
|
**🍔麦麦是一个基于大语言模型的智能QQ群聊机器人**
|
||||||
|
|
||||||
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
|
- 🤖 基于 nonebot2 框架开发
|
||||||
|
- 🧠 LLM 提供对话能力
|
||||||
|
- 💾 MongoDB 提供数据持久化支持
|
||||||
|
- 🐧 NapCat 作为QQ协议端支持
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
|
||||||
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
|
<img src="docs/video.png" width="300" alt="麦麦演示视频">
|
||||||
<br>
|
<br>
|
||||||
👆 点击观看麦麦演示视频 👆
|
👆 点击观看麦麦演示视频 👆
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
|
> ⚠️ **注意事项**
|
||||||
> ⚠️ **警告**:请自行了解qqbot的风险,麦麦有时候一天被腾讯肘七八次
|
> - 项目处于活跃开发阶段,代码可能随时更改
|
||||||
> ⚠️ **警告**:由于麦麦一直在迭代,所以可能存在一些bug,请自行测试,包括胡言乱语(
|
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
|
||||||
|
> - QQ机器人存在被限制风险,请自行了解,谨慎使用
|
||||||
|
> - 由于持续迭代,可能存在一些已知或未知的bug
|
||||||
|
|
||||||
关于麦麦的开发和建议相关的讨论群:766798517(不建议发布无关消息)这里不会有麦麦发言!
|
**交流群**: 766798517(仅用于开发和建议相关讨论)不建议在群内询问部署问题,我不一定有空回复,会优先写文档和代码
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
## 📚 文档
|
||||||
|
|
||||||
- 兼容gif的解析和保存
|
- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明
|
||||||
- 小程序转发链接解析
|
- [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节
|
||||||
- 对思考链长度限制
|
|
||||||
- 修复已知bug
|
|
||||||
- 完善文档
|
|
||||||
- 修复转发
|
|
||||||
- config自动生成和检测
|
|
||||||
- log别用print
|
|
||||||
- 给发送消息写专门的类
|
|
||||||
- 改进表情包发送逻辑l
|
|
||||||
|
|
||||||
|
|
||||||
## 📚 详细文档
|
|
||||||
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
|
|
||||||
|
|
||||||
### 安装方法(还没测试好,随时outdated ,现在部署可能遇到未知问题!!!!)
|
|
||||||
|
|
||||||
#### Linux 使用 Docker Compose 部署
|
|
||||||
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
|
|
||||||
```bash
|
|
||||||
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
|
|
||||||
```
|
|
||||||
配置文件修改完成后,运行以下命令
|
|
||||||
```bash
|
|
||||||
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 手动运行
|
|
||||||
1. **创建Python环境**
|
|
||||||
推荐使用conda或其他虚拟环境进行依赖安装,防止出现依赖版本冲突问题
|
|
||||||
```bash
|
|
||||||
# 安装requirements
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
2. **MongoDB设置**
|
|
||||||
- 安装并运行mongodb
|
|
||||||
- 麦麦bot会自动连接默认的mongodb,端口和数据库名可配置
|
|
||||||
|
|
||||||
3. **Napcat配置**
|
|
||||||
- 安装并运行Napcat,登录
|
|
||||||
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
|
|
||||||
|
|
||||||
4. **配置文件设置**
|
|
||||||
- 修改.env的 变量值为 prod
|
|
||||||
- 将.env.prod文件打开,填上你的apikey(硅基流动或deepseekapi)
|
|
||||||
- 将bot_config_toml改名为bot_config.toml,打开并填写相关内容,不然无法正常运行
|
|
||||||
|
|
||||||
#### .env 文件配置说明
|
|
||||||
```ini
|
|
||||||
# 环境配置
|
|
||||||
ENVIRONMENT=dev # 开发环境设置
|
|
||||||
HOST=127.0.0.1 # 主机地址
|
|
||||||
PORT=8080 # 端口号
|
|
||||||
|
|
||||||
# 命令前缀设置
|
|
||||||
COMMAND_START=["/"] # 命令起始符
|
|
||||||
|
|
||||||
# 插件配置
|
|
||||||
PLUGINS=["src2.plugins.chat"] # 启用的插件列表
|
|
||||||
|
|
||||||
# MongoDB配置
|
|
||||||
MONGODB_HOST=127.0.0.1 # MongoDB主机地址
|
|
||||||
MONGODB_PORT=27017 # MongoDB端口
|
|
||||||
DATABASE_NAME=MegBot # 数据库名称
|
|
||||||
MONGODB_USERNAME="" # MongoDB用户名(可选)
|
|
||||||
MONGODB_PASSWORD="" # MongoDB密码(可选)
|
|
||||||
MONGODB_AUTH_SOURCE="" # MongoDB认证源(可选)
|
|
||||||
|
|
||||||
#api配置项,建议siliconflow必填,识图需要这个
|
|
||||||
SILICONFLOW_KEY=
|
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
|
||||||
DEEP_SEEK_KEY=
|
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### bot_config.toml 文件配置说明
|
|
||||||
```toml
|
|
||||||
# 数据库设置
|
|
||||||
[database]
|
|
||||||
host = "127.0.0.1" # MongoDB主机地址
|
|
||||||
port = 27017 # MongoDB端口
|
|
||||||
name = "MegBot" # 数据库名称
|
|
||||||
|
|
||||||
# 机器人基本设置
|
|
||||||
[bot]
|
|
||||||
qq = # 你的机器人QQ号(必填)
|
|
||||||
nickname = "麦麦" # 机器人昵称
|
|
||||||
|
|
||||||
# 消息处理设置
|
|
||||||
[message]
|
|
||||||
min_text_length = 2 # 最小响应文本长度
|
|
||||||
max_context_size = 15 # 上下文最大保存数量
|
|
||||||
emoji_chance = 0.2 # 表情包使用概率
|
|
||||||
|
|
||||||
# 表情包功能设置
|
|
||||||
[emoji]
|
|
||||||
check_interval = 120 # 表情检查间隔(秒)
|
|
||||||
register_interval = 10 # 表情注册间隔(秒)
|
|
||||||
|
|
||||||
# CQ码设置
|
|
||||||
[cq_code]
|
|
||||||
enable_pic_translate = false # 是否启用图片转换(无效)
|
|
||||||
|
|
||||||
# 响应设置
|
|
||||||
[response]
|
|
||||||
api_using = "siliconflow" # 回复使用的API(siliconflow/deepseek)
|
|
||||||
model_r1_probability = 0.8 # R1模型使用概率
|
|
||||||
model_v3_probability = 0.1 # V3模型使用概率
|
|
||||||
model_r1_distill_probability = 0.1 # R1蒸馏模型使用概率(对deepseek api 无效)
|
|
||||||
|
|
||||||
# 其他设置
|
|
||||||
[others]
|
|
||||||
enable_advance_output = false # 是否启用详细日志输出
|
|
||||||
|
|
||||||
# 群组设置
|
|
||||||
[groups]
|
|
||||||
talk_allowed = [ # 允许回复的群号列表
|
|
||||||
# 在这里添加群号,逗号隔开
|
|
||||||
]
|
|
||||||
|
|
||||||
talk_frequency_down = [ # 降低回复频率的群号列表
|
|
||||||
# 在这里添加群号,逗号隔开
|
|
||||||
]
|
|
||||||
|
|
||||||
ban_user_id = [ # 禁止回复的用户QQ号列表
|
|
||||||
# 在这里添加QQ号,逗号隔开
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
5. **运行麦麦**
|
|
||||||
在含有bot.py程序的目录下运行(如果使用了虚拟环境需要先进入虚拟环境)
|
|
||||||
```bash
|
|
||||||
nb run
|
|
||||||
```
|
|
||||||
6. **运行其他组件**
|
|
||||||
run_thingking.bat 可以启动可视化的推理界面(未完善)和消息队列及其他信息预览(WIP)
|
|
||||||
knowledge.bat可以将/data/raw_info下的文本文档载入到数据库(未启动)
|
|
||||||
|
|
||||||
## 🎯 功能介绍
|
## 🎯 功能介绍
|
||||||
|
|
||||||
@@ -204,6 +73,19 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|||||||
- 幽默和meme功能:WIP的WIP
|
- 幽默和meme功能:WIP的WIP
|
||||||
- 让麦麦玩mc:WIP的WIP的WIP
|
- 让麦麦玩mc:WIP的WIP的WIP
|
||||||
|
|
||||||
|
## 开发计划TODO:LIST
|
||||||
|
|
||||||
|
- 兼容gif的解析和保存
|
||||||
|
- 小程序转发链接解析
|
||||||
|
- 对思考链长度限制
|
||||||
|
- 修复已知bug
|
||||||
|
- 完善文档
|
||||||
|
- 修复转发
|
||||||
|
- config自动生成和检测
|
||||||
|
- log别用print
|
||||||
|
- 给发送消息写专门的类
|
||||||
|
- 改进表情包发送逻辑
|
||||||
|
|
||||||
## 📌 注意事项
|
## 📌 注意事项
|
||||||
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
||||||
|
|
||||||
@@ -218,3 +100,7 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|||||||
感谢各位大佬!
|
感谢各位大佬!
|
||||||
|
|
||||||
[](https://github.com/SengokuCola/MaiMBot/graphs/contributors)
|
[](https://github.com/SengokuCola/MaiMBot/graphs/contributors)
|
||||||
|
|
||||||
|
|
||||||
|
## Stargazers over time
|
||||||
|
[](https://starchart.cc/SengokuCola/MaiMBot)
|
||||||
13
bot.py
13
bot.py
@@ -15,25 +15,22 @@ for i, char in enumerate(text):
|
|||||||
print(rainbow_text)
|
print(rainbow_text)
|
||||||
'''彩蛋'''
|
'''彩蛋'''
|
||||||
|
|
||||||
# 首先加载基础环境变量
|
# 首先加载基础环境变量.env
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
logger.success("成功加载基础环境变量配置")
|
logger.success("成功加载基础环境变量配置")
|
||||||
else:
|
else:
|
||||||
logger.error("基础环境变量配置文件 .env 不存在")
|
logger.error("基础环境变量配置文件 .env 不存在")
|
||||||
exit(1)
|
exit(1)
|
||||||
# 根据 ENVIRONMENT 加载对应的环境配置
|
|
||||||
env = os.getenv("ENVIRONMENT")
|
if os.path.exists(".env.dev"):
|
||||||
env_file = f".env.{env}"
|
|
||||||
|
|
||||||
if env_file == ".env.dev" and os.path.exists(env_file):
|
|
||||||
logger.success("加载开发环境变量配置")
|
logger.success("加载开发环境变量配置")
|
||||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
elif os.path.exists(".env.prod"):
|
elif os.path.exists(".env.prod"):
|
||||||
logger.success("加载环境变量配置")
|
logger.success("加载环境变量配置")
|
||||||
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
else:
|
else:
|
||||||
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
logger.error(f".env对应的环境配置文件不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# 获取所有环境变量
|
# 获取所有环境变量
|
||||||
|
|||||||
12012
char_frequency.json
Normal file
12012
char_frequency.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,59 +3,69 @@ qq = 123
|
|||||||
nickname = "麦麦"
|
nickname = "麦麦"
|
||||||
|
|
||||||
[message]
|
[message]
|
||||||
min_text_length = 2
|
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
|
||||||
max_context_size = 15
|
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
|
||||||
emoji_chance = 0.2
|
emoji_chance = 0.2 # 麦麦使用表情包的概率
|
||||||
|
|
||||||
[emoji]
|
[emoji]
|
||||||
check_interval = 120
|
check_interval = 120 # 检查表情包的时间间隔
|
||||||
register_interval = 10
|
register_interval = 10 # 注册表情包的时间间隔
|
||||||
|
|
||||||
[cq_code]
|
[cq_code]
|
||||||
enable_pic_translate = false
|
enable_pic_translate = false
|
||||||
|
|
||||||
[response]
|
[response]
|
||||||
api_using = "siliconflow"
|
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
||||||
api_paid = true
|
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
||||||
model_r1_probability = 0.8
|
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
||||||
model_v3_probability = 0.1
|
|
||||||
model_r1_distill_probability = 0.1
|
|
||||||
|
|
||||||
[memory]
|
[memory]
|
||||||
build_memory_interval = 300
|
build_memory_interval = 300 # 记忆构建间隔 单位秒
|
||||||
|
|
||||||
[others]
|
[others]
|
||||||
enable_advance_output = true
|
enable_advance_output = true # 是否启用高级输出
|
||||||
|
enable_kuuki_read = true # 是否启用读空气功能
|
||||||
|
|
||||||
[groups]
|
[groups]
|
||||||
talk_allowed = [
|
talk_allowed = [
|
||||||
123,
|
123,
|
||||||
123,
|
123,
|
||||||
]
|
] #可以回复消息的群
|
||||||
talk_frequency_down = []
|
talk_frequency_down = [] #降低回复频率的群
|
||||||
ban_user_id = []
|
ban_user_id = [] #禁止回复消息的QQ号
|
||||||
|
|
||||||
[model.llm_reasoning]
|
|
||||||
|
#V3
|
||||||
|
#name = "deepseek-chat"
|
||||||
|
#base_url = "DEEP_SEEK_BASE_URL"
|
||||||
|
#key = "DEEP_SEEK_KEY"
|
||||||
|
|
||||||
|
#R1
|
||||||
|
#name = "deepseek-reasoner"
|
||||||
|
#base_url = "DEEP_SEEK_BASE_URL"
|
||||||
|
#key = "DEEP_SEEK_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning] #R1
|
||||||
name = "Pro/deepseek-ai/DeepSeek-R1"
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
key = "SILICONFLOW_KEY"
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
[model.llm_reasoning_minor]
|
[model.llm_reasoning_minor] #R1蒸馏
|
||||||
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
key = "SILICONFLOW_KEY"
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
[model.llm_normal]
|
[model.llm_normal] #V3
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
key = "SILICONFLOW_KEY"
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
[model.llm_normal_minor]
|
[model.llm_normal_minor] #V2.5
|
||||||
name = "deepseek-ai/DeepSeek-V2.5"
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
key = "SILICONFLOW_KEY"
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
[model.vlm]
|
[model.vlm] #图像识别
|
||||||
name = "deepseek-ai/deepseek-vl2"
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
key = "SILICONFLOW_KEY"
|
key = "SILICONFLOW_KEY"
|
||||||
|
|||||||
145
docs/installation.md
Normal file
145
docs/installation.md
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# 🔧 安装与配置指南
|
||||||
|
|
||||||
|
## 部署方式
|
||||||
|
|
||||||
|
如果你不知道Docker是什么,建议寻找相关教程或使用手动部署
|
||||||
|
|
||||||
|
### 🐳 Docker部署(推荐,但不一定是最新)
|
||||||
|
|
||||||
|
1. 获取配置文件:
|
||||||
|
```bash
|
||||||
|
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动服务:
|
||||||
|
```bash
|
||||||
|
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 修改配置后重启:
|
||||||
|
```bash
|
||||||
|
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📦 手动部署
|
||||||
|
|
||||||
|
1. **环境准备**
|
||||||
|
```bash
|
||||||
|
# 创建虚拟环境(推荐)
|
||||||
|
python -m venv venv
|
||||||
|
venv\\Scripts\\activate # Windows
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **配置MongoDB**
|
||||||
|
- 安装并启动MongoDB服务
|
||||||
|
- 默认连接本地27017端口
|
||||||
|
|
||||||
|
3. **配置NapCat**
|
||||||
|
- 安装并登录NapCat
|
||||||
|
- 添加反向WS:`ws://localhost:8080/onebot/v11/ws`
|
||||||
|
|
||||||
|
4. **配置文件设置**
|
||||||
|
- 修改环境配置文件:`.env.prod`
|
||||||
|
- 修改机器人配置文件:`bot_config.toml`
|
||||||
|
|
||||||
|
5. **启动麦麦机器人**
|
||||||
|
- 打开命令行,cd到对应路径
|
||||||
|
```bash
|
||||||
|
nb run
|
||||||
|
```
|
||||||
|
|
||||||
|
6. **其他组件**
|
||||||
|
- `run_thingking.bat`: 启动可视化推理界面(未完善)
|
||||||
|
|
||||||
|
- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~
|
||||||
|
- 直接运行 knowledge.py生成知识库
|
||||||
|
|
||||||
|
## ⚙️ 配置说明
|
||||||
|
|
||||||
|
### 环境配置 (.env.prod)
|
||||||
|
```ini
|
||||||
|
# API配置,你可以在这里定义你的密钥和base_url
|
||||||
|
# 你可以选择定义其他服务商提供的KEY,完全可以自定义
|
||||||
|
SILICONFLOW_KEY=your_key
|
||||||
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
|
DEEP_SEEK_KEY=your_key
|
||||||
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
# 服务配置,如果你不知道这是什么,保持默认
|
||||||
|
HOST=127.0.0.1
|
||||||
|
PORT=8080
|
||||||
|
|
||||||
|
# 数据库配置,如果你不知道这是什么,保持默认
|
||||||
|
MONGODB_HOST=127.0.0.1
|
||||||
|
MONGODB_PORT=27017
|
||||||
|
DATABASE_NAME=MegBot
|
||||||
|
```
|
||||||
|
|
||||||
|
### 机器人配置 (bot_config.toml)
|
||||||
|
```toml
|
||||||
|
[bot]
|
||||||
|
qq = "你的机器人QQ号"
|
||||||
|
nickname = "麦麦"
|
||||||
|
|
||||||
|
[message]
|
||||||
|
min_text_length = 2
|
||||||
|
max_context_size = 15
|
||||||
|
emoji_chance = 0.2
|
||||||
|
|
||||||
|
[emoji]
|
||||||
|
check_interval = 120
|
||||||
|
register_interval = 10
|
||||||
|
|
||||||
|
[cq_code]
|
||||||
|
enable_pic_translate = false
|
||||||
|
|
||||||
|
[response]
|
||||||
|
#现已移除deepseek或硅基流动选项,可以直接切换分别配置任意模型
|
||||||
|
model_r1_probability = 0.8 #推理模型权重
|
||||||
|
model_v3_probability = 0.1 #非推理模型权重
|
||||||
|
model_r1_distill_probability = 0.1
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300
|
||||||
|
|
||||||
|
[others]
|
||||||
|
enable_advance_output = true # 是否启用详细日志输出
|
||||||
|
|
||||||
|
[groups]
|
||||||
|
talk_allowed = [] # 允许回复的群号列表
|
||||||
|
talk_frequency_down = [] # 降低回复频率的群号列表
|
||||||
|
ban_user_id = [] # 禁止回复的用户QQ号列表
|
||||||
|
|
||||||
|
[model.llm_reasoning]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.vlm]
|
||||||
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚠️ 注意事项
|
||||||
|
|
||||||
|
- 目前部署方案仍在测试中,可能存在未知问题
|
||||||
|
- 配置文件中的API密钥请妥善保管,不要泄露
|
||||||
|
- 建议先在测试环境中运行,确认无误后再部署到生产环境
|
||||||
BIN
docs/video.png
Normal file
BIN
docs/video.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -1,3 +1,4 @@
|
|||||||
|
chcp 65001
|
||||||
call conda activate niuniu
|
call conda activate niuniu
|
||||||
cd .
|
cd .
|
||||||
|
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ driver = get_driver()
|
|||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source= config.mongodb_auth_source
|
auth_source= config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,13 @@ class ChatBot:
|
|||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||||
|
|
||||||
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||||
|
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
|
||||||
|
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
|
||||||
|
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
|
||||||
|
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
|
||||||
|
topic = topic3
|
||||||
|
|
||||||
all_num = 0
|
all_num = 0
|
||||||
interested_num = 0
|
interested_num = 0
|
||||||
@@ -166,7 +171,6 @@ class ChatBot:
|
|||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
message_based_id=event.message_id,
|
|
||||||
raw_message=msg,
|
raw_message=msg,
|
||||||
plain_text=msg,
|
plain_text=msg,
|
||||||
processed_plain_text=msg,
|
processed_plain_text=msg,
|
||||||
|
|||||||
@@ -116,6 +116,9 @@ class BotConfig:
|
|||||||
|
|
||||||
if "vlm" in model_config:
|
if "vlm" in model_config:
|
||||||
config.vlm = model_config["vlm"]
|
config.vlm = model_config["vlm"]
|
||||||
|
|
||||||
|
if "embedding" in model_config:
|
||||||
|
config.embedding = model_config["embedding"]
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
@@ -138,7 +141,7 @@ class BotConfig:
|
|||||||
if "others" in toml_dict:
|
if "others" in toml_dict:
|
||||||
others_config = toml_dict["others"]
|
others_config = toml_dict["others"]
|
||||||
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
||||||
|
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
|
||||||
|
|
||||||
logger.success(f"成功加载配置文件: {config_path}")
|
logger.success(f"成功加载配置文件: {config_path}")
|
||||||
|
|
||||||
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
|||||||
if not os.path.exists(bot_config_path):
|
if not os.path.exists(bot_config_path):
|
||||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||||
logger.info("使用默认配置文件")
|
logger.info("使用bot配置文件")
|
||||||
else:
|
else:
|
||||||
logger.info("已找到开发环境配置文件")
|
logger.info("已找到开发bot配置文件")
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMConfig:
|
|
||||||
"""机器人配置类"""
|
|
||||||
# 基础配置
|
|
||||||
SILICONFLOW_API_KEY: str = None
|
|
||||||
SILICONFLOW_BASE_URL: str = None
|
|
||||||
DEEP_SEEK_API_KEY: str = None
|
|
||||||
DEEP_SEEK_BASE_URL: str = None
|
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
|
||||||
config = get_driver().config
|
|
||||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
|
||||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
|
||||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
|
||||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
|
||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# logger.remove()
|
# logger.remove()
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from ...common.database import Database
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
import urllib3
|
import urllib3
|
||||||
from .utils_user import get_user_nickname,get_user_cardname
|
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
|
||||||
from .utils_cq import parse_cq_code
|
from .utils_cq import parse_cq_code
|
||||||
from .cq_code import cq_code_tool,CQCode
|
from .cq_code import cq_code_tool,CQCode
|
||||||
|
|
||||||
@@ -21,50 +21,47 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||||||
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message:
|
class Message:
|
||||||
"""消息数据类"""
|
"""消息数据类"""
|
||||||
|
message_id: int = None
|
||||||
|
time: float = None
|
||||||
|
|
||||||
group_id: int = None
|
group_id: int = None
|
||||||
|
group_name: str = None # 群名称
|
||||||
|
|
||||||
user_id: int = None
|
user_id: int = None
|
||||||
user_nickname: str = None # 用户昵称
|
user_nickname: str = None # 用户昵称
|
||||||
user_cardname: str=None # 用户群昵称
|
user_cardname: str=None # 用户群昵称
|
||||||
group_name: str = None # 群名称
|
|
||||||
|
|
||||||
message_id: int = None
|
raw_message: str = None # 原始消息,包含未解析的cq码
|
||||||
raw_message: str = None
|
plain_text: str = None # 纯文本
|
||||||
plain_text: str = None
|
|
||||||
|
|
||||||
message_based_id: int = None
|
|
||||||
reply_message: Dict = None # 存储回复消息
|
|
||||||
|
|
||||||
message_segments: List[Dict] = None # 存储解析后的消息片段
|
message_segments: List[Dict] = None # 存储解析后的消息片段
|
||||||
processed_plain_text: str = None # 用于存储处理后的plain_text
|
processed_plain_text: str = None # 用于存储处理后的plain_text
|
||||||
detailed_plain_text: str = None # 用于存储详细可读文本
|
detailed_plain_text: str = None # 用于存储详细可读文本
|
||||||
|
|
||||||
time: float = None
|
reply_message: Dict = None # 存储 回复的 源消息
|
||||||
|
|
||||||
is_emoji: bool = False # 是否是表情包
|
is_emoji: bool = False # 是否是表情包
|
||||||
has_emoji: bool = False # 是否包含表情包
|
has_emoji: bool = False # 是否包含表情包
|
||||||
|
|
||||||
translate_cq: bool = True # 是否翻译cq码
|
translate_cq: bool = True # 是否翻译cq码
|
||||||
|
|
||||||
|
|
||||||
reply_benefits: float = 0.0
|
|
||||||
|
|
||||||
type: str = 'received' # 消息类型,可以是received或者send
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.time is None:
|
if self.time is None:
|
||||||
self.time = int(time.time())
|
self.time = int(time.time())
|
||||||
|
|
||||||
|
if not self.group_name:
|
||||||
|
self.group_name = get_groupname(self.group_id)
|
||||||
|
|
||||||
if not self.user_nickname:
|
if not self.user_nickname:
|
||||||
self.user_nickname = get_user_nickname(self.user_id)
|
self.user_nickname = get_user_nickname(self.user_id)
|
||||||
|
|
||||||
if not self.user_cardname:
|
if not self.user_cardname:
|
||||||
self.user_cardname=get_user_cardname(self.user_id)
|
self.user_cardname=get_user_cardname(self.user_id)
|
||||||
|
|
||||||
if not self.group_name:
|
|
||||||
self.group_name = self.get_groupname(self.group_id)
|
|
||||||
|
|
||||||
if not self.processed_plain_text:
|
if not self.processed_plain_text:
|
||||||
if self.raw_message:
|
if self.raw_message:
|
||||||
self.message_segments = self.parse_message_segments(str(self.raw_message))
|
self.message_segments = self.parse_message_segments(str(self.raw_message))
|
||||||
@@ -244,6 +241,38 @@ class MessageSet:
|
|||||||
return len(self.messages)
|
return len(self.messages)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message_Sending(Message):
|
||||||
|
"""发送消息数据类,继承自Message类"""
|
||||||
|
|
||||||
|
priority: int = 0 # 发送优先级,数字越大优先级越高
|
||||||
|
wait_until: float = None # 等待发送的时间戳
|
||||||
|
continue_thinking: bool = False # 是否继续思考
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.wait_until is None:
|
||||||
|
self.wait_until = self.time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_send(self) -> bool:
|
||||||
|
"""检查是否可以发送消息"""
|
||||||
|
return time.time() >= self.wait_until
|
||||||
|
|
||||||
|
def set_wait_time(self, seconds: float) -> None:
|
||||||
|
"""设置等待发送时间"""
|
||||||
|
self.wait_until = time.time() + seconds
|
||||||
|
|
||||||
|
def set_priority(self, priority: int) -> None:
|
||||||
|
"""设置发送优先级"""
|
||||||
|
self.priority = priority
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
"""重写小于比较,用于优先级排序"""
|
||||||
|
if not isinstance(other, Message_Sending):
|
||||||
|
return NotImplemented
|
||||||
|
return (self.priority, -self.wait_until) < (other.priority, -other.wait_until)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ class MessageSendControl:
|
|||||||
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
|
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
|
||||||
cost_time = round(time.time(), 2) - message.time
|
cost_time = round(time.time(), 2) - message.time
|
||||||
if cost_time > 40:
|
if cost_time > 40:
|
||||||
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
|
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
await self._current_bot.send_group_msg(
|
await self._current_bot.send_group_msg(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
|
|||||||
0
src/plugins/chat/message_sender.py
Normal file
0
src/plugins/chat/message_sender.py
Normal file
@@ -127,15 +127,15 @@ class MessageStream:
|
|||||||
# 从数据库中查询最近的消息
|
# 从数据库中查询最近的消息
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": self.group_id},
|
{"group_id": self.group_id},
|
||||||
{
|
# {
|
||||||
"time": 1,
|
# "time": 1,
|
||||||
"user_id": 1,
|
# "user_id": 1,
|
||||||
"user_nickname": 1,
|
# "user_nickname": 1,
|
||||||
# "user_cardname": 1,
|
# # "user_cardname": 1,
|
||||||
"message_id": 1,
|
# "message_id": 1,
|
||||||
"raw_message": 1,
|
# "raw_message": 1,
|
||||||
"processed_text": 1
|
# "processed_text": 1
|
||||||
}
|
# }
|
||||||
).sort("time", -1).limit(count))
|
).sort("time", -1).limit(count))
|
||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
@@ -145,17 +145,21 @@ class MessageStream:
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
messages = []
|
messages = []
|
||||||
for msg_data in recent_messages:
|
for msg_data in recent_messages:
|
||||||
msg = Message(
|
try:
|
||||||
time=msg_data["time"],
|
msg = Message(
|
||||||
user_id=msg_data["user_id"],
|
time=msg_data["time"],
|
||||||
user_nickname=msg_data.get("user_nickname", ""),
|
user_id=msg_data["user_id"],
|
||||||
user_cardname=msg_data.get("user_cardname", ""),
|
user_nickname=msg_data.get("user_nickname", ""),
|
||||||
message_id=msg_data["message_id"],
|
user_cardname=msg_data.get("user_cardname", ""),
|
||||||
raw_message=msg_data["raw_message"],
|
message_id=msg_data["message_id"],
|
||||||
processed_plain_text=msg_data.get("processed_text", ""),
|
raw_message=msg_data["raw_message"],
|
||||||
group_id=self.group_id
|
processed_plain_text=msg_data.get("processed_text", ""),
|
||||||
)
|
group_id=self.group_id
|
||||||
messages.append(msg)
|
)
|
||||||
|
messages.append(msg)
|
||||||
|
except KeyError:
|
||||||
|
print("[WARNING] 数据库中存在无效的消息")
|
||||||
|
continue
|
||||||
|
|
||||||
return list(reversed(messages)) # 返回按时间正序的消息
|
return list(reversed(messages)) # 返回按时间正序的消息
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ class PromptBuilder:
|
|||||||
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
||||||
promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|||||||
14
src/plugins/chat/thinking_idea.py
Normal file
14
src/plugins/chat/thinking_idea.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#Broca's Area
|
||||||
|
# 功能:语言产生、语法处理和言语运动控制。
|
||||||
|
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Thinking_Idea:
|
||||||
|
def __init__(self, message_id: str):
|
||||||
|
self.messages = [] # 消息列表集合
|
||||||
|
self.current_thoughts = [] # 当前思考内容列表
|
||||||
|
self.time = time.time() # 创建时间
|
||||||
|
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID
|
||||||
|
|
||||||
@@ -4,6 +4,8 @@ from .message import Message
|
|||||||
import jieba
|
import jieba
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from snownlp import SnowNLP
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -11,12 +13,10 @@ config = driver.config
|
|||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = OpenAI(
|
self.llm_client = LLM_request(model=global_config.llm_normal)
|
||||||
api_key=config.siliconflow_key, base_url=config.siliconflow_base_url
|
|
||||||
)
|
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
||||||
|
"""识别消息主题,返回主题列表"""
|
||||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
|
||||||
"""识别消息主题"""
|
|
||||||
|
|
||||||
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:\
|
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:\
|
||||||
1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。\
|
1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。\
|
||||||
@@ -24,36 +24,20 @@ class TopicIdentifier:
|
|||||||
3. 这里是
|
3. 这里是
|
||||||
消息内容:{text}"""
|
消息内容:{text}"""
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
# 使用 LLM_request 类进行请求
|
||||||
model=global_config.SILICONFLOW_MODEL_V3,
|
topic, _ = await self.llm_client.generate_response(prompt)
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
temperature=0.8,
|
if not topic:
|
||||||
max_tokens=10,
|
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
|
||||||
)
|
|
||||||
|
|
||||||
if not response or not response.choices:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
|
# 直接在这里处理主题解析
|
||||||
topic = (
|
|
||||||
response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].message.content
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if topic == "无主题":
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# print(f"[主题分析结果]{text[:20]}... : {topic}")
|
|
||||||
split_topic = self.parse_topic(topic)
|
|
||||||
return split_topic
|
|
||||||
|
|
||||||
def parse_topic(self, topic: str) -> List[str]:
|
|
||||||
"""解析主题,返回主题列表"""
|
|
||||||
if not topic or topic == "无主题":
|
if not topic or topic == "无主题":
|
||||||
return []
|
return None
|
||||||
return [t.strip() for t in topic.split(",") if t.strip()]
|
|
||||||
|
# 解析主题字符串为列表
|
||||||
|
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
|
||||||
|
return topic_list if topic_list else None
|
||||||
|
|
||||||
def identify_topic_jieba(self, text: str) -> Optional[str]:
|
def identify_topic_jieba(self, text: str) -> Optional[str]:
|
||||||
"""使用jieba识别主题"""
|
"""使用jieba识别主题"""
|
||||||
@@ -239,33 +223,12 @@ class TopicIdentifier:
|
|||||||
filtered_words = []
|
filtered_words = []
|
||||||
for word in words:
|
for word in words:
|
||||||
if word not in stop_words and not word.strip() in {
|
if word not in stop_words and not word.strip() in {
|
||||||
"。",
|
'。', ',', '、', ':', ';', '!', '?', '"', '"', ''', ''',
|
||||||
",",
|
'(', ')', '【', '】', '《', '》', '…', '—', '·', '、', '~',
|
||||||
"、",
|
'~', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
|
||||||
":",
|
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
|
||||||
";",
|
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '≠',
|
||||||
"!",
|
'≈', '∈', '∉', '⊆', '⊇', '⊂', '⊃', '∪', '∩', '∧', '∨'
|
||||||
"?",
|
|
||||||
'"',
|
|
||||||
'"',
|
|
||||||
""", """,
|
|
||||||
"(",
|
|
||||||
")",
|
|
||||||
"【",
|
|
||||||
"】",
|
|
||||||
"《",
|
|
||||||
"》",
|
|
||||||
"…",
|
|
||||||
"—",
|
|
||||||
"·",
|
|
||||||
"、",
|
|
||||||
"~",
|
|
||||||
"~",
|
|
||||||
"+",
|
|
||||||
"=",
|
|
||||||
"-",
|
|
||||||
"[",
|
|
||||||
"]",
|
|
||||||
}:
|
}:
|
||||||
filtered_words.append(word)
|
filtered_words.append(word)
|
||||||
|
|
||||||
@@ -280,5 +243,25 @@ class TopicIdentifier:
|
|||||||
|
|
||||||
return top_words if top_words else None
|
return top_words if top_words else None
|
||||||
|
|
||||||
|
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
|
||||||
|
"""使用 SnowNLP 进行主题识别
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): 需要识别主题的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
|
||||||
|
"""
|
||||||
|
if not text or len(text.strip()) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = SnowNLP(text)
|
||||||
|
# 提取前3个关键词作为主题
|
||||||
|
keywords = s.keywords(3)
|
||||||
|
return keywords if keywords else None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
topic_identifier = TopicIdentifier()
|
topic_identifier = TopicIdentifier()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Dict
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import math
|
import math
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_embedding(text):
|
def get_embedding(text):
|
||||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
"""获取文本的embedding向量"""
|
||||||
payload = {
|
llm = LLM_request(model=global_config.embedding)
|
||||||
"model": "BAAI/bge-m3",
|
return llm.get_embedding_sync(text)
|
||||||
"input": text,
|
|
||||||
"encoding_format": "float"
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"API请求失败: {response.status_code}")
|
|
||||||
print(f"错误信息: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response.json()['data'][0]['embedding']
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
@@ -142,14 +127,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
# 从数据库获取最近消息
|
# 从数据库获取最近消息
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"group_id": group_id},
|
||||||
{
|
# {
|
||||||
"time": 1,
|
# "time": 1,
|
||||||
"user_id": 1,
|
# "user_id": 1,
|
||||||
"user_nickname": 1,
|
# "user_nickname": 1,
|
||||||
"message_id": 1,
|
# "message_id": 1,
|
||||||
"raw_message": 1,
|
# "raw_message": 1,
|
||||||
"processed_text": 1
|
# "processed_text": 1
|
||||||
}
|
# }
|
||||||
).sort("time", -1).limit(limit))
|
).sort("time", -1).limit(limit))
|
||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
@@ -159,16 +144,20 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
message_objects = []
|
message_objects = []
|
||||||
for msg_data in recent_messages:
|
for msg_data in recent_messages:
|
||||||
msg = Message(
|
try:
|
||||||
time=msg_data["time"],
|
msg = Message(
|
||||||
user_id=msg_data["user_id"],
|
time=msg_data["time"],
|
||||||
user_nickname=msg_data.get("user_nickname", ""),
|
user_id=msg_data["user_id"],
|
||||||
message_id=msg_data["message_id"],
|
user_nickname=msg_data.get("user_nickname", ""),
|
||||||
raw_message=msg_data["raw_message"],
|
message_id=msg_data["message_id"],
|
||||||
processed_plain_text=msg_data.get("processed_text", ""),
|
raw_message=msg_data["raw_message"],
|
||||||
group_id=group_id
|
processed_plain_text=msg_data.get("processed_text", ""),
|
||||||
)
|
group_id=group_id
|
||||||
message_objects.append(msg)
|
)
|
||||||
|
message_objects.append(msg)
|
||||||
|
except KeyError:
|
||||||
|
print("[WARNING] 数据库中存在无效的消息")
|
||||||
|
continue
|
||||||
|
|
||||||
# 按时间正序排列
|
# 按时间正序排列
|
||||||
message_objects.reverse()
|
message_objects.reverse()
|
||||||
@@ -181,7 +170,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
|||||||
"time": 1, # 返回时间字段
|
"time": 1, # 返回时间字段
|
||||||
"user_id": 1, # 返回用户ID字段
|
"user_id": 1, # 返回用户ID字段
|
||||||
"user_nickname": 1, # 返回用户昵称字段
|
"user_nickname": 1, # 返回用户昵称字段
|
||||||
"user_cardname": 1, #返回用户群昵称
|
|
||||||
"message_id": 1, # 返回消息ID字段
|
"message_id": 1, # 返回消息ID字段
|
||||||
"detailed_plain_text": 1 # 返回处理后的文本字段
|
"detailed_plain_text": 1 # 返回处理后的文本字段
|
||||||
}
|
}
|
||||||
@@ -193,6 +181,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
|||||||
message_detailed_plain_text = ''
|
message_detailed_plain_text = ''
|
||||||
message_detailed_plain_text_list = []
|
message_detailed_plain_text_list = []
|
||||||
|
|
||||||
|
# 反转消息列表,使最新的消息在最后
|
||||||
|
recent_messages.reverse()
|
||||||
|
|
||||||
if combine:
|
if combine:
|
||||||
for msg_db_data in recent_messages:
|
for msg_db_data in recent_messages:
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ def get_user_nickname(user_id: int) -> str:
|
|||||||
return global_config.BOT_NICKNAME
|
return global_config.BOT_NICKNAME
|
||||||
# print(user_id)
|
# print(user_id)
|
||||||
return relationship_manager.get_name(user_id)
|
return relationship_manager.get_name(user_id)
|
||||||
|
|
||||||
def get_user_cardname(user_id: int) -> str:
|
def get_user_cardname(user_id: int) -> str:
|
||||||
if int(user_id) == int(global_config.BOT_QQ):
|
if int(user_id) == int(global_config.BOT_QQ):
|
||||||
return global_config.BOT_NICKNAME
|
return global_config.BOT_NICKNAME
|
||||||
# print(user_id)
|
# print(user_id)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
def get_groupname(group_id: int) -> str:
|
||||||
|
return f"群{group_id}"
|
||||||
@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
|
|||||||
|
|
||||||
# 直接配置数据库连接信息
|
# 直接配置数据库连接信息
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
class KnowledgeLibrary:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import jieba
|
import jieba
|
||||||
from llm_module import LLMModel
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
import math
|
||||||
@@ -10,10 +9,76 @@ from collections import Counter
|
|||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
# from chat.config import global_config
|
from dotenv import load_dotenv
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
|
||||||
|
# 加载.env.dev文件
|
||||||
|
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||||
|
load_dotenv(env_path)
|
||||||
|
|
||||||
|
class LLMModel:
|
||||||
|
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.params = kwargs
|
||||||
|
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||||
|
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的chat/completions端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -158,12 +223,12 @@ class Memory_graph:
|
|||||||
def main():
|
def main():
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username=os.getenv("MONGODB_USERNAME", ""),
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password=os.getenv("MONGODB_PASSWORD", ""),
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
@@ -185,11 +250,14 @@ def main():
|
|||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
if query.lower() == '退出':
|
if query.lower() == '退出':
|
||||||
break
|
break
|
||||||
items_list = memory_graph.get_related_item(query)
|
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||||
if items_list:
|
if first_layer_items or second_layer_items:
|
||||||
# print(items_list)
|
print("\n第一层记忆:")
|
||||||
for memory_item in items_list:
|
for item in first_layer_items:
|
||||||
print(memory_item)
|
print(item)
|
||||||
|
print("\n第二层记忆:")
|
||||||
|
for item in second_layer_items:
|
||||||
|
print(item)
|
||||||
else:
|
else:
|
||||||
print("未找到相关记忆。")
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class LLMModel:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|||||||
@@ -259,12 +259,12 @@ config = driver.config
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= config.MONGODB_PORT,
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
#创建记忆图
|
#创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ driver = get_driver()
|
|||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
class LLM_request:
|
class LLM_request:
|
||||||
def __init__(self, model = global_config.llm_normal,**kwargs):
|
def __init__(self, model ,**kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = getattr(config, model["key"])
|
self.api_key = getattr(config, model["key"])
|
||||||
@@ -61,7 +61,7 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
@@ -126,7 +126,7 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
@@ -166,8 +166,8 @@ class LLM_request:
|
|||||||
# 发送请求到完整的chat/completions端点
|
# 发送请求到完整的chat/completions端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 2
|
||||||
base_wait_time = 15
|
base_wait_time = 6
|
||||||
|
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
@@ -191,9 +191,119 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""同步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
max_retries = 2
|
||||||
|
base_wait_time = 6
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
print(f"embedding请求失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""异步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
print(f"embedding请求失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ config = driver.config
|
|||||||
|
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
@@ -128,6 +128,10 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
def _time_diff(self, time1: str, time2: str) -> int:
|
def _time_diff(self, time1: str, time2: str) -> int:
|
||||||
"""计算两个时间字符串之间的分钟差"""
|
"""计算两个时间字符串之间的分钟差"""
|
||||||
|
if time1=="24:00":
|
||||||
|
time1="23:59"
|
||||||
|
if time2=="24:00":
|
||||||
|
time2="23:59"
|
||||||
t1 = datetime.datetime.strptime(time1, "%H:%M")
|
t1 = datetime.datetime.strptime(time1, "%H:%M")
|
||||||
t2 = datetime.datetime.strptime(time2, "%H:%M")
|
t2 = datetime.datetime.strptime(time2, "%H:%M")
|
||||||
diff = int((t2 - t1).total_seconds() / 60)
|
diff = int((t2 - t1).total_seconds() / 60)
|
||||||
@@ -165,4 +169,4 @@ class ScheduleGenerator:
|
|||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
# main()
|
# main()
|
||||||
|
|
||||||
bot_schedule = ScheduleGenerator()
|
bot_schedule = ScheduleGenerator()
|
||||||
|
|||||||
53
src/test/emotion_cal_snownlp.py
Normal file
53
src/test/emotion_cal_snownlp.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from snownlp import SnowNLP
|
||||||
|
|
||||||
|
def analyze_emotion_snownlp(text):
|
||||||
|
"""
|
||||||
|
使用SnowNLP进行中文情感分析
|
||||||
|
:param text: 输入文本
|
||||||
|
:return: 情感得分(0-1之间,越接近1越积极)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
s = SnowNLP(text)
|
||||||
|
sentiment_score = s.sentiments
|
||||||
|
|
||||||
|
# 获取文本的关键词
|
||||||
|
keywords = s.keywords(3)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sentiment_score': sentiment_score,
|
||||||
|
'keywords': keywords,
|
||||||
|
'summary': s.summary(1) # 生成文本摘要
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分析过程中出现错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_emotion_description_snownlp(score):
|
||||||
|
"""
|
||||||
|
将情感得分转换为描述性文字
|
||||||
|
"""
|
||||||
|
if score is None:
|
||||||
|
return "无法分析情感"
|
||||||
|
|
||||||
|
if score > 0.8:
|
||||||
|
return "非常积极"
|
||||||
|
elif score > 0.6:
|
||||||
|
return "较为积极"
|
||||||
|
elif score > 0.4:
|
||||||
|
return "中性偏积极"
|
||||||
|
elif score > 0.2:
|
||||||
|
return "中性偏消极"
|
||||||
|
else:
|
||||||
|
return "消极"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试样例
|
||||||
|
test_text = "我们学校有免费的gpt4用"
|
||||||
|
result = analyze_emotion_snownlp(test_text)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
print(f"测试文本: {test_text}")
|
||||||
|
print(f"情感得分: {result['sentiment_score']:.2f}")
|
||||||
|
print(f"情感倾向: {get_emotion_description_snownlp(result['sentiment_score'])}")
|
||||||
|
print(f"关键词: {', '.join(result['keywords'])}")
|
||||||
|
print(f"文本摘要: {result['summary'][0]}")
|
||||||
54
src/test/snownlp_demo.py
Normal file
54
src/test/snownlp_demo.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from snownlp import SnowNLP
|
||||||
|
|
||||||
|
def demo_snownlp_features(text):
|
||||||
|
"""
|
||||||
|
展示SnowNLP的主要功能
|
||||||
|
:param text: 输入文本
|
||||||
|
"""
|
||||||
|
print(f"\n=== SnowNLP功能演示 ===")
|
||||||
|
print(f"输入文本: {text}")
|
||||||
|
|
||||||
|
# 创建SnowNLP对象
|
||||||
|
s = SnowNLP(text)
|
||||||
|
|
||||||
|
# 1. 分词
|
||||||
|
print(f"\n1. 分词结果:")
|
||||||
|
print(f" {' | '.join(s.words)}")
|
||||||
|
|
||||||
|
# 2. 情感分析
|
||||||
|
print(f"\n2. 情感分析:")
|
||||||
|
sentiment = s.sentiments
|
||||||
|
print(f" 情感得分: {sentiment:.2f}")
|
||||||
|
print(f" 情感倾向: {'积极' if sentiment > 0.5 else '消极' if sentiment < 0.5 else '中性'}")
|
||||||
|
|
||||||
|
# 3. 关键词提取
|
||||||
|
print(f"\n3. 关键词提取:")
|
||||||
|
print(f" {', '.join(s.keywords(3))}")
|
||||||
|
|
||||||
|
# 4. 词性标注
|
||||||
|
print(f"\n4. 词性标注:")
|
||||||
|
print(f" {' '.join([f'{word}/{tag}' for word, tag in s.tags])}")
|
||||||
|
|
||||||
|
# 5. 拼音转换
|
||||||
|
print(f"\n5. 拼音:")
|
||||||
|
print(f" {' '.join(s.pinyin)}")
|
||||||
|
|
||||||
|
# 6. 文本摘要
|
||||||
|
if len(text) > 100: # 只对较长文本生成摘要
|
||||||
|
print(f"\n6. 文本摘要:")
|
||||||
|
print(f" {' '.join(s.summary(3))}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试用例
|
||||||
|
test_texts = [
|
||||||
|
"这家新开的餐厅很不错,菜品种类丰富,味道可口,服务态度也很好,价格实惠,强烈推荐大家来尝试!",
|
||||||
|
"这部电影剧情混乱,演技浮夸,特效粗糙,配乐难听,完全浪费了我的时间和票价。",
|
||||||
|
"""人工智能正在改变我们的生活方式。它能够帮助我们完成复杂的计算任务,
|
||||||
|
提供个性化的服务推荐,优化交通路线,辅助医疗诊断。但同时我们也要警惕
|
||||||
|
人工智能带来的问题,比如隐私安全、就业变化等。如何正确认识和利用人工智能,
|
||||||
|
是我们每个人都需要思考的问题。"""
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in test_texts:
|
||||||
|
demo_snownlp_features(text)
|
||||||
|
print("\n" + "="*50)
|
||||||
488
src/test/typo.py
Normal file
488
src/test/typo.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
错别字生成器 - 流程说明
|
||||||
|
|
||||||
|
整体替换逻辑:
|
||||||
|
1. 数据准备
|
||||||
|
- 加载字频词典:使用jieba词典计算汉字使用频率
|
||||||
|
- 创建拼音映射:建立拼音到汉字的映射关系
|
||||||
|
- 加载词频信息:从jieba词典获取词语使用频率
|
||||||
|
|
||||||
|
2. 分词处理
|
||||||
|
- 使用jieba将输入句子分词
|
||||||
|
- 区分单字词和多字词
|
||||||
|
- 保留标点符号和空格
|
||||||
|
|
||||||
|
3. 词语级别替换(针对多字词)
|
||||||
|
- 触发条件:词长>1 且 随机概率<0.3
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取词语拼音
|
||||||
|
b. 生成所有可能的同音字组合
|
||||||
|
c. 过滤条件:
|
||||||
|
- 必须是jieba词典中的有效词
|
||||||
|
- 词频必须达到原词频的10%以上
|
||||||
|
- 综合评分(词频70%+字频30%)必须达到阈值
|
||||||
|
d. 按综合评分排序,选择最合适的替换词
|
||||||
|
|
||||||
|
4. 字级别替换(针对单字词或未进行整词替换的多字词)
|
||||||
|
- 单字替换概率:0.3
|
||||||
|
- 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1))
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取字的拼音
|
||||||
|
b. 声调错误处理(20%概率)
|
||||||
|
c. 获取同音字列表
|
||||||
|
d. 过滤条件:
|
||||||
|
- 字频必须达到最小阈值
|
||||||
|
- 频率差异不能过大(指数衰减计算)
|
||||||
|
e. 按频率排序选择替换字
|
||||||
|
|
||||||
|
5. 频率控制机制
|
||||||
|
- 字频控制:使用归一化的字频(0-1000范围)
|
||||||
|
- 词频控制:使用jieba词典中的词频
|
||||||
|
- 频率差异计算:使用指数衰减函数
|
||||||
|
- 最小频率阈值:确保替换字/词不会太生僻
|
||||||
|
|
||||||
|
6. 输出信息
|
||||||
|
- 原文和错字版本的对照
|
||||||
|
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
|
||||||
|
- 替换类型说明(整词替换/声调错误/同音字替换)
|
||||||
|
- 词语分析和完整拼音
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
1. 所有替换都必须使用有意义的词语
|
||||||
|
2. 替换词的使用频率不能过低
|
||||||
|
3. 多字词优先考虑整词替换
|
||||||
|
4. 考虑声调变化的情况
|
||||||
|
5. 保持标点符号和空格不变
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pypinyin import pinyin, Style
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import jieba.posseg as pseg
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
def load_or_create_char_frequency():
|
||||||
|
"""
|
||||||
|
加载或创建汉字频率字典
|
||||||
|
"""
|
||||||
|
cache_file = Path("char_frequency.json")
|
||||||
|
|
||||||
|
# 如果缓存文件存在,直接加载
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# 使用内置的词频文件
|
||||||
|
char_freq = defaultdict(int)
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
|
||||||
|
# 读取jieba的词典文件
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
word, freq = line.strip().split()[:2]
|
||||||
|
# 对词中的每个字进行频率累加
|
||||||
|
for char in word:
|
||||||
|
if is_chinese_char(char):
|
||||||
|
char_freq[char] += int(freq)
|
||||||
|
|
||||||
|
# 归一化频率值
|
||||||
|
max_freq = max(char_freq.values())
|
||||||
|
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||||
|
|
||||||
|
# 保存到缓存文件
|
||||||
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return normalized_freq
|
||||||
|
|
||||||
|
# 创建拼音到汉字的映射字典
|
||||||
|
def create_pinyin_dict():
|
||||||
|
"""
|
||||||
|
创建拼音到汉字的映射字典
|
||||||
|
"""
|
||||||
|
# 常用汉字范围
|
||||||
|
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||||
|
pinyin_dict = defaultdict(list)
|
||||||
|
|
||||||
|
# 为每个汉字建立拼音映射
|
||||||
|
for char in chars:
|
||||||
|
try:
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
pinyin_dict[py].append(char)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
|
def is_chinese_char(char):
|
||||||
|
"""
|
||||||
|
判断是否为汉字
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pinyin(sentence):
|
||||||
|
"""
|
||||||
|
将中文句子拆分成单个汉字并获取其拼音
|
||||||
|
:param sentence: 输入的中文句子
|
||||||
|
:return: 每个汉字及其拼音的列表
|
||||||
|
"""
|
||||||
|
# 将句子拆分成单个字符
|
||||||
|
characters = list(sentence)
|
||||||
|
|
||||||
|
# 获取每个字符的拼音
|
||||||
|
result = []
|
||||||
|
for char in characters:
|
||||||
|
# 跳过空格和非汉字字符
|
||||||
|
if char.isspace() or not is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
# 获取拼音(数字声调)
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
result.append((char, py))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取同音字,按照使用频率排序
|
||||||
|
"""
|
||||||
|
homophones = pinyin_dict[py]
|
||||||
|
# 移除原字并过滤低频字
|
||||||
|
if char in homophones:
|
||||||
|
homophones.remove(char)
|
||||||
|
|
||||||
|
# 过滤掉低频字
|
||||||
|
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
# 按照字频排序
|
||||||
|
sorted_homophones = sorted(homophones,
|
||||||
|
key=lambda x: char_frequency.get(x, 0),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# 只返回前10个同音字,避免输出过多
|
||||||
|
return sorted_homophones[:10]
|
||||||
|
|
||||||
|
def get_similar_tone_pinyin(py):
|
||||||
|
"""
|
||||||
|
获取相似声调的拼音
|
||||||
|
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
||||||
|
处理特殊情况:
|
||||||
|
1. 轻声(如 'de5' 或 'le')
|
||||||
|
2. 非数字结尾的拼音
|
||||||
|
"""
|
||||||
|
# 检查拼音是否为空或无效
|
||||||
|
if not py or len(py) < 1:
|
||||||
|
return py
|
||||||
|
|
||||||
|
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||||
|
if not py[-1].isdigit():
|
||||||
|
# 为非数字结尾的拼音添加数字声调1
|
||||||
|
return py + '1'
|
||||||
|
|
||||||
|
base = py[:-1] # 去掉声调
|
||||||
|
tone = int(py[-1]) # 获取声调
|
||||||
|
|
||||||
|
# 处理轻声(通常用5表示)或无效声调
|
||||||
|
if tone not in [1, 2, 3, 4]:
|
||||||
|
return base + str(random.choice([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
# 正常处理声调
|
||||||
|
possible_tones = [1, 2, 3, 4]
|
||||||
|
possible_tones.remove(tone) # 移除原声调
|
||||||
|
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||||
|
return base + str(new_tone)
|
||||||
|
|
||||||
|
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
||||||
|
"""
|
||||||
|
根据频率差计算替换概率
|
||||||
|
频率差越大,概率越低
|
||||||
|
:param orig_freq: 原字频率
|
||||||
|
:param target_freq: 目标字频率
|
||||||
|
:param max_freq_diff: 最大允许的频率差
|
||||||
|
:return: 0-1之间的概率值
|
||||||
|
"""
|
||||||
|
if target_freq > orig_freq:
|
||||||
|
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||||
|
|
||||||
|
freq_diff = orig_freq - target_freq
|
||||||
|
if freq_diff > max_freq_diff:
|
||||||
|
return 0.0 # 频率差太大,不替换
|
||||||
|
|
||||||
|
# 使用指数衰减函数计算概率
|
||||||
|
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||||
|
return math.exp(-3 * freq_diff / max_freq_diff)
|
||||||
|
|
||||||
|
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
获取与给定字频率相近的同音字,可能包含声调错误
|
||||||
|
"""
|
||||||
|
homophones = []
|
||||||
|
|
||||||
|
# 有20%的概率使用错误声调
|
||||||
|
if random.random() < tone_error_rate:
|
||||||
|
wrong_tone_py = get_similar_tone_pinyin(py)
|
||||||
|
homophones.extend(pinyin_dict[wrong_tone_py])
|
||||||
|
|
||||||
|
# 添加正确声调的同音字
|
||||||
|
homophones.extend(pinyin_dict[py])
|
||||||
|
|
||||||
|
if not homophones:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取原字的频率
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||||
|
freq_diff = [(h, char_frequency.get(h, 0))
|
||||||
|
for h in homophones
|
||||||
|
if h != char and char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
if not freq_diff:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算每个候选字的替换概率
|
||||||
|
candidates_with_prob = []
|
||||||
|
for h, freq in freq_diff:
|
||||||
|
prob = calculate_replacement_probability(orig_freq, freq)
|
||||||
|
if prob > 0: # 只保留有效概率的候选字
|
||||||
|
candidates_with_prob.append((h, prob))
|
||||||
|
|
||||||
|
if not candidates_with_prob:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据概率排序
|
||||||
|
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 返回概率最高的几个字
|
||||||
|
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||||
|
|
||||||
|
def get_word_pinyin(word):
|
||||||
|
"""
|
||||||
|
获取词语的拼音列表
|
||||||
|
"""
|
||||||
|
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||||
|
|
||||||
|
def segment_sentence(sentence):
|
||||||
|
"""
|
||||||
|
使用jieba分词,返回词语列表
|
||||||
|
"""
|
||||||
|
return list(jieba.cut(sentence))
|
||||||
|
|
||||||
|
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取整个词的同音词,只返回高频的有意义词语
|
||||||
|
:param word: 输入词语
|
||||||
|
:param pinyin_dict: 拼音字典
|
||||||
|
:param char_frequency: 字频字典
|
||||||
|
:param min_freq: 最小频率阈值
|
||||||
|
:return: 同音词列表
|
||||||
|
"""
|
||||||
|
if len(word) == 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取词的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
word_pinyin_str = ''.join(word_pinyin)
|
||||||
|
|
||||||
|
# 创建词语频率字典
|
||||||
|
word_freq = defaultdict(float)
|
||||||
|
|
||||||
|
# 遍历所有可能的同音字组合
|
||||||
|
candidates = []
|
||||||
|
for py in word_pinyin:
|
||||||
|
chars = pinyin_dict.get(py, [])
|
||||||
|
if not chars:
|
||||||
|
return []
|
||||||
|
candidates.append(chars)
|
||||||
|
|
||||||
|
# 生成所有可能的组合
|
||||||
|
import itertools
|
||||||
|
all_combinations = itertools.product(*candidates)
|
||||||
|
|
||||||
|
# 获取jieba词典和词频信息
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
valid_words = {} # 改用字典存储词语及其频率
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
word_text = parts[0]
|
||||||
|
word_freq = float(parts[1]) # 获取词频
|
||||||
|
valid_words[word_text] = word_freq
|
||||||
|
|
||||||
|
# 获取原词的词频作为参考
|
||||||
|
original_word_freq = valid_words.get(word, 0)
|
||||||
|
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||||
|
|
||||||
|
# 过滤和计算频率
|
||||||
|
homophones = []
|
||||||
|
for combo in all_combinations:
|
||||||
|
new_word = ''.join(combo)
|
||||||
|
if new_word != word and new_word in valid_words:
|
||||||
|
new_word_freq = valid_words[new_word]
|
||||||
|
# 只保留词频达到阈值的词
|
||||||
|
if new_word_freq >= min_word_freq:
|
||||||
|
# 计算词的平均字频(考虑字频和词频)
|
||||||
|
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||||
|
# 综合评分:结合词频和字频
|
||||||
|
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
|
||||||
|
if combined_score >= min_freq:
|
||||||
|
homophones.append((new_word, combined_score))
|
||||||
|
|
||||||
|
# 按综合分数排序并限制返回数量
|
||||||
|
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||||
|
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||||
|
|
||||||
|
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
|
||||||
|
"""
|
||||||
|
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||||
|
只使用高频的有意义词语进行替换
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
typo_info = []
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# 如果是标点符号或空格,直接添加
|
||||||
|
if all(not is_chinese_char(c) for c in word):
|
||||||
|
result.append(word)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取词语的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
|
||||||
|
# 尝试整词替换
|
||||||
|
if len(word) > 1 and random.random() < word_replace_rate:
|
||||||
|
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
|
||||||
|
if word_homophones:
|
||||||
|
typo_word = random.choice(word_homophones)
|
||||||
|
# 计算词的平均频率
|
||||||
|
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
|
||||||
|
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||||
|
|
||||||
|
# 添加到结果中
|
||||||
|
result.append(typo_word)
|
||||||
|
typo_info.append((word, typo_word,
|
||||||
|
' '.join(word_pinyin),
|
||||||
|
' '.join(get_word_pinyin(typo_word)),
|
||||||
|
orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果不进行整词替换,则进行单字替换
|
||||||
|
if len(word) == 1:
|
||||||
|
char = word
|
||||||
|
py = word_pinyin[0]
|
||||||
|
if random.random() < error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
# 处理多字词的单字替换
|
||||||
|
word_result = []
|
||||||
|
for i, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||||
|
# 词中的字替换概率降低
|
||||||
|
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
|
||||||
|
|
||||||
|
if random.random() < word_error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
word_result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
word_result.append(char)
|
||||||
|
result.append(''.join(word_result))
|
||||||
|
|
||||||
|
return ''.join(result), typo_info
|
||||||
|
|
||||||
|
def format_frequency(freq):
|
||||||
|
"""
|
||||||
|
格式化频率显示
|
||||||
|
"""
|
||||||
|
return f"{freq:.2f}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 首先创建拼音字典和加载字频统计
|
||||||
|
print("正在加载汉字数据库,请稍候...")
|
||||||
|
pinyin_dict = create_pinyin_dict()
|
||||||
|
char_frequency = load_or_create_char_frequency()
|
||||||
|
|
||||||
|
# 获取用户输入
|
||||||
|
sentence = input("请输入中文句子:")
|
||||||
|
|
||||||
|
# 创建包含错别字的句子
|
||||||
|
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
||||||
|
error_rate=0.3, min_freq=5,
|
||||||
|
tone_error_rate=0.2, word_replace_rate=0.3)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n原句:", sentence)
|
||||||
|
print("错字版:", typo_sentence)
|
||||||
|
|
||||||
|
if typo_info:
|
||||||
|
print("\n错别字信息:")
|
||||||
|
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||||
|
# 判断是否为词语替换
|
||||||
|
is_word = ' ' in orig_py
|
||||||
|
if is_word:
|
||||||
|
error_type = "整词替换"
|
||||||
|
else:
|
||||||
|
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||||
|
error_type = "声调错误" if tone_error else "同音字替换"
|
||||||
|
|
||||||
|
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
||||||
|
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
||||||
|
|
||||||
|
# 获取拼音结果
|
||||||
|
result = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 打印完整拼音
|
||||||
|
print("\n完整拼音:")
|
||||||
|
print(" ".join(py for _, py in result))
|
||||||
|
|
||||||
|
# 打印词语分析
|
||||||
|
print("\n词语分析:")
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
for word in words:
|
||||||
|
if any(is_chinese_char(c) for c in word):
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
print(f"词语:{word}")
|
||||||
|
print(f"拼音:{' '.join(word_pinyin)}")
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
# 计算并打印总耗时
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
print(f"\n总耗时:{total_time:.2f}秒")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
488
src/test/typo_creator.py
Normal file
488
src/test/typo_creator.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
错别字生成器 - 流程说明
|
||||||
|
|
||||||
|
整体替换逻辑:
|
||||||
|
1. 数据准备
|
||||||
|
- 加载字频词典:使用jieba词典计算汉字使用频率
|
||||||
|
- 创建拼音映射:建立拼音到汉字的映射关系
|
||||||
|
- 加载词频信息:从jieba词典获取词语使用频率
|
||||||
|
|
||||||
|
2. 分词处理
|
||||||
|
- 使用jieba将输入句子分词
|
||||||
|
- 区分单字词和多字词
|
||||||
|
- 保留标点符号和空格
|
||||||
|
|
||||||
|
3. 词语级别替换(针对多字词)
|
||||||
|
- 触发条件:词长>1 且 随机概率<0.3
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取词语拼音
|
||||||
|
b. 生成所有可能的同音字组合
|
||||||
|
c. 过滤条件:
|
||||||
|
- 必须是jieba词典中的有效词
|
||||||
|
- 词频必须达到原词频的10%以上
|
||||||
|
- 综合评分(词频70%+字频30%)必须达到阈值
|
||||||
|
d. 按综合评分排序,选择最合适的替换词
|
||||||
|
|
||||||
|
4. 字级别替换(针对单字词或未进行整词替换的多字词)
|
||||||
|
- 单字替换概率:0.3
|
||||||
|
- 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1))
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取字的拼音
|
||||||
|
b. 声调错误处理(20%概率)
|
||||||
|
c. 获取同音字列表
|
||||||
|
d. 过滤条件:
|
||||||
|
- 字频必须达到最小阈值
|
||||||
|
- 频率差异不能过大(指数衰减计算)
|
||||||
|
e. 按频率排序选择替换字
|
||||||
|
|
||||||
|
5. 频率控制机制
|
||||||
|
- 字频控制:使用归一化的字频(0-1000范围)
|
||||||
|
- 词频控制:使用jieba词典中的词频
|
||||||
|
- 频率差异计算:使用指数衰减函数
|
||||||
|
- 最小频率阈值:确保替换字/词不会太生僻
|
||||||
|
|
||||||
|
6. 输出信息
|
||||||
|
- 原文和错字版本的对照
|
||||||
|
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
|
||||||
|
- 替换类型说明(整词替换/声调错误/同音字替换)
|
||||||
|
- 词语分析和完整拼音
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
1. 所有替换都必须使用有意义的词语
|
||||||
|
2. 替换词的使用频率不能过低
|
||||||
|
3. 多字词优先考虑整词替换
|
||||||
|
4. 考虑声调变化的情况
|
||||||
|
5. 保持标点符号和空格不变
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pypinyin import pinyin, Style
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import jieba.posseg as pseg
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
def load_or_create_char_frequency():
|
||||||
|
"""
|
||||||
|
加载或创建汉字频率字典
|
||||||
|
"""
|
||||||
|
cache_file = Path("char_frequency.json")
|
||||||
|
|
||||||
|
# 如果缓存文件存在,直接加载
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# 使用内置的词频文件
|
||||||
|
char_freq = defaultdict(int)
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
|
||||||
|
# 读取jieba的词典文件
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
word, freq = line.strip().split()[:2]
|
||||||
|
# 对词中的每个字进行频率累加
|
||||||
|
for char in word:
|
||||||
|
if is_chinese_char(char):
|
||||||
|
char_freq[char] += int(freq)
|
||||||
|
|
||||||
|
# 归一化频率值
|
||||||
|
max_freq = max(char_freq.values())
|
||||||
|
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||||
|
|
||||||
|
# 保存到缓存文件
|
||||||
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return normalized_freq
|
||||||
|
|
||||||
|
# 创建拼音到汉字的映射字典
|
||||||
|
def create_pinyin_dict():
|
||||||
|
"""
|
||||||
|
创建拼音到汉字的映射字典
|
||||||
|
"""
|
||||||
|
# 常用汉字范围
|
||||||
|
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||||
|
pinyin_dict = defaultdict(list)
|
||||||
|
|
||||||
|
# 为每个汉字建立拼音映射
|
||||||
|
for char in chars:
|
||||||
|
try:
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
pinyin_dict[py].append(char)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
|
def is_chinese_char(char):
|
||||||
|
"""
|
||||||
|
判断是否为汉字
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pinyin(sentence):
|
||||||
|
"""
|
||||||
|
将中文句子拆分成单个汉字并获取其拼音
|
||||||
|
:param sentence: 输入的中文句子
|
||||||
|
:return: 每个汉字及其拼音的列表
|
||||||
|
"""
|
||||||
|
# 将句子拆分成单个字符
|
||||||
|
characters = list(sentence)
|
||||||
|
|
||||||
|
# 获取每个字符的拼音
|
||||||
|
result = []
|
||||||
|
for char in characters:
|
||||||
|
# 跳过空格和非汉字字符
|
||||||
|
if char.isspace() or not is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
# 获取拼音(数字声调)
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
result.append((char, py))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取同音字,按照使用频率排序
|
||||||
|
"""
|
||||||
|
homophones = pinyin_dict[py]
|
||||||
|
# 移除原字并过滤低频字
|
||||||
|
if char in homophones:
|
||||||
|
homophones.remove(char)
|
||||||
|
|
||||||
|
# 过滤掉低频字
|
||||||
|
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
# 按照字频排序
|
||||||
|
sorted_homophones = sorted(homophones,
|
||||||
|
key=lambda x: char_frequency.get(x, 0),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# 只返回前10个同音字,避免输出过多
|
||||||
|
return sorted_homophones[:10]
|
||||||
|
|
||||||
|
def get_similar_tone_pinyin(py):
|
||||||
|
"""
|
||||||
|
获取相似声调的拼音
|
||||||
|
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
||||||
|
处理特殊情况:
|
||||||
|
1. 轻声(如 'de5' 或 'le')
|
||||||
|
2. 非数字结尾的拼音
|
||||||
|
"""
|
||||||
|
# 检查拼音是否为空或无效
|
||||||
|
if not py or len(py) < 1:
|
||||||
|
return py
|
||||||
|
|
||||||
|
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||||
|
if not py[-1].isdigit():
|
||||||
|
# 为非数字结尾的拼音添加数字声调1
|
||||||
|
return py + '1'
|
||||||
|
|
||||||
|
base = py[:-1] # 去掉声调
|
||||||
|
tone = int(py[-1]) # 获取声调
|
||||||
|
|
||||||
|
# 处理轻声(通常用5表示)或无效声调
|
||||||
|
if tone not in [1, 2, 3, 4]:
|
||||||
|
return base + str(random.choice([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
# 正常处理声调
|
||||||
|
possible_tones = [1, 2, 3, 4]
|
||||||
|
possible_tones.remove(tone) # 移除原声调
|
||||||
|
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||||
|
return base + str(new_tone)
|
||||||
|
|
||||||
|
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
||||||
|
"""
|
||||||
|
根据频率差计算替换概率
|
||||||
|
频率差越大,概率越低
|
||||||
|
:param orig_freq: 原字频率
|
||||||
|
:param target_freq: 目标字频率
|
||||||
|
:param max_freq_diff: 最大允许的频率差
|
||||||
|
:return: 0-1之间的概率值
|
||||||
|
"""
|
||||||
|
if target_freq > orig_freq:
|
||||||
|
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||||
|
|
||||||
|
freq_diff = orig_freq - target_freq
|
||||||
|
if freq_diff > max_freq_diff:
|
||||||
|
return 0.0 # 频率差太大,不替换
|
||||||
|
|
||||||
|
# 使用指数衰减函数计算概率
|
||||||
|
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||||
|
return math.exp(-3 * freq_diff / max_freq_diff)
|
||||||
|
|
||||||
|
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
获取与给定字频率相近的同音字,可能包含声调错误
|
||||||
|
"""
|
||||||
|
homophones = []
|
||||||
|
|
||||||
|
# 有20%的概率使用错误声调
|
||||||
|
if random.random() < tone_error_rate:
|
||||||
|
wrong_tone_py = get_similar_tone_pinyin(py)
|
||||||
|
homophones.extend(pinyin_dict[wrong_tone_py])
|
||||||
|
|
||||||
|
# 添加正确声调的同音字
|
||||||
|
homophones.extend(pinyin_dict[py])
|
||||||
|
|
||||||
|
if not homophones:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取原字的频率
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||||
|
freq_diff = [(h, char_frequency.get(h, 0))
|
||||||
|
for h in homophones
|
||||||
|
if h != char and char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
if not freq_diff:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算每个候选字的替换概率
|
||||||
|
candidates_with_prob = []
|
||||||
|
for h, freq in freq_diff:
|
||||||
|
prob = calculate_replacement_probability(orig_freq, freq)
|
||||||
|
if prob > 0: # 只保留有效概率的候选字
|
||||||
|
candidates_with_prob.append((h, prob))
|
||||||
|
|
||||||
|
if not candidates_with_prob:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据概率排序
|
||||||
|
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 返回概率最高的几个字
|
||||||
|
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||||
|
|
||||||
|
def get_word_pinyin(word):
|
||||||
|
"""
|
||||||
|
获取词语的拼音列表
|
||||||
|
"""
|
||||||
|
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||||
|
|
||||||
|
def segment_sentence(sentence):
|
||||||
|
"""
|
||||||
|
使用jieba分词,返回词语列表
|
||||||
|
"""
|
||||||
|
return list(jieba.cut(sentence))
|
||||||
|
|
||||||
|
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取整个词的同音词,只返回高频的有意义词语
|
||||||
|
:param word: 输入词语
|
||||||
|
:param pinyin_dict: 拼音字典
|
||||||
|
:param char_frequency: 字频字典
|
||||||
|
:param min_freq: 最小频率阈值
|
||||||
|
:return: 同音词列表
|
||||||
|
"""
|
||||||
|
if len(word) == 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取词的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
word_pinyin_str = ''.join(word_pinyin)
|
||||||
|
|
||||||
|
# 创建词语频率字典
|
||||||
|
word_freq = defaultdict(float)
|
||||||
|
|
||||||
|
# 遍历所有可能的同音字组合
|
||||||
|
candidates = []
|
||||||
|
for py in word_pinyin:
|
||||||
|
chars = pinyin_dict.get(py, [])
|
||||||
|
if not chars:
|
||||||
|
return []
|
||||||
|
candidates.append(chars)
|
||||||
|
|
||||||
|
# 生成所有可能的组合
|
||||||
|
import itertools
|
||||||
|
all_combinations = itertools.product(*candidates)
|
||||||
|
|
||||||
|
# 获取jieba词典和词频信息
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
valid_words = {} # 改用字典存储词语及其频率
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
word_text = parts[0]
|
||||||
|
word_freq = float(parts[1]) # 获取词频
|
||||||
|
valid_words[word_text] = word_freq
|
||||||
|
|
||||||
|
# 获取原词的词频作为参考
|
||||||
|
original_word_freq = valid_words.get(word, 0)
|
||||||
|
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||||
|
|
||||||
|
# 过滤和计算频率
|
||||||
|
homophones = []
|
||||||
|
for combo in all_combinations:
|
||||||
|
new_word = ''.join(combo)
|
||||||
|
if new_word != word and new_word in valid_words:
|
||||||
|
new_word_freq = valid_words[new_word]
|
||||||
|
# 只保留词频达到阈值的词
|
||||||
|
if new_word_freq >= min_word_freq:
|
||||||
|
# 计算词的平均字频(考虑字频和词频)
|
||||||
|
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||||
|
# 综合评分:结合词频和字频
|
||||||
|
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
|
||||||
|
if combined_score >= min_freq:
|
||||||
|
homophones.append((new_word, combined_score))
|
||||||
|
|
||||||
|
# 按综合分数排序并限制返回数量
|
||||||
|
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||||
|
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||||
|
|
||||||
|
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
|
||||||
|
"""
|
||||||
|
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||||
|
只使用高频的有意义词语进行替换
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
typo_info = []
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# 如果是标点符号或空格,直接添加
|
||||||
|
if all(not is_chinese_char(c) for c in word):
|
||||||
|
result.append(word)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取词语的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
|
||||||
|
# 尝试整词替换
|
||||||
|
if len(word) > 1 and random.random() < word_replace_rate:
|
||||||
|
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
|
||||||
|
if word_homophones:
|
||||||
|
typo_word = random.choice(word_homophones)
|
||||||
|
# 计算词的平均频率
|
||||||
|
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
|
||||||
|
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||||
|
|
||||||
|
# 添加到结果中
|
||||||
|
result.append(typo_word)
|
||||||
|
typo_info.append((word, typo_word,
|
||||||
|
' '.join(word_pinyin),
|
||||||
|
' '.join(get_word_pinyin(typo_word)),
|
||||||
|
orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果不进行整词替换,则进行单字替换
|
||||||
|
if len(word) == 1:
|
||||||
|
char = word
|
||||||
|
py = word_pinyin[0]
|
||||||
|
if random.random() < error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
# 处理多字词的单字替换
|
||||||
|
word_result = []
|
||||||
|
for i, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||||
|
# 词中的字替换概率降低
|
||||||
|
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
|
||||||
|
|
||||||
|
if random.random() < word_error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
word_result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
word_result.append(char)
|
||||||
|
result.append(''.join(word_result))
|
||||||
|
|
||||||
|
return ''.join(result), typo_info
|
||||||
|
|
||||||
|
def format_frequency(freq):
|
||||||
|
"""
|
||||||
|
格式化频率显示
|
||||||
|
"""
|
||||||
|
return f"{freq:.2f}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 首先创建拼音字典和加载字频统计
|
||||||
|
print("正在加载汉字数据库,请稍候...")
|
||||||
|
pinyin_dict = create_pinyin_dict()
|
||||||
|
char_frequency = load_or_create_char_frequency()
|
||||||
|
|
||||||
|
# 获取用户输入
|
||||||
|
sentence = input("请输入中文句子:")
|
||||||
|
|
||||||
|
# 创建包含错别字的句子
|
||||||
|
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
||||||
|
error_rate=0.3, min_freq=5,
|
||||||
|
tone_error_rate=0.2, word_replace_rate=0.3)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n原句:", sentence)
|
||||||
|
print("错字版:", typo_sentence)
|
||||||
|
|
||||||
|
if typo_info:
|
||||||
|
print("\n错别字信息:")
|
||||||
|
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||||
|
# 判断是否为词语替换
|
||||||
|
is_word = ' ' in orig_py
|
||||||
|
if is_word:
|
||||||
|
error_type = "整词替换"
|
||||||
|
else:
|
||||||
|
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||||
|
error_type = "声调错误" if tone_error else "同音字替换"
|
||||||
|
|
||||||
|
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
||||||
|
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
||||||
|
|
||||||
|
# 获取拼音结果
|
||||||
|
result = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 打印完整拼音
|
||||||
|
print("\n完整拼音:")
|
||||||
|
print(" ".join(py for _, py in result))
|
||||||
|
|
||||||
|
# 打印词语分析
|
||||||
|
print("\n词语分析:")
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
for word in words:
|
||||||
|
if any(is_chinese_char(c) for c in word):
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
print(f"词语:{word}")
|
||||||
|
print(f"拼音:{' '.join(word_pinyin)}")
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
# 计算并打印总耗时
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
print(f"\n总耗时:{total_time:.2f}秒")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user