Merge remote-tracking branch 'upstream/debug' into feature

This commit is contained in:
tcmofashi
2025-03-04 08:18:22 +08:00
34 changed files with 13732 additions and 413 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.git
__pycache__
*.pyc
*.pyo
*.pyd
.DS_Store

26
.env
View File

@@ -1,26 +1,2 @@
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
ENVIRONMENT=dev
# HOST=127.0.0.1
# PORT=8080
# COMMAND_START=["/"]
# # 插件配置
# PLUGINS=["src2.plugins.chat"]
# # 默认配置
# MONGODB_HOST=127.0.0.1
# MONGODB_PORT=27017
# DATABASE_NAME=MegBot
# MONGODB_USERNAME = "" # 默认空值
# MONGODB_PASSWORD = "" # 默认空值
# MONGODB_AUTH_SOURCE = "" # 默认空值
# #key and url
# CHAT_ANY_WHERE_KEY=
# SILICONFLOW_KEY=
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
# DEEP_SEEK_KEY=
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
ENVIRONMENT=.dev

View File

@@ -1,8 +1,6 @@
HOST=127.0.0.1
PORT=8080
COMMAND_START=["/"]
# 插件配置
PLUGINS=["src2.plugins.chat"]
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
#key and url
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
#定义你要用的api的base_url
DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=

View File

@@ -1,8 +1,18 @@
FROM nonebot/nb-cli:latest
WORKDIR /
COPY . /MaiMBot/
# 设置工作目录
WORKDIR /MaiMBot
# 先复制依赖列表
COPY requirements.txt .
# 安装依赖这层会被缓存直到requirements.txt改变
RUN pip install --upgrade -r requirements.txt
# 然后复制项目代码
COPY . .
VOLUME [ "/MaiMBot/config" ]
VOLUME [ "/MaiMBot/data" ]
EXPOSE 8080
ENTRYPOINT [ "nb","run" ]
ENTRYPOINT [ "nb","run" ]

178
README.md
View File

@@ -3,7 +3,6 @@
<div align="center">
![Python Version](https://img.shields.io/badge/Python-3.x-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
@@ -12,163 +11,33 @@
## 📝 项目简介
**麦麦qq机器人的源代码仓库**
**🍔麦麦是一个基于大语言模型的智能QQ群聊机器人**
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
- 🤖 基于 nonebot2 框架开发
- 🧠 LLM 提供对话能力
- 💾 MongoDB 提供数据持久化支持
- 🐧 NapCat 作为QQ协议端支持
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
<img src="docs/video.png" width="300" alt="麦麦演示视频">
<br>
👆 点击观看麦麦演示视频 👆
</a>
</div>
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
> ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语
> ⚠️ **注意事项**
> - 项目处于活跃开发阶段,代码可能随时更改
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
关于麦麦的开发和建议相关的讨论群:766798517不建议发布无关消息这里不会有麦麦发言
**交流群**: 766798517仅用于开发和建议相关讨论不建议在群内询问部署问题我不一定有空回复会优先写文档和代码
## 开发计划TODOLIST
## 📚 文档
- 兼容gif的解析和保存
- 小程序转发链接解析
- 对思考链长度限制
- 修复已知bug
- 完善文档
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
- 改进表情包发送逻辑l
## 📚 详细文档
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
### 安装方法还没测试好随时outdated ,现在部署可能遇到未知问题!!!!)
#### Linux 使用 Docker Compose 部署
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
```
配置文件修改完成后,运行以下命令
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
```
#### 手动运行
1. **创建Python环境**
推荐使用conda或其他虚拟环境进行依赖安装防止出现依赖版本冲突问题
```bash
# 安装requirements
pip install -r requirements.txt
```
2. **MongoDB设置**
- 安装并运行mongodb
- 麦麦bot会自动连接默认的mongodb端口和数据库名可配置
3. **Napcat配置**
- 安装并运行Napcat登录
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
4. **配置文件设置**
- 修改.env的 变量值为 prod
- 将.env.prod文件打开填上你的apikey硅基流动或deepseekapi
- 将bot_config_toml改名为bot_config.toml打开并填写相关内容不然无法正常运行
#### .env 文件配置说明
```ini
# 环境配置
ENVIRONMENT=dev # 开发环境设置
HOST=127.0.0.1 # 主机地址
PORT=8080 # 端口号
# 命令前缀设置
COMMAND_START=["/"] # 命令起始符
# 插件配置
PLUGINS=["src2.plugins.chat"] # 启用的插件列表
# MongoDB配置
MONGODB_HOST=127.0.0.1 # MongoDB主机地址
MONGODB_PORT=27017 # MongoDB端口
DATABASE_NAME=MegBot # 数据库名称
MONGODB_USERNAME="" # MongoDB用户名可选
MONGODB_PASSWORD="" # MongoDB密码可选
MONGODB_AUTH_SOURCE="" # MongoDB认证源可选
#api配置项建议siliconflow必填识图需要这个
SILICONFLOW_KEY=
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
```
#### bot_config.toml 文件配置说明
```toml
# 数据库设置
[database]
host = "127.0.0.1" # MongoDB主机地址
port = 27017 # MongoDB端口
name = "MegBot" # 数据库名称
# 机器人基本设置
[bot]
qq = # 你的机器人QQ号必填
nickname = "麦麦" # 机器人昵称
# 消息处理设置
[message]
min_text_length = 2 # 最小响应文本长度
max_context_size = 15 # 上下文最大保存数量
emoji_chance = 0.2 # 表情包使用概率
# 表情包功能设置
[emoji]
check_interval = 120 # 表情检查间隔(秒)
register_interval = 10 # 表情注册间隔(秒)
# CQ码设置
[cq_code]
enable_pic_translate = false # 是否启用图片转换(无效)
# 响应设置
[response]
api_using = "siliconflow" # 回复使用的APIsiliconflow/deepseek
model_r1_probability = 0.8 # R1模型使用概率
model_v3_probability = 0.1 # V3模型使用概率
model_r1_distill_probability = 0.1 # R1蒸馏模型使用概率对deepseek api 无效)
# 其他设置
[others]
enable_advance_output = false # 是否启用详细日志输出
# 群组设置
[groups]
talk_allowed = [ # 允许回复的群号列表
# 在这里添加群号,逗号隔开
]
talk_frequency_down = [ # 降低回复频率的群号列表
# 在这里添加群号,逗号隔开
]
ban_user_id = [ # 禁止回复的用户QQ号列表
# 在这里添加QQ号,逗号隔开
]
```
5. **运行麦麦**
在含有bot.py程序的目录下运行如果使用了虚拟环境需要先进入虚拟环境
```bash
nb run
```
6. **运行其他组件**
run_thingking.bat 可以启动可视化的推理界面未完善和消息队列及其他信息预览WIP
knowledge.bat可以将/data/raw_info下的文本文档载入到数据库未启动
- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明
- [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节
## 🎯 功能介绍
@@ -204,6 +73,19 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
- 幽默和meme功能WIP的WIP
- 让麦麦玩mcWIP的WIP的WIP
## 开发计划TODOLIST
- 兼容gif的解析和保存
- 小程序转发链接解析
- 对思考链长度限制
- 修复已知bug
- 完善文档
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
- 改进表情包发送逻辑
## 📌 注意事项
纯编程外行面向cursor编程很多代码史一样多多包涵
@@ -218,3 +100,7 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
感谢各位大佬!
[![Contributors](https://contributors-img.web.app/image?repo=SengokuCola/MaiMBot)](https://github.com/SengokuCola/MaiMBot/graphs/contributors)
## Stargazers over time
[![Stargazers over time](https://starchart.cc/SengokuCola/MaiMBot.svg?variant=adaptive)](https://starchart.cc/SengokuCola/MaiMBot)

13
bot.py
View File

@@ -15,25 +15,22 @@ for i, char in enumerate(text):
print(rainbow_text)
'''彩蛋'''
# 首先加载基础环境变量
# 首先加载基础环境变量.env
if os.path.exists(".env"):
load_dotenv(".env")
logger.success("成功加载基础环境变量配置")
else:
logger.error("基础环境变量配置文件 .env 不存在")
exit(1)
# 根据 ENVIRONMENT 加载对应的环境配置
env = os.getenv("ENVIRONMENT")
env_file = f".env.{env}"
if env_file == ".env.dev" and os.path.exists(env_file):
if os.path.exists(".env.dev"):
logger.success("加载开发环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
elif os.path.exists(".env.prod"):
logger.success("加载环境变量配置")
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
else:
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
logger.error(f".env对应的环境配置文件不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
exit(1)
# 获取所有环境变量

12012
char_frequency.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -3,59 +3,69 @@ qq = 123
nickname = "麦麦"
[message]
min_text_length = 2
max_context_size = 15
emoji_chance = 0.2
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
emoji_chance = 0.2 # 麦麦使用表情包的概率
[emoji]
check_interval = 120
register_interval = 10
check_interval = 120 # 检查表情包的时间间隔
register_interval = 10 # 注册表情包的时间间隔
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow"
api_paid = true
model_r1_probability = 0.8
model_v3_probability = 0.1
model_r1_distill_probability = 0.1
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
[memory]
build_memory_interval = 300
build_memory_interval = 300 # 记忆构建间隔 单位秒
[others]
enable_advance_output = true
enable_advance_output = true # 是否启用高级输出
enable_kuuki_read = true # 是否启用读空气功能
[groups]
talk_allowed = [
123,
123,
]
talk_frequency_down = []
ban_user_id = []
] #可以回复消息的群
talk_frequency_down = [] #降低回复频率的群
ban_user_id = [] #禁止回复消息的QQ号
[model.llm_reasoning]
#V3
#name = "deepseek-chat"
#base_url = "DEEP_SEEK_BASE_URL"
#key = "DEEP_SEEK_KEY"
#R1
#name = "deepseek-reasoner"
#base_url = "DEEP_SEEK_BASE_URL"
#key = "DEEP_SEEK_KEY"
[model.llm_reasoning] #R1
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor]
[model.llm_reasoning_minor] #R1蒸馏
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal]
[model.llm_normal] #V3
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor]
[model.llm_normal_minor] #V2.5
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm]
[model.vlm] #图像识别
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"

145
docs/installation.md Normal file
View File

@@ -0,0 +1,145 @@
# 🔧 安装与配置指南
## 部署方式
如果你不知道Docker是什么建议寻找相关教程或使用手动部署
### 🐳 Docker部署推荐但不一定是最新
1. 获取配置文件:
```bash
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml
```
2. 启动服务:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
```
3. 修改配置后重启:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
```
### 📦 手动部署
1. **环境准备**
```bash
# 创建虚拟环境(推荐)
python -m venv venv
venv\\Scripts\\activate # Windows
# 安装依赖
pip install -r requirements.txt
```
2. **配置MongoDB**
- 安装并启动MongoDB服务
- 默认连接本地27017端口
3. **配置NapCat**
- 安装并登录NapCat
- 添加反向WS`ws://localhost:8080/onebot/v11/ws`
4. **配置文件设置**
- 修改环境配置文件:`.env.prod`
- 修改机器人配置文件:`bot_config.toml`
5. **启动麦麦机器人**
- 打开命令行cd到对应路径
```bash
nb run
```
6. **其他组件**
- `run_thingking.bat`: 启动可视化推理界面(未完善)
- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~
- 直接运行 knowledge.py生成知识库
## ⚙️ 配置说明
### 环境配置 (.env.prod)
```ini
# API配置,你可以在这里定义你的密钥和base_url
# 你可以选择定义其他服务商提供的KEY完全可以自定义
SILICONFLOW_KEY=your_key
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=your_key
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
# 服务配置,如果你不知道这是什么,保持默认
HOST=127.0.0.1
PORT=8080
# 数据库配置,如果你不知道这是什么,保持默认
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
```
### 机器人配置 (bot_config.toml)
```toml
[bot]
qq = "你的机器人QQ号"
nickname = "麦麦"
[message]
min_text_length = 2
max_context_size = 15
emoji_chance = 0.2
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
#现已移除deepseek或硅基流动选项可以直接切换分别配置任意模型
model_r1_probability = 0.8 #推理模型权重
model_v3_probability = 0.1 #非推理模型权重
model_r1_distill_probability = 0.1
[memory]
build_memory_interval = 300
[others]
enable_advance_output = true # 是否启用详细日志输出
[groups]
talk_allowed = [] # 允许回复的群号列表
talk_frequency_down = [] # 降低回复频率的群号列表
ban_user_id = [] # 禁止回复的用户QQ号列表
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor]
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm]
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
```
## ⚠️ 注意事项
- 目前部署方案仍在测试中,可能存在未知问题
- 配置文件中的API密钥请妥善保管不要泄露
- 建议先在测试环境中运行,确认无误后再部署到生产环境

BIN
docs/video.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

View File

@@ -1,3 +1,4 @@
chcp 65001
call conda activate niuniu
cd .

View File

@@ -17,12 +17,12 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
)
print("\033[1;32m[初始化数据库完成]\033[0m")

View File

@@ -97,8 +97,13 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
topic = topic3
all_num = 0
interested_num = 0
@@ -166,7 +171,6 @@ class ChatBot:
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,

View File

@@ -116,6 +116,9 @@ class BotConfig:
if "vlm" in model_config:
config.vlm = model_config["vlm"]
if "embedding" in model_config:
config.embedding = model_config["embedding"]
# 消息配置
if "message" in toml_dict:
@@ -138,7 +141,7 @@ class BotConfig:
if "others" in toml_dict:
others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
logger.success(f"成功加载配置文件: {config_path}")
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
logger.info("使用bot配置文件")
else:
logger.info("已找到开发环境配置文件")
logger.info("已找到开发bot配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
pass

View File

@@ -8,7 +8,7 @@ from ...common.database import Database
from PIL import Image
from .config import global_config
import urllib3
from .utils_user import get_user_nickname,get_user_cardname
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode
@@ -21,50 +21,47 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str=None # 用户群昵称
group_name: str = None # 群名称
message_id: int = None
raw_message: str = None
plain_text: str = None
message_based_id: int = None
reply_message: Dict = None # 存储回复消息
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
time: float = None
reply_message: Dict = None # 存储 回复的 源消息
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0
type: str = 'received' # 消息类型可以是received或者send
def __post_init__(self):
if self.time is None:
self.time = int(time.time())
if not self.group_name:
self.group_name = get_groupname(self.group_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id)
if not self.user_cardname:
self.user_cardname=get_user_cardname(self.user_id)
if not self.group_name:
self.group_name = self.get_groupname(self.group_id)
if not self.processed_plain_text:
if self.raw_message:
self.message_segments = self.parse_message_segments(str(self.raw_message))
@@ -244,6 +241,38 @@ class MessageSet:
return len(self.messages)
@dataclass
class Message_Sending(Message):
"""发送消息数据类继承自Message类"""
priority: int = 0 # 发送优先级,数字越大优先级越高
wait_until: float = None # 等待发送的时间戳
continue_thinking: bool = False # 是否继续思考
def __post_init__(self):
super().__post_init__()
if self.wait_until is None:
self.wait_until = self.time
@property
def can_send(self) -> bool:
"""检查是否可以发送消息"""
return time.time() >= self.wait_until
def set_wait_time(self, seconds: float) -> None:
"""设置等待发送时间"""
self.wait_until = time.time() + seconds
def set_priority(self, priority: int) -> None:
"""设置发送优先级"""
self.priority = priority
def __lt__(self, other):
"""重写小于比较,用于优先级排序"""
if not isinstance(other, Message_Sending):
return NotImplemented
return (self.priority, -self.wait_until) < (other.priority, -other.wait_until)

View File

@@ -201,7 +201,7 @@ class MessageSendControl:
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
cost_time = round(time.time(), 2) - message.time
if cost_time > 40:
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
cur_time = time.time()
await self._current_bot.send_group_msg(
group_id=group_id,

View File

View File

@@ -127,15 +127,15 @@ class MessageStream:
# 从数据库中查询最近的消息
recent_messages = list(db.db.messages.find(
{"group_id": self.group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
# "user_cardname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# # "user_cardname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(count))
if not recent_messages:
@@ -145,17 +145,21 @@ class MessageStream:
from .message import Message
messages = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
return list(reversed(messages)) # 返回按时间正序的消息

View File

@@ -118,7 +118,7 @@ class PromptBuilder:
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
promt_info_prompt = '你有一些[知识],在上面可以参考。'
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")

View File

@@ -0,0 +1,14 @@
#Broca's Area
# 功能:语言产生、语法处理和言语运动控制。
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
import time
class Thinking_Idea:
def __init__(self, message_id: str):
self.messages = [] # 消息列表集合
self.current_thoughts = [] # 当前思考内容列表
self.time = time.time() # 创建时间
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID

View File

@@ -4,6 +4,8 @@ from .message import Message
import jieba
from nonebot import get_driver
from .config import global_config
from snownlp import SnowNLP
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -11,12 +13,10 @@ config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=config.siliconflow_key, base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:
"""识别消息主题"""
self.llm_client = LLM_request(model=global_config.llm_normal)
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:\
1. 主题通常2-4个字必须简短要求精准概括不要太具体。\
@@ -24,36 +24,20 @@ class TopicIdentifier:
3. 这里是
消息内容:{text}"""
response = self.client.chat.completions.create(
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10,
)
if not response or not response.choices:
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_client.generate_response(prompt)
if not topic:
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
return None
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
topic = (
response.choices[0].message.content.strip()
if response.choices[0].message.content
else None
)
if topic == "无主题":
return None
else:
# print(f"[主题分析结果]{text[:20]}... : {topic}")
split_topic = self.parse_topic(topic)
return split_topic
def parse_topic(self, topic: str) -> List[str]:
"""解析主题,返回主题列表"""
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return []
return [t.strip() for t in topic.split(",") if t.strip()]
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
@@ -239,33 +223,12 @@ class TopicIdentifier:
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
"",
"",
"",
"",
"",
"",
"",
'"',
'"',
""", """,
"",
"",
"",
"",
"",
"",
"",
"",
"·",
"",
"~",
"",
"+",
"=",
"-",
"[",
"]",
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '',
'', '', '', '', '', '', '', '', '', '', ''
}:
filtered_words.append(word)
@@ -280,5 +243,25 @@ class TopicIdentifier:
return top_words if top_words else None
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
Args:
text (str): 需要识别主题的文本
Returns:
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
"""
if not text or len(text.strip()) == 0:
return None
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(3)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
return None
topic_identifier = TopicIdentifier()

View File

@@ -10,6 +10,7 @@ from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return False
def get_embedding(text):
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
if response.status_code != 200:
print(f"API请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
return response.json()['data'][0]['embedding']
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text)
def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
@@ -142,14 +127,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit))
if not recent_messages:
@@ -159,16 +144,20 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
from .message import Message
message_objects = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
# 按时间正序排列
message_objects.reverse()
@@ -181,7 +170,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"user_cardname": 1, #返回用户群昵称
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}
@@ -193,6 +181,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
message_detailed_plain_text = ''
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:

View File

@@ -6,8 +6,12 @@ def get_user_nickname(user_id: int) -> str:
return global_config.BOT_NICKNAME
# print(user_id)
return relationship_manager.get_name(user_id)
def get_user_cardname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return ''
return ''
def get_groupname(group_id: int) -> str:
return f"{group_id}"

View File

@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class KnowledgeLibrary:

View File

@@ -2,7 +2,6 @@
import os
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
@@ -10,10 +9,76 @@ from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
from dotenv import load_dotenv
import sys
import asyncio
import aiohttp
from typing import Tuple
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class LLMModel:
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
class Memory_graph:
def __init__(self):
@@ -158,12 +223,12 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
@@ -185,11 +250,14 @@ def main():
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
for item in second_layer_items:
print(item)
else:
print("未找到相关记忆。")

View File

@@ -66,7 +66,7 @@ class LLMModel:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""

View File

@@ -259,12 +259,12 @@ config = driver.config
start_time = time.time()
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= config.MONGODB_PORT,
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -9,7 +9,7 @@ driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
def __init__(self, model ,**kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
@@ -61,7 +61,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -126,7 +126,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -166,8 +166,8 @@ class LLM_request:
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
@@ -191,9 +191,119 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None

View File

@@ -11,12 +11,12 @@ config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator:
@@ -128,6 +128,10 @@ class ScheduleGenerator:
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1=="24:00":
time1="23:59"
if time2=="24:00":
time2="23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
@@ -165,4 +169,4 @@ class ScheduleGenerator:
# if __name__ == "__main__":
# main()
bot_schedule = ScheduleGenerator()
bot_schedule = ScheduleGenerator()

View File

@@ -0,0 +1,53 @@
from snownlp import SnowNLP
def analyze_emotion_snownlp(text):
"""
使用SnowNLP进行中文情感分析
:param text: 输入文本
:return: 情感得分(0-1之间越接近1越积极)
"""
try:
s = SnowNLP(text)
sentiment_score = s.sentiments
# 获取文本的关键词
keywords = s.keywords(3)
return {
'sentiment_score': sentiment_score,
'keywords': keywords,
'summary': s.summary(1) # 生成文本摘要
}
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_snownlp(score):
"""
将情感得分转换为描述性文字
"""
if score is None:
return "无法分析情感"
if score > 0.8:
return "非常积极"
elif score > 0.6:
return "较为积极"
elif score > 0.4:
return "中性偏积极"
elif score > 0.2:
return "中性偏消极"
else:
return "消极"
if __name__ == "__main__":
# 测试样例
test_text = "我们学校有免费的gpt4用"
result = analyze_emotion_snownlp(test_text)
if result:
print(f"测试文本: {test_text}")
print(f"情感得分: {result['sentiment_score']:.2f}")
print(f"情感倾向: {get_emotion_description_snownlp(result['sentiment_score'])}")
print(f"关键词: {', '.join(result['keywords'])}")
print(f"文本摘要: {result['summary'][0]}")

54
src/test/snownlp_demo.py Normal file
View File

@@ -0,0 +1,54 @@
from snownlp import SnowNLP
def demo_snownlp_features(text):
"""
展示SnowNLP的主要功能
:param text: 输入文本
"""
print(f"\n=== SnowNLP功能演示 ===")
print(f"输入文本: {text}")
# 创建SnowNLP对象
s = SnowNLP(text)
# 1. 分词
print(f"\n1. 分词结果:")
print(f" {' | '.join(s.words)}")
# 2. 情感分析
print(f"\n2. 情感分析:")
sentiment = s.sentiments
print(f" 情感得分: {sentiment:.2f}")
print(f" 情感倾向: {'积极' if sentiment > 0.5 else '消极' if sentiment < 0.5 else '中性'}")
# 3. 关键词提取
print(f"\n3. 关键词提取:")
print(f" {', '.join(s.keywords(3))}")
# 4. 词性标注
print(f"\n4. 词性标注:")
print(f" {' '.join([f'{word}/{tag}' for word, tag in s.tags])}")
# 5. 拼音转换
print(f"\n5. 拼音:")
print(f" {' '.join(s.pinyin)}")
# 6. 文本摘要
if len(text) > 100: # 只对较长文本生成摘要
print(f"\n6. 文本摘要:")
print(f" {' '.join(s.summary(3))}")
if __name__ == "__main__":
# 测试用例
test_texts = [
"这家新开的餐厅很不错,菜品种类丰富,味道可口,服务态度也很好,价格实惠,强烈推荐大家来尝试!",
"这部电影剧情混乱,演技浮夸,特效粗糙,配乐难听,完全浪费了我的时间和票价。",
"""人工智能正在改变我们的生活方式。它能够帮助我们完成复杂的计算任务,
提供个性化的服务推荐,优化交通路线,辅助医疗诊断。但同时我们也要警惕
人工智能带来的问题,比如隐私安全、就业变化等。如何正确认识和利用人工智能,
是我们每个人都需要思考的问题。"""
]
for text in test_texts:
demo_snownlp_features(text)
print("\n" + "="*50)

488
src/test/typo.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

488
src/test/typo_creator.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()