Merge pull request #36 from SengokuCola/debug

v0.4.1
This commit is contained in:
SengokuCola
2025-03-03 19:38:10 +08:00
committed by GitHub
45 changed files with 14353 additions and 1030 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.git
__pycache__
*.pyc
*.pyo
*.pyd
.DS_Store

26
.env Normal file
View File

@@ -0,0 +1,26 @@
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
ENVIRONMENT=dev
# HOST=127.0.0.1
# PORT=8080
# COMMAND_START=["/"]
# # 插件配置
# PLUGINS=["src2.plugins.chat"]
# # 默认配置
# MONGODB_HOST=127.0.0.1
# MONGODB_PORT=27017
# DATABASE_NAME=MegBot
# MONGODB_USERNAME = "" # 默认空值
# MONGODB_PASSWORD = "" # 默认空值
# MONGODB_AUTH_SOURCE = "" # 默认空值
# #key and url
# CHAT_ANY_WHERE_KEY=
# SILICONFLOW_KEY=
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
# DEEP_SEEK_KEY=
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

View File

@@ -1,4 +1,3 @@
ENVIRONMENT=dev
HOST=127.0.0.1
PORT=8080
@@ -11,15 +10,17 @@ PLUGINS=["src2.plugins.chat"]
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
MONGODB_USERNAME = "" # 默认空值
MONGODB_PASSWORD = "" # 默认空值
MONGODB_AUTH_SOURCE = "" # 默认空值
#api配置项
SILICONFLOW_KEY=
#key and url
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=

10
.gitignore vendored
View File

@@ -2,19 +2,15 @@ data/
mongodb/
NapCat.Framework.Windows.Once/
log/
src/plugins/memory
config/bot_config.toml
/test
message_queue_content.txt
message_queue_content.bat
message_queue_window.bat
message_queue_window.txt
reasoning_content.txt
reasoning_content.bat
reasoning_window.bat
queue_update.txt
memory_graph.gml
.env.*
config/bot_config_dev.toml
# Byte-compiled / optimized / DLL files
__pycache__/
@@ -146,7 +142,6 @@ celerybeat.pid
*.sage.py
# Environments
.env
.venv
env/
venv/
@@ -187,3 +182,4 @@ cython_debug/
# PyPI configuration file
.pypirc
.env

View File

@@ -1,10 +1,18 @@
FROM nonebot/nb-cli:latest
WORKDIR /
COPY . /MaiMBot/
# 设置工作目录
WORKDIR /MaiMBot
RUN mv config/env.example config/.env \
&& mv config/bot_config_toml config/bot_config.toml
# 先复制依赖列表
COPY requirements.txt .
# 安装依赖这层会被缓存直到requirements.txt改变
RUN pip install --upgrade -r requirements.txt
# 然后复制项目代码
COPY . .
VOLUME [ "/MaiMBot/config" ]
VOLUME [ "/MaiMBot/data" ]
EXPOSE 8080
ENTRYPOINT [ "nb","run" ]

176
README.md
View File

@@ -3,7 +3,6 @@
<div align="center">
![Python Version](https://img.shields.io/badge/Python-3.x-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot)
![Status](https://img.shields.io/badge/状态-开发中-yellow)
@@ -12,165 +11,33 @@
## 📝 项目简介
**麦麦qq机器人的源代码仓库**
**🍔麦麦是一个基于大语言模型的智能QQ群聊机器人**
基于llm、napcat、nonebot和mongodb的专注于群聊天的qqbot
- 🤖 基于 nonebot2 框架开发
- 🧠 LLM 提供对话能力
- 💾 MongoDB 提供数据持久化支持
- 🐧 NapCat 作为QQ协议端支持
<div align="center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="https://i0.hdslb.com/bfs/archive/7d9fa0a88e8a1aa01b92b8a5a743a2671c0e1798.jpg" width="500" alt="麦麦演示视频">
<img src="docs/video.png" width="300" alt="麦麦演示视频">
<br>
👆 点击观看麦麦演示视频 👆
</a>
</div>
> ⚠️ **警告**:代码可能随时更改,目前版本不一定是稳定版本
> ⚠️ **警告**请自行了解qqbot的风险麦麦有时候一天被腾讯肘七八次
> ⚠️ **警告**由于麦麦一直在迭代所以可能存在一些bug请自行测试包括胡言乱语
> ⚠️ **注意事项**
> - 项目处于活跃开发阶段,代码可能随时更改
> - 文档未完善,有问题可以提交 Issue 或者 Discussion
> - QQ机器人存在被限制风险请自行了解谨慎使用
> - 由于持续迭代可能存在一些已知或未知的bug
关于麦麦的开发和建议相关讨论群(不建议发布无关消息)这里不会有麦麦发言!
**交流群**: 766798517仅用于开发和建议相关讨论
## 开发计划TODOLIST
## 📚 文档
- 兼容gif的解析和保存
- 小程序转发链接解析
- 对思考链长度限制
- 修复已知bug
- 完善文档
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
<div align="center">
<img src="docs/qq.png" width="300" />
</div>
## 📚 详细文档
- [项目详细介绍和架构说明](docs/doc1.md) - 包含完整的项目结构、文件说明和核心功能实现细节(由claude-3.5-sonnet生成)
### 安装方法(还没测试好,现在部署可能遇到未知问题!!!!)
#### Linux 使用 Docker Compose 部署
获取项目根目录中的```docker-compose.yml```文件,运行以下命令
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
```
配置文件修改完成后,运行以下命令
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
```
#### 手动运行
1. **创建Python环境**
推荐使用conda或其他虚拟环境进行依赖安装防止出现依赖版本冲突问题
```bash
# 安装requirements
pip install -r requirements.txt
```
2. **MongoDB设置**
- 安装并运行mongodb
- 麦麦bot会自动连接默认的mongodb端口和数据库名可配置
3. **Napcat配置**
- 安装并运行Napcat登录
- 在Napcat的网络设置中添加ws反向代理:ws://localhost:8080/onebot/v11/ws
4. **配置文件设置**
- 将.env文件打开填上你的apikey硅基流动或deepseekapi
- 将bot_config.toml文件打开并填写相关内容不然无法正常运行
#### .env 文件配置说明
```ini
# 环境配置
ENVIRONMENT=dev # 开发环境设置
HOST=127.0.0.1 # 主机地址
PORT=8080 # 端口号
# 命令前缀设置
COMMAND_START=["/"] # 命令起始符
# 插件配置
PLUGINS=["src2.plugins.chat"] # 启用的插件列表
# MongoDB配置
MONGODB_HOST=127.0.0.1 # MongoDB主机地址
MONGODB_PORT=27017 # MongoDB端口
DATABASE_NAME=MegBot # 数据库名称
MONGODB_USERNAME="" # MongoDB用户名可选
MONGODB_PASSWORD="" # MongoDB密码可选
MONGODB_AUTH_SOURCE="" # MongoDB认证源可选
#api配置项建议siliconflow必填识图需要这个
SILICONFLOW_KEY=
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
```
#### bot_config.toml 文件配置说明
```toml
# 数据库设置
[database]
host = "127.0.0.1" # MongoDB主机地址
port = 27017 # MongoDB端口
name = "MegBot" # 数据库名称
# 机器人基本设置
[bot]
qq = # 你的机器人QQ号必填
nickname = "麦麦" # 机器人昵称
# 消息处理设置
[message]
min_text_length = 2 # 最小响应文本长度
max_context_size = 15 # 上下文最大保存数量
emoji_chance = 0.2 # 表情包使用概率
# 表情包功能设置
[emoji]
check_interval = 120 # 表情检查间隔(秒)
register_interval = 10 # 表情注册间隔(秒)
# CQ码设置
[cq_code]
enable_pic_translate = false # 是否启用图片转换(无效)
# 响应设置
[response]
api_using = "siliconflow" # 回复使用的APIsiliconflow/deepseek
model_r1_probability = 0.8 # R1模型使用概率
model_v3_probability = 0.1 # V3模型使用概率
model_r1_distill_probability = 0.1 # R1蒸馏模型使用概率对deepseek api 无效)
# 其他设置
[others]
enable_advance_output = false # 是否启用详细日志输出
# 群组设置
[groups]
talk_allowed = [ # 允许回复的群号列表
# 在这里添加群号,逗号隔开
]
talk_frequency_down = [ # 降低回复频率的群号列表
# 在这里添加群号,逗号隔开
]
ban_user_id = [ # 禁止回复的用户QQ号列表
# 在这里添加QQ号,逗号隔开
]
```
5. **运行麦麦**
在含有bot.py程序的目录下运行如果使用了虚拟环境需要先进入虚拟环境
```bash
nb run
```
6. **运行其他组件**
run_thingking.bat 可以启动可视化的推理界面未完善和消息队列及其他信息预览WIP
knowledge.bat可以将/data/raw_info下的文本文档载入到数据库未启动
- [安装与配置指南](docs/installation.md) - 详细的部署和配置说明
- [项目架构说明](docs/doc1.md) - 项目结构和核心功能实现细节
## 🎯 功能介绍
@@ -206,6 +73,19 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
- 幽默和meme功能WIP的WIP
- 让麦麦玩mcWIP的WIP的WIP
## 开发计划TODOLIST
- 兼容gif的解析和保存
- 小程序转发链接解析
- 对思考链长度限制
- 修复已知bug
- 完善文档
- 修复转发
- config自动生成和检测
- log别用print
- 给发送消息写专门的类
- 改进表情包发送逻辑
## 📌 注意事项
纯编程外行面向cursor编程很多代码史一样多多包涵

61
bot.py
View File

@@ -4,28 +4,50 @@ from nonebot.adapters.onebot.v11 import Adapter
from dotenv import load_dotenv
from loguru import logger
# 加载全局环境变量
root_dir = os.path.dirname(os.path.abspath(__file__))
env_path=os.path.join(root_dir, "config",'.env')
'''彩蛋'''
from colorama import init, Fore
init()
text = "多年以后面对行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
rainbow_text = ""
for i, char in enumerate(text):
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
print(rainbow_text)
'''彩蛋'''
logger.info(f"尝试从 {env_path} 加载环境变量配置")
if os.path.exists(env_path):
load_dotenv(env_path)
logger.success("成功加载环境变量配置")
# 首先加载基础环境变量
if os.path.exists(".env"):
load_dotenv(".env")
logger.success("成功加载基础环境变量配置")
else:
logger.error(f"环境变量配置文件不存在: {env_path}")
logger.error("基础环境变量配置文件 .env 不存在")
exit(1)
# 根据 ENVIRONMENT 加载对应的环境配置
env = os.getenv("ENVIRONMENT")
env_file = f".env.{env}"
# 初始化 NoneBot
nonebot.init(
# napcat 默认使用 8080 端口
websocket_port=8080,
# 设置日志级别
log_level="INFO",
# 设置超级用户
superusers={"你的QQ号"},
# TODO: 这样写会忽略环境变量需要优化 https://nonebot.dev/docs/appendices/config
_env_file=env_path
)
if env_file == ".env.dev" and os.path.exists(env_file):
logger.success("加载开发环境变量配置")
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
elif os.path.exists(".env.prod"):
logger.success("加载环境变量配置")
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
else:
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
exit(1)
# 获取所有环境变量
env_config = {key: os.getenv(key) for key in os.environ}
# 设置基础配置
base_config = {
"websocket_port": int(env_config.get("PORT", 8080)),
"host": env_config.get("HOST", "127.0.0.1"),
"log_level": "INFO",
}
# 合并配置
nonebot.init(**base_config, **env_config)
# 注册适配器
driver = nonebot.get_driver()
@@ -35,4 +57,5 @@ driver.register_adapter(Adapter)
nonebot.load_plugins("src/plugins")
if __name__ == "__main__":
nonebot.run()

12012
char_frequency.json Normal file

File diff suppressed because it is too large Load Diff

46
config/auto_format.py Normal file
View File

@@ -0,0 +1,46 @@
import tomli
import tomli_w
import sys
from pathlib import Path
import os
def sync_configs():
# 读取两个配置文件
try:
with open('bot_config_dev.toml', 'rb') as f: # tomli需要使用二进制模式读取
dev_config = tomli.load(f)
with open('bot_config.toml', 'rb') as f:
prod_config = tomli.load(f)
except FileNotFoundError as e:
print(f"错误:找不到配置文件 - {e}")
sys.exit(1)
except tomli.TOMLDecodeError as e:
print(f"错误TOML格式解析失败 - {e}")
sys.exit(1)
# 递归合并配置
def merge_configs(source, target):
for key, value in source.items():
if key not in target:
target[key] = value
elif isinstance(value, dict) and isinstance(target[key], dict):
merge_configs(value, target[key])
# 将dev配置的新属性合并到prod配置中
merge_configs(dev_config, prod_config)
# 保存更新后的配置
try:
with open('bot_config.toml', 'wb') as f: # tomli_w需要使用二进制模式写入
tomli_w.dump(prod_config, f)
print("配置文件同步完成!")
except Exception as e:
print(f"错误:保存配置文件失败 - {e}")
sys.exit(1)
if __name__ == '__main__':
# 确保在正确的目录下运行
script_dir = Path(__file__).parent
os.chdir(script_dir)
sync_configs()

61
config/bot_config.toml Normal file
View File

@@ -0,0 +1,61 @@
[bot]
qq = 123
nickname = "麦麦"
[message]
min_text_length = 2
max_context_size = 15
emoji_chance = 0.2
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow"
api_paid = true
model_r1_probability = 0.8
model_v3_probability = 0.1
model_r1_distill_probability = 0.1
[memory]
build_memory_interval = 300
[others]
enable_advance_output = true
[groups]
talk_allowed = [
123,
123,
]
talk_frequency_down = []
ban_user_id = []
[model.llm_reasoning]
name = "Pro/deepseek-ai/DeepSeek-R1"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_reasoning_minor]
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal]
name = "Pro/deepseek-ai/DeepSeek-V3"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.llm_normal_minor]
name = "deepseek-ai/DeepSeek-V2.5"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"
[model.vlm]
name = "deepseek-ai/deepseek-vl2"
base_url = "SILICONFLOW_BASE_URL"
key = "SILICONFLOW_KEY"

View File

@@ -1,45 +0,0 @@
[bot]
qq = 123456 #填入你的机器人QQ
nickname = "麦麦" #你希望bot被称呼的名字
[message]
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
max_context_size = 15 # 麦麦获得的上下文数量,超出数量后自动丢弃
emoji_chance = 0.2 # 麦麦使用表情包的概率
[emoji]
check_interval = 120
register_interval = 10
[cq_code]
enable_pic_translate = false
[response]
api_using = "siliconflow" # 选择大模型API可选值为siliconflow,deepseek建议使用siliconflow因为识图api目前只支持siliconflow的deepseek-vl2模型
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
[memory]
build_memory_interval = 300 # 记忆构建间隔
[others]
enable_advance_output = true # 开启后输出更多日志,false关闭true开启
[groups]
talk_allowed = [
123456,12345678
] #可以回复消息的群
talk_frequency_down = [
123456,12345678
] #降低回复频率的群
ban_user_id = [
123456,12345678
] #禁止回复消息的QQ号

View File

@@ -41,9 +41,9 @@ services:
volumes:
- maimbotCONFIG:/MaiMBot/config
- maimbotDATA:/MaiMBot/data
- ./.env.prod:/MaiMBot/.env.prod
image: sengokucola/maimbot:latest
volumes:
maimbotCONFIG:
maimbotDATA:
@@ -52,3 +52,4 @@ volumes:
mongodb:
mongodbCONFIG:

102
docs/installation.md Normal file
View File

@@ -0,0 +1,102 @@
# 🔧 安装与配置指南
## 部署方式
### 🐳 Docker部署推荐
1. 获取配置文件:
```bash
wget https://raw.githubusercontent.com/SengokuCola/MaiMBot/main/docker-compose.yml
```
2. 启动服务:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose up -d
```
3. 修改配置后重启:
```bash
NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
```
### 📦 手动部署
1. **环境准备**
```bash
# 创建虚拟环境(推荐)
python -m venv venv
source venv/bin/activate # Linux
venv\\Scripts\\activate # Windows
# 安装依赖
pip install -r requirements.txt
```
2. **配置MongoDB**
- 安装并启动MongoDB服务
- 默认连接本地27017端口
3. **配置NapCat**
- 安装并登录NapCat
- 添加反向WS`ws://localhost:8080/onebot/v11/ws`
4. **配置文件设置**
- 复制并修改环境配置:`.env.prod`
- 复制并修改机器人配置:`bot_config.toml`
5. **启动服务**
```bash
nb run
```
6. **其他组件**
- `run_thingking.bat`: 启动可视化推理界面(未完善)和消息队列预览
- `knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库
## ⚙️ 配置说明
### 环境配置 (.env.prod)
```ini
# API配置必填
SILICONFLOW_KEY=your_key
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=your_key
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
# 服务配置
HOST=127.0.0.1
PORT=8080
# 数据库配置
MONGODB_HOST=127.0.0.1
MONGODB_PORT=27017
DATABASE_NAME=MegBot
```
### 机器人配置 (bot_config.toml)
```toml
[bot]
qq = "你的机器人QQ号"
nickname = "麦麦"
[message]
max_context_size = 15
emoji_chance = 0.2
[response]
api_using = "siliconflow" # 或 "deepseek"
[others]
enable_advance_output = false # 是否启用详细日志输出
[groups]
talk_allowed = [] # 允许回复的群号列表
talk_frequency_down = [] # 降低回复频率的群号列表
ban_user_id = [] # 禁止回复的用户QQ号列表
```
## ⚠️ 注意事项
- 目前部署方案仍在测试中,可能存在未知问题
- 配置文件中的API密钥请妥善保管不要泄露
- 建议先在测试环境中运行,确认无误后再部署到生产环境

Binary file not shown.

Before

Width:  |  Height:  |  Size: 191 KiB

BIN
docs/video.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

View File

@@ -3,3 +3,4 @@ cd .
REM 执行nb run命令
nb run
pause

View File

@@ -7,6 +7,23 @@ import threading
import queue
import sys
import os
from dotenv import load_dotenv
# 获取当前文件的目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
# 加载环境变量
if os.path.exists(os.path.join(root_dir, '.env.dev')):
load_dotenv(os.path.join(root_dir, '.env.dev'))
print("成功加载开发环境配置")
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
load_dotenv(os.path.join(root_dir, '.env.prod'))
print("成功加载生产环境配置")
else:
print("未找到环境配置文件")
sys.exit(1)
from pymongo import MongoClient
from typing import Optional
@@ -14,14 +31,23 @@ from typing import Optional
class Database:
_instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str):
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
if username and password:
self.client = MongoClient(
host=host,
port=port,
username=username,
password=password,
authSource=auth_source or 'admin'
)
else:
self.client = MongoClient(host, port)
self.db = self.client[db_name]
@classmethod
def initialize(cls, host: str, port: int, db_name: str) -> "Database":
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
if cls._instance is None:
cls._instance = cls(host, port, db_name)
cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance
@classmethod

View File

@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .willing_manager import willing_manager
# 获取驱动器
driver = get_driver()
config = driver.config
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
)
print("\033[1;32m[初始化数据库完成]\033[0m")
@@ -37,7 +39,7 @@ emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
# 创建机器人实例
chat_bot = ChatBot(global_config)
chat_bot = ChatBot()
# 注册消息处理器
group_msg = on_message()
# 创建定时任务
@@ -50,6 +52,7 @@ async def start_background_tasks():
"""启动后台任务"""
# 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
await bot_schedule.initialize()
bot_schedule.print_schedule()
@driver.on_startup
@@ -88,7 +91,7 @@ async def monitor_relationships():
async def build_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
hippocampus.build_memory(chat_size=12)
await hippocampus.build_memory(chat_size=30)
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")

View File

@@ -2,7 +2,7 @@ from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessa
from .message import Message,MessageSet
from .config import BotConfig, global_config
from .storage import MessageStorage
from .llm_generator import LLMResponseGenerator
from .llm_generator import ResponseGenerator
from .message_stream import MessageStream, MessageStreamContainer
from .topic_identifier import topic_identifier
from random import random, choice
@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
from ..memory_system.memory import memory_graph
class ChatBot:
def __init__(self, config: BotConfig):
self.config = config
def __init__(self):
self.storage = MessageStorage()
self.gpt = LLMResponseGenerator(config)
self.gpt = ResponseGenerator()
self.bot = None # bot 实例引用
self._started = False
@@ -39,11 +38,11 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
if event.group_id not in self.config.talk_allowed_groups:
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
if event.user_id in self.config.ban_user_id:
if event.user_id in global_config.ban_user_id:
return
# 打印原始消息内容
@@ -120,7 +119,7 @@ class ChatBot:
event.group_id,
topic[0] if topic else None,
is_mentioned,
self.config,
global_config,
event.user_id,
message.is_emoji,
interested_rate
@@ -144,9 +143,13 @@ class ChatBot:
# 如果生成了回复,发送并记录
'''
生成回复后的内容
'''
if response:
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
accu_typing_time = 0
for msg in response:
print(f"当前消息: {msg}")
@@ -157,7 +160,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
user_id=self.config.BOT_QQ,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
@@ -174,7 +177,7 @@ class ChatBot:
bot_response_time = tinking_time_point
if random() < self.config.emoji_chance:
if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path)
@@ -186,7 +189,7 @@ class ChatBot:
bot_message = Message(
group_id=event.group_id,
user_id=self.config.BOT_QQ,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, Set
import os
from nonebot.log import logger, default_format
@@ -7,6 +7,7 @@ import configparser
import tomli
import sys
from loguru import logger
from nonebot import get_driver
@@ -31,28 +32,36 @@ class BotConfig:
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
# 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
vlm: Dict[str, str] = field(default_factory=lambda: {})
API_USING: str = "siliconflow" # 使用的API
API_PAID: bool = False # 是否使用付费API
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
@staticmethod
def get_default_config_path() -> str:
"""获取默认配置文件路径"""
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
return os.path.join(config_dir, 'bot_config.toml')
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
if config_path is None:
config_path = cls.get_default_config_path()
logger.info(f"使用默认配置文件路径: {config_path}")
config = cls()
if os.path.exists(config_path):
with open(config_path, "rb") as f:
@@ -80,6 +89,26 @@ class BotConfig:
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.API_USING = response_config.get("api_using", config.API_USING)
config.API_PAID = response_config.get("api_paid", config.API_PAID)
# 加载模型配置
if "model" in toml_dict:
model_config = toml_dict["model"]
if "llm_reasoning" in model_config:
config.llm_reasoning = model_config["llm_reasoning"]
if "llm_reasoning_minor" in model_config:
config.llm_reasoning_minor = model_config["llm_reasoning_minor"]
if "llm_normal" in model_config:
config.llm_normal = model_config["llm_normal"]
if "llm_normal_minor" in model_config:
config.llm_normal_minor = model_config["llm_normal_minor"]
if "vlm" in model_config:
config.vlm = model_config["vlm"]
# 消息配置
if "message" in toml_dict:
@@ -108,13 +137,21 @@ class BotConfig:
return config
# 获取配置文件路径
bot_config_path = BotConfig.get_default_config_path()
config_dir = os.path.dirname(bot_config_path)
env_path = os.path.join(config_dir, '.env')
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
bot_config_floder_path = BotConfig.get_config_dir()
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
else:
logger.info("已找到开发环境配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
@@ -125,12 +162,14 @@ class LLMConfig:
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
pass

View File

@@ -7,15 +7,20 @@ from PIL import Image
import os
from random import random
from nonebot.adapters.onebot.v11 import Bot
from .config import global_config, llm_config
from .config import global_config
import time
import asyncio
from .utils_image import storage_image,storage_emoji
from .utils_user import get_user_nickname
from ..models.utils_model import LLM_request
#解析各种CQ码
#包含CQ码类
import urllib3
from urllib3.util import create_urllib3_context
from nonebot import get_driver
driver = get_driver()
config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
ctx = create_urllib3_context()
@@ -53,6 +58,11 @@ class CQCode:
translated_plain_text: Optional[str] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def translate(self):
"""根据CQ码类型进行相应的翻译处理"""
@@ -157,7 +167,7 @@ class CQCode:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return self.get_image_description(base64_str)
return self.get_emoji_description(base64_str)
else:
return '[表情包]'
@@ -177,93 +187,23 @@ class CQCode:
def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
}
payload = {
"model": "deepseek-ai/deepseek-vl2",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
"max_tokens": 50,
"temperature": 0.4
}
response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
headers=headers,
json=payload,
timeout=30
)
if response.status_code == 200:
result_json = response.json()
if "choices" in result_json and len(result_json["choices"]) > 0:
description = result_json["choices"][0]["message"]["content"]
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
return f"[表情包:{description}]"
raise ValueError(f"AI接口调用失败: {response.text}")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]"
def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
}
payload = {
"model": "deepseek-ai/deepseek-vl2",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.6
}
response = requests.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
headers=headers,
json=payload,
timeout=30
)
if response.status_code == 200:
result_json = response.json()
if "choices" in result_json and len(result_json["choices"]) > 0:
description = result_json["choices"][0]["message"]["content"]
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
return f"[图片:{description}]"
raise ValueError(f"AI接口调用失败: {response.text}")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
def translate_forward(self) -> str:
"""处理转发消息"""
@@ -345,7 +285,7 @@ class CQCode:
# 创建Message对象
from .message import Message
if self.reply_message == None:
print(f"\033[1;31m[错误]\033[0m 回复消息为空")
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message.sender.user_id:

View File

@@ -10,10 +10,16 @@ import hashlib
from datetime import datetime
import base64
import shutil
from .config import global_config, llm_config
import asyncio
import time
from nonebot import get_driver
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
class EmojiManager:
_instance = None
@@ -39,6 +45,7 @@ class EmojiManager:
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
@@ -83,55 +90,23 @@ class EmojiManager:
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
async def _get_emotion_from_text(self, text: str) -> List[str]:
"""从文本中识别情感关键词使用DeepSeek API进行分析
"""从文本中识别情感关键词
Args:
text: 输入文本
Returns:
List[str]: 匹配到的情感标签列表
"""
try:
# 准备请求数据
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
}
prompt = f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签不要输出其他任何内容。'
payload = {
"model": "deepseek-ai/DeepSeek-V3",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签不要输出其他任何内容。'
}
]
}
],
"max_tokens": 50,
"temperature": 0.3
}
content, _ = await self.llm.generate_response(prompt)
emotion = content.strip().lower()
async with aiohttp.ClientSession() as session:
async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
headers=headers,
json=payload
) as response:
if response.status != 200:
print(f"\033[1;31m[错误]\033[0m API请求失败: {await response.text()}")
return ['neutral']
result = json.loads(await response.text())
if "choices" in result and len(result["choices"]) > 0:
emotion = result["choices"][0]["message"]["content"].strip().lower()
# 确保返回的标签是有效的
if emotion in self.EMOTION_KEYWORDS:
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
return [emotion] # 返回单个情感标签的列表
return [emotion]
return ['neutral'] # 如果无法识别情感返回neutral
return ['neutral']
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
@@ -246,52 +221,20 @@ class EmojiManager:
async def _get_emoji_tag(self, image_base64: str) -> str:
"""获取表情包的标签"""
async with aiohttp.ClientSession() as session:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
}
try:
prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签不要输出其他任何内容只输出情感标签就好'
payload = {
"model": "deepseek-ai/deepseek-vl2",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签不要输出其他任何内容只输出情感标签就好'
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
"max_tokens": 60,
"temperature": 0.3
}
async with session.post(
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
headers=headers,
json=payload
) as response:
if response.status == 200:
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
tag_result = result["choices"][0]["message"]["content"].strip().lower()
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
tag_result = content.strip().lower()
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
for tag_match in valid_tags:
if tag_match in tag_result or tag_match == tag_result:
return tag_match
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_match}, 跳过")
else:
print(f"\033[1;31m[错误]\033[0m 获取标签失败, 状态码: {response.status}")
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
return "skip" # 默认标签

View File

@@ -1,205 +1,126 @@
from typing import Dict, Any, List, Optional, Union, Tuple
from openai import OpenAI
import asyncio
import requests
from functools import partial
from .message import Message
from .config import BotConfig, global_config
from .config import global_config
from ...common.database import Database
import random
import time
import os
import numpy as np
from .relationship_manager import relationship_manager
from ..schedule.schedule_generator import bot_schedule
from .prompt_builder import prompt_builder
from .config import llm_config, global_config
from .config import global_config
from .utils import process_llm_response
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
class LLMResponseGenerator:
def __init__(self, config: BotConfig):
self.config = config
if self.config.API_USING == "siliconflow":
self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY,
base_url=llm_config.SILICONFLOW_BASE_URL
)
elif self.config.API_USING == "deepseek":
self.client = OpenAI(
api_key=llm_config.DEEP_SEEK_API_KEY,
base_url=llm_config.DEEP_SEEK_BASE_URL
)
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
self.db = Database.get_instance()
# 当前使用的模型类型
self.current_model_type = 'r1' # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值
model_r1_probability = global_config.MODEL_R1_PROBABILITY
model_v3_probability = global_config.MODEL_V3_PROBABILITY
model_r1_distill_probability = global_config.MODEL_R1_DISTILL_PROBABILITY
# 生成随机数并根据概率选择模型
# 从global_config中获取模型概率值并选择模型
rand = random.random()
if rand < model_r1_probability:
if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1'
elif rand < model_r1_probability + model_v3_probability:
current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = 'v3'
current_model = self.model_v3
else:
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
self.current_model_type = 'r1_distill'
current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
if self.current_model_type == 'r1':
model_response = await self._generate_r1_response(message)
elif self.current_model_type == 'v3':
model_response = await self._generate_v3_response(message)
else:
model_response = await self._generate_r1_distill_response(message)
# 打印情感标签
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response, emotion = await self._process_response(model_response)
model_response = await self._generate_response_with_model(message, current_model)
if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response, emotion = await self._process_response(model_response)
if model_response:
print(f"'{model_response}' 获取到的情感标签为:{emotion}")
return model_response, emotion
return None, []
async def _generate_base_response(
self,
message: Message,
model_name: str,
model_params: Optional[Dict[str, Any]] = None
) -> Optional[str]:
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
# 获取关系值
if relationship_manager.get_relationship(message.user_id):
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
if relationship_value != 0.0:
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
else:
relationship_value = 0.0
# 构建prompt
prompt = prompt_builder._build_prompt(
prompt, prompt_check = prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
relationship_value=relationship_value,
group_id=message.group_id
)
# 设置默认参数
default_params = {
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"max_tokens": 1024,
"temperature": 0.7
}
# 更新参数
if model_params:
default_params.update(model_params)
def create_completion():
return self.client.chat.completions.create(**default_params)
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, create_completion)
# 检查响应内容
if not response:
print("请求未返回任何内容")
# 读空气模块
if global_config.enable_kuuki_read:
content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
if 'yes' not in content_check.lower() and random.random() < 0.3:
self._save_to_db(
message=message,
sender_name=sender_name,
prompt=prompt,
prompt_check=prompt_check,
content="",
content_check=content_check,
reasoning_content="",
reasoning_content_check=reasoning_content_check
)
return None
if not response.choices or not response.choices[0].message.content:
print("请求返回的内容无效:", response)
return None
content = response.choices[0].message.content
# 获取推理内容
reasoning_content = ""
if hasattr(response.choices[0].message, "reasoning"):
reasoning_content = response.choices[0].message.reasoning or reasoning_content
elif hasattr(response.choices[0].message, "reasoning_content"):
reasoning_content = response.choices[0].message.reasoning_content or reasoning_content
# 生成回复
content, reasoning_content = await model.generate_response(prompt)
# 保存到数据库
self._save_to_db(
message=message,
sender_name=sender_name,
prompt=prompt,
prompt_check=prompt_check,
content=content,
content_check=content_check if global_config.enable_kuuki_read else "",
reasoning_content=reasoning_content,
reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
)
return content
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': model_name,
'model': self.current_model_type,
'reasoning_check': reasoning_content_check,
'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
'model_params': default_params
'prompt_check': prompt_check
})
return content
async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复"""
if self.config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-reasoner",
{"temperature": 0.7, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-R1",
{"temperature": 0.7, "max_tokens": 1024}
)
async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复"""
if self.config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-chat",
{"temperature": 0.8, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-V3",
{"temperature": 0.8, "max_tokens": 1024}
)
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
return await self._generate_base_response(
message,
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
{"temperature": 0.7, "max_tokens": 1024}
)
async def _get_group_chat_context(self, message: Message) -> str:
"""获取群聊上下文"""
recent_messages = self.db.db.messages.find(
{"group_id": message.group_id}
).sort("time", -1).limit(15)
messages_list = list(recent_messages)[::-1]
group_chat = ""
for msg_dict in messages_list:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(msg_dict['time']))
display_name = msg_dict.get('user_nickname', f"用户{msg_dict['user_id']}")
content = msg_dict.get('processed_plain_text', msg_dict['plain_text'])
group_chat += f"[{time_str}] {display_name}: {content}\n"
return group_chat
async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签"""
try:
@@ -209,33 +130,12 @@ class LLMResponseGenerator:
输出:
'''
messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop()
if self.config.API_USING == "deepseek":
model = "deepseek-chat"
else:
model = "Pro/deepseek-ai/DeepSeek-V3"
create_completion = partial(
self.client.chat.completions.create,
model=model,
messages=messages,
stream=False,
max_tokens=30,
temperature=0.6
)
response = await loop.run_in_executor(None, create_completion)
if response.choices[0].message.content:
# 确保返回的是列表格式
emotion_tag = response.choices[0].message.content.strip()
return [emotion_tag] # 将单个标签包装成列表返回
return ["neutral"] # 如果无法获取情感标签,返回默认值
content, _ = await self.model_v3.generate_response(prompt)
return [content.strip()] if content else ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
return ["neutral"] # 发生错误时返回默认值
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签"""
@@ -243,10 +143,6 @@ class LLMResponseGenerator:
return None, []
emotion_tags = await self._get_emotion_tags(content)
processed_response = process_llm_response(content)
return processed_response, emotion_tags
# 创建全局实例
llm_response = LLMResponseGenerator(global_config)

View File

@@ -66,15 +66,21 @@ class PromptBuilder:
overlapping_second_layer.update(overlap)
# 合并所有需要的记忆
if all_first_layer_items:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
if overlapping_second_layer:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
# if all_first_layer_items:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
# if overlapping_second_layer:
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
all_memories = all_first_layer_items + list(overlapping_second_layer)
# 使用集合去重
# 从每个来源随机选择2条记忆如果有的话
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
if all_memories: # 只在列表非空时选择随机项
random_item = choice(all_memories)
# 合并并去重
all_memories = list(set(selected_first_layer + selected_second_layer))
if all_memories:
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
random_item = " ".join(all_memories)
memory_prompt = f"看到这些聊天,你想起来{random_item}\n"
else:
memory_prompt = "" # 如果没有记忆,则返回空字符串
@@ -112,7 +118,7 @@ class PromptBuilder:
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
promt_info_prompt = '你有一些[知识],在上面可以参考。'
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")
@@ -144,9 +150,9 @@ class PromptBuilder:
prompt_personality = ''
personality_choice = random.random()
if personality_choice < 4/6: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}曾经是一个学习地质的女大学生,现在学习心理学和脑科学你会刷贴吧你正在浏览qq群,{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}现在学习心理学和脑科学你会刷贴吧你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
请注意把握群里的聊天内容,不要回复的太有条理,可以有个性。'''
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < 1: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
@@ -181,7 +187,21 @@ class PromptBuilder:
prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n"
return prompt
'''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < 4/6: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}曾经是一个学习地质的女大学生现在学习心理学和脑科学你会刷贴吧你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < 1: # 第二种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response
def get_prompt_info(self,message:str,threshold:float):
related_info = ''

View File

@@ -1,14 +1,18 @@
from typing import Optional, Dict, List
from openai import OpenAI
from .message import Message
from .config import global_config, llm_config
import jieba
from nonebot import get_driver
from .config import global_config
driver = get_driver()
config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY,
base_url=llm_config.SILICONFLOW_BASE_URL
api_key=config.siliconflow_key,
base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:
@@ -21,7 +25,7 @@ class TopicIdentifier:
消息内容:{text}"""
response = self.client.chat.completions.create(
model="Pro/deepseek-ai/DeepSeek-V3",
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10

View File

@@ -4,11 +4,15 @@ from typing import List
from .message import Message
import requests
import numpy as np
from .config import llm_config, global_config
from .config import global_config
import re
from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
@@ -64,7 +68,7 @@ def get_embedding(text):
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
@@ -181,6 +185,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
message_detailed_plain_text = ''
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:

View File

@@ -7,6 +7,10 @@ from ...common.database import Database
import zlib # 用于 CRC32
import base64
from .config import global_config
from nonebot import get_driver
driver = get_driver()
config = driver.config
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
# 连接数据库
db = Database(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片

View File

@@ -58,8 +58,8 @@ class WillingManager:
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / 3.5
if is_mentioned_bot and user_id == int(1026294844):
reply_probability = 1
# if is_mentioned_bot and user_id == int(1026294844):
# reply_probability = 1
return reply_probability

View File

@@ -3,6 +3,10 @@ import sys
import numpy as np
import requests
import time
from nonebot import get_driver
driver = get_driver()
config = driver.config
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -13,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
class KnowledgeLibrary:

View File

@@ -168,10 +168,12 @@ def main():
memory_graph.load_graph_from_db()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
# visualize_graph(memory_graph, color_by_memory=False)
visualize_graph_lite(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# visualize_graph(memory_graph, color_by_memory=True)
visualize_graph_lite(memory_graph, color_by_memory=True)
# memory_graph.save_graph_to_db()
@@ -262,7 +264,89 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 2 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
count = 1 if memory_items else 0
memory_counts.append(count)
max_memories = max(memory_counts) if memory_counts else 1
for count in memory_counts:
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
if max_memories > 0:
intensity = min(1.0, count / max_memories)
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
else:
color = (0, 0, 1) # 如果没有记忆,则为蓝色
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
color = (red, 0, blue)
else:
color = (0, 0, 1)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
font_size=10,
font_family='SimHei',
font_weight='bold')
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()

View File

@@ -1,62 +0,0 @@
import os
import requests
from typing import Tuple, Union
import time
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15 # 基础等待时间(秒)
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except requests.exceptions.RequestException as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -2,22 +2,28 @@ import os
import requests
from typing import Tuple, Union
import time
from ..chat.config import BotConfig
from nonebot import get_driver
import aiohttp
import asyncio
from src.plugins.chat.config import BotConfig, global_config
driver = get_driver()
config = driver.config
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.api_key = config.siliconflow_key
self.base_url = config.siliconflow_base_url
if not self.api_key or not self.base_url:
raise ValueError("环境变量未正确加载SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
def generate_response(self, prompt: str) -> Tuple[str, str]:
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
@@ -40,28 +46,28 @@ class LLMModel:
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data)
if response.status_code == 429:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except requests.exceptions.RequestException as e:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""

View File

@@ -1,19 +1,16 @@
# -*- coding: utf-8 -*-
import os
import jieba
from .llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
from collections import Counter
import datetime
import random
import time
from ..chat.config import global_config
import sys
from ...common.database import Database # 使用正确的导入语法
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -169,8 +166,8 @@ class Memory_graph:
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
@@ -193,13 +190,29 @@ class Hippocampus:
chat_text.append(chat_)
return chat_text
def build_memory(self,chat_size=12):
async def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
topic_prompt = find_topic(input_text, topic_num)
topic_response = await self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
async def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
for i, input_text in enumerate(memory_sample, 1):
#加载进度可视化
progress = (i / len(memory_sample)) * 100
@@ -207,12 +220,10 @@ class Hippocampus:
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = self.memory_compress(input_text, 2.5)
# 延时防止访问超频
# time.sleep(5)
first_memory = await self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
@@ -223,28 +234,9 @@ class Hippocampus:
for other_split_topic in topics:
if split_topic != other_split_topic:
self.memory_graph.connect_dot(split_topic, other_split_topic)
self.memory_graph.save_graph_to_db()
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
# print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
# print(topics)
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
print(f"空消息 跳过")
self.memory_graph.save_graph_to_db()
def segment_text(text):
@@ -260,16 +252,19 @@ def topic_what(text, topic):
return prompt
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time()
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -14,6 +14,37 @@ sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
from src.plugins.memory_system.llm_module import LLMModel
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录"""
chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
return chat_text
return ''
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -102,7 +133,8 @@ class Memory_graph:
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
@@ -110,6 +142,7 @@ class Memory_graph:
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
for record in chat_record:
if record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
@@ -187,156 +220,81 @@ class Memory_graph:
for edge in edges:
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
def calculate_information_content(text):
# 海马体
class Hippocampus:
def __init__(self,memory_graph:Memory_graph):
self.memory_graph = memory_graph
self.llm_model = LLMModel()
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
"""计算文本的信息量(熵)"""
# 统计字符频率
char_count = Counter(text)
total_chars = len(text)
# 计算熵
entropy = 0
for count in char_count.values():
probability = count / total_chars
entropy -= probability * math.log2(probability)
return entropy
# Database.initialize(
# global_config.MONGODB_HOST,
# global_config.MONGODB_PORT,
# global_config.DATABASE_NAME
# )
# memory_graph = Memory_graph()
# llm_model = LLMModel()
# llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# memory_graph.load_graph_from_db()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
memory_graph = Memory_graph()
# 创建LLM模型实例
llm_model = LLMModel()
llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
# 使用当前时间戳进行测试
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
current_timestamp = datetime.datetime.now().timestamp()
chat_text = []
#短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): # 循环10次
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('mid')): # 循环10次
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
for _ in range(time_frequency.get('far')): # 循环10次
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
chat_text.append(chat_)
return chat_text
chat_size =25
def build_memory(self,chat_size=12):
#最近消息获取频率
time_frequency = {'near':1,'mid':2,'far':2}
memory_sample = self.get_memory_sample(chat_size,time_frequency)
for _ in range(30): # 循环10次
random_time = current_timestamp - random.randint(1, 3600*10) # 随机时间
print(f"随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
chat_ = memory_graph.get_random_chat_from_db(chat_size, random_time)
chat_text.append(chat_) # 拼接所有text
# time.sleep(1)
for i, input_text in enumerate(chat_text, 1):
progress = (i / len(chat_text)) * 100
#加载进度可视化
for i, input_text in enumerate(memory_sample, 1):
progress = (i / len(memory_sample)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(chat_text))
filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(chat_text)})")
# print(input_text)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# print(f"第{i}条消息: {input_text}")
if input_text:
# 生成压缩后记忆
first_memory = set()
first_memory = memory_compress(input_text, llm_model_small, llm_model_small, rate=2.5)
# time.sleep(5)
first_memory = self.memory_compress(input_text, 2.5)
#将记忆加入到图谱中
for topic, memory in first_memory:
topics = segment_text(topic)
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
for split_topic in topics:
memory_graph.add_dot(split_topic,memory)
self.memory_graph.add_dot(split_topic,memory)
for split_topic in topics:
for other_split_topic in topics:
if split_topic != other_split_topic:
memory_graph.connect_dot(split_topic, other_split_topic)
# memory_graph.store_memory()
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
memory_graph.save_graph_to_db()
# memory_graph.load_graph_from_db()
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
self.memory_graph.connect_dot(split_topic, other_split_topic)
else:
print("未找到相关记忆。")
print(f"空消息 跳过")
while True:
query = input("请输入问题:")
self.memory_graph.save_graph_to_db()
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
def memory_compress(input_text, llm_model, llm_model_small, rate=1):
def memory_compress(self, input_text, rate=1):
information_content = calculate_information_content(input_text)
print(f"文本的信息量(熵): {information_content:.4f} bits")
topic_num = max(1, min(5, int(information_content * rate / 4)))
print(topic_num)
topic_prompt = find_topic(input_text, topic_num)
topic_response = llm_model.generate_response(topic_prompt)
topic_response = self.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
compressed_memory = set()
for topic in topics:
topic_what_prompt = topic_what(input_text,topic)
topic_what_response = llm_model_small.generate_response(topic_what_prompt)
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
return compressed_memory
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
@@ -356,18 +314,37 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
memory_items = H.nodes[node].get('memory_items', [])
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
degree = H.degree(node)
if memory_count <= 1 or degree <= 2:
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
nx.write_gml(G, "memory_graph.gml") # 保存为 GML 格式
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
# 根据连接条数或记忆数量设置节点颜色
node_colors = []
nodes = list(G.nodes()) # 获取图中实际的节点列表
nodes = list(H.nodes()) # 获取图中实际的节点列表
if color_by_memory:
# 计算每个节点的记忆数量
memory_counts = []
for node in nodes:
memory_items = G.nodes[node].get('memory_items', [])
memory_items = H.nodes[node].get('memory_items', [])
if isinstance(memory_items, list):
count = len(memory_items)
else:
@@ -385,9 +362,9 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
node_colors.append(color)
else:
# 使用原来的连接数量着色方案
max_degree = max(G.degree(), key=lambda x: x[1])[1] if G.degree() else 1
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
for node in nodes:
degree = G.degree(node)
degree = H.degree(node)
if max_degree > 0:
red = min(1.0, degree / max_degree)
blue = 1.0 - red
@@ -398,8 +375,8 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, k=1, iterations=50)
nx.draw(G, pos,
pos = nx.spring_layout(H, k=1, iterations=50)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=2000,
@@ -411,6 +388,71 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
start_time = time.time()
# 创建记忆图
memory_graph = Memory_graph()
# 加载数据库中存储的记忆图
memory_graph.load_graph_from_db()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
hippocampus.build_memory(chat_size=25)
# 展示两种不同的可视化方式
print("\n按连接数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=False)
print("\n按记忆数量着色的图谱:")
visualize_graph(memory_graph, color_by_memory=True)
# 交互式查询
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
for memory_item in items_list:
print(memory_item)
else:
print("未找到相关记忆。")
while True:
query = input("请输入问题:")
if query.lower() == '退出':
break
topic_prompt = find_topic(query, 3)
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
# 检查 topic_response 是否为元组
if isinstance(topic_response, tuple):
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
else:
topics = topic_response.split(",")
print(topics)
for keyword in topics:
items_list = memory_graph.get_related_item(keyword)
if items_list:
print(items_list)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,199 @@
import aiohttp
import asyncio
import requests
import time
from typing import Tuple, Union
from nonebot import get_driver
from ..chat.config import global_config
driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
self.base_url = getattr(config, model["base_url"])
except AttributeError as e:
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}")
self.model_name = model["name"]
self.params = kwargs
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的异步响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""同步方法:根据输入的提示和图片生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""

View File

@@ -1,37 +1,47 @@
import datetime
import os
from typing import List, Dict
from .schedule_llm_module import LLMModel
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
from src.plugins.chat.config import global_config
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
)
class ScheduleGenerator:
def __init__(self):
if global_config.API_USING == "siliconflow":
self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3")
elif global_config.API_USING == "deepseek":
self.llm_scheduler = LLMModel(model_name="deepseek-chat",api_using="deepseek")
#根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model = global_config.llm_normal,temperature=0.9)
self.db = Database.get_instance()
self.today_schedule_text = ""
self.today_schedule = {}
self.tomorrow_schedule_text = ""
self.tomorrow_schedule = {}
self.yesterday_schedule_text = ""
self.yesterday_schedule = {}
async def initialize(self):
today = datetime.datetime.now()
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
self.today_schedule_text, self.today_schedule = self.generate_daily_schedule(target_date=today)
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today)
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,read_only=True)
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
self.tomorrow_schedule_text, self.tomorrow_schedule = self.generate_daily_schedule(target_date=tomorrow,read_only=True)
self.yesterday_schedule_text, self.yesterday_schedule = self.generate_daily_schedule(target_date=yesterday,read_only=True)
def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
if target_date is None:
target_date = datetime.datetime.now()
@@ -55,7 +65,7 @@ class ScheduleGenerator:
3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床""""
schedule_text, _ = self.llm_scheduler.generate_response(prompt)
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
# print(self.schedule_text)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
else:

View File

@@ -1,59 +0,0 @@
import os
import requests
from typing import Tuple, Union
class LLMModel:
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1",api_using=None, **kwargs):
if api_using == "deepseek":
self.api_key = os.getenv("DEEP_SEEK_KEY")
self.base_url = os.getenv("DEEP_SEEK_BASE_URL")
if model_name != "Pro/deepseek-ai/DeepSeek-R1":
self.model_name = model_name
else:
self.model_name = "deepseek-reasoner"
else:
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
self.model_name = model_name
self.params = kwargs
def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.9,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
try:
response = requests.post(api_url, headers=headers, json=data)
response.raise_for_status() # 检查响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content # 返回内容和推理内容
return "没有返回结果", "" # 返回两个值
except requests.exceptions.RequestException as e:
return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串
# 示例用法
if __name__ == "__main__":
model = LLMModel() # 默认使用 DeepSeek-V3 模型
prompt = "你好,你喜欢我吗?"
result, reasoning = model.generate_response(prompt)
print("回复内容:", result)
print("推理内容:", reasoning)

70
src/test/emotion_cal.py Normal file
View File

@@ -0,0 +1,70 @@
from textblob import TextBlob
import jieba
from translate import Translator
def analyze_emotion(text):
"""
分析文本的情感,返回情感极性和主观性得分
:param text: 输入文本
:return: (情感极性, 主观性) 元组
情感极性: -1(非常消极) 到 1(非常积极)
主观性: 0(客观) 到 1(主观)
"""
try:
# 创建翻译器
translator = Translator(to_lang="en", from_lang="zh")
# 如果是中文文本,先翻译成英文
# 因为TextBlob的情感分析主要基于英文
translated_text = translator.translate(text)
# 创建TextBlob对象
blob = TextBlob(translated_text)
# 获取情感极性和主观性
polarity = blob.sentiment.polarity
subjectivity = blob.sentiment.subjectivity
return polarity, subjectivity
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None, None
def get_emotion_description(polarity, subjectivity):
"""
根据情感极性和主观性生成描述性文字
"""
if polarity is None or subjectivity is None:
return "无法分析情感"
# 情感极性描述
if polarity > 0.5:
emotion = "非常积极"
elif polarity > 0:
emotion = "较为积极"
elif polarity == 0:
emotion = "中性"
elif polarity > -0.5:
emotion = "较为消极"
else:
emotion = "非常消极"
# 主观性描述
if subjectivity > 0.7:
subj = "非常主观"
elif subjectivity > 0.3:
subj = "较为主观"
else:
subj = "较为客观"
return f"情感倾向: {emotion}, 表达方式: {subj}"
if __name__ == "__main__":
# 测试样例
test_text = "今天天气真好,我感到非常开心!"
polarity, subjectivity = analyze_emotion(test_text)
print(f"测试文本: {test_text}")
print(f"情感极性: {polarity:.2f}")
print(f"主观性得分: {subjectivity:.2f}")
print(get_emotion_description(polarity, subjectivity))

View File

@@ -0,0 +1,74 @@
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
def setup_bert_analyzer():
"""
设置中文BERT情感分析器
"""
# 使用专门针对中文情感分析的模型
model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
try:
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 创建情感分析pipeline
analyzer = pipeline("sentiment-analysis",
model=model,
tokenizer=tokenizer)
return analyzer
except Exception as e:
print(f"模型加载错误: {str(e)}")
return None
def analyze_emotion_bert(text, analyzer):
"""
使用BERT模型进行中文情感分析
"""
try:
if not analyzer:
return None
# 进行情感分析
result = analyzer(text)[0]
return {
'label': result['label'],
'score': result['score']
}
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_bert(result):
"""
将BERT的情感分析结果转换为描述性文字
"""
if not result:
return "无法分析情感"
label = "积极" if result['label'] == 'positive' else "消极"
confidence = result['score']
if confidence > 0.9:
strength = "强烈"
elif confidence > 0.7:
strength = "明显"
else:
strength = "轻微"
return f"{strength}{label}"
if __name__ == "__main__":
# 初始化分析器
analyzer = setup_bert_analyzer()
# 测试样例
test_text = "这个产品质量很好,使用起来非常方便,推荐购买!"
result = analyze_emotion_bert(test_text, analyzer)
print(f"测试文本: {test_text}")
if result:
print(f"情感倾向: {get_emotion_description_bert(result)}")
print(f"置信度: {result['score']:.2f}")

View File

@@ -0,0 +1,62 @@
import hanlp
def analyze_emotion_hanlp(text):
"""
使用HanLP进行中文情感分析
"""
try:
# 使用更基础的模型
tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
# 分词
words = tokenizer(text)
# 简单的情感词典方法
positive_words = {'', '', '优秀', '喜欢', '开心', '快乐', '美味', '推荐', '优质', '满意'}
negative_words = {'', '', '', '讨厌', '失望', '难受', '恶心', '不满', '差劲', '垃圾'}
# 计算情感得分
score = 0
for word in words:
if word in positive_words:
score += 1
elif word in negative_words:
score -= 1
# 归一化得分
if score > 0:
return 1
elif score < 0:
return 0
else:
return 0.5
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_hanlp(score):
"""
将HanLP的情感分析结果转换为描述性文字
"""
if score is None:
return "无法分析情感"
elif score == 1:
return "积极"
elif score == 0:
return "消极"
else:
return "中性"
if __name__ == "__main__":
# 测试样例
test_texts = [
"这家餐厅的服务态度很好,菜品也很美味!",
"这个产品质量太差了,一点都不值这个价",
"今天天气不错,但是工作很累"
]
for test_text in test_texts:
result = analyze_emotion_hanlp(test_text)
print(f"\n测试文本: {test_text}")
print(f"情感倾向: {get_emotion_description_hanlp(result)}")

View File

@@ -0,0 +1,53 @@
from snownlp import SnowNLP
def analyze_emotion_snownlp(text):
"""
使用SnowNLP进行中文情感分析
:param text: 输入文本
:return: 情感得分(0-1之间越接近1越积极)
"""
try:
s = SnowNLP(text)
sentiment_score = s.sentiments
# 获取文本的关键词
keywords = s.keywords(3)
return {
'sentiment_score': sentiment_score,
'keywords': keywords,
'summary': s.summary(1) # 生成文本摘要
}
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_snownlp(score):
"""
将情感得分转换为描述性文字
"""
if score is None:
return "无法分析情感"
if score > 0.8:
return "非常积极"
elif score > 0.6:
return "较为积极"
elif score > 0.4:
return "中性偏积极"
elif score > 0.2:
return "中性偏消极"
else:
return "消极"
if __name__ == "__main__":
# 测试样例
test_text = "我们学校有免费的gpt4用"
result = analyze_emotion_snownlp(test_text)
if result:
print(f"测试文本: {test_text}")
print(f"情感得分: {result['sentiment_score']:.2f}")
print(f"情感倾向: {get_emotion_description_snownlp(result['sentiment_score'])}")
print(f"关键词: {', '.join(result['keywords'])}")
print(f"文本摘要: {result['summary'][0]}")

54
src/test/snownlp_demo.py Normal file
View File

@@ -0,0 +1,54 @@
from snownlp import SnowNLP
def demo_snownlp_features(text):
"""
展示SnowNLP的主要功能
:param text: 输入文本
"""
print(f"\n=== SnowNLP功能演示 ===")
print(f"输入文本: {text}")
# 创建SnowNLP对象
s = SnowNLP(text)
# 1. 分词
print(f"\n1. 分词结果:")
print(f" {' | '.join(s.words)}")
# 2. 情感分析
print(f"\n2. 情感分析:")
sentiment = s.sentiments
print(f" 情感得分: {sentiment:.2f}")
print(f" 情感倾向: {'积极' if sentiment > 0.5 else '消极' if sentiment < 0.5 else '中性'}")
# 3. 关键词提取
print(f"\n3. 关键词提取:")
print(f" {', '.join(s.keywords(3))}")
# 4. 词性标注
print(f"\n4. 词性标注:")
print(f" {' '.join([f'{word}/{tag}' for word, tag in s.tags])}")
# 5. 拼音转换
print(f"\n5. 拼音:")
print(f" {' '.join(s.pinyin)}")
# 6. 文本摘要
if len(text) > 100: # 只对较长文本生成摘要
print(f"\n6. 文本摘要:")
print(f" {' '.join(s.summary(3))}")
if __name__ == "__main__":
# 测试用例
test_texts = [
"这家新开的餐厅很不错,菜品种类丰富,味道可口,服务态度也很好,价格实惠,强烈推荐大家来尝试!",
"这部电影剧情混乱,演技浮夸,特效粗糙,配乐难听,完全浪费了我的时间和票价。",
"""人工智能正在改变我们的生活方式。它能够帮助我们完成复杂的计算任务,
提供个性化的服务推荐,优化交通路线,辅助医疗诊断。但同时我们也要警惕
人工智能带来的问题,比如隐私安全、就业变化等。如何正确认识和利用人工智能,
是我们每个人都需要思考的问题。"""
]
for text in test_texts:
demo_snownlp_features(text)
print("\n" + "="*50)

488
src/test/typo.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

301
src/test/typo_word.py Normal file
View File

@@ -0,0 +1,301 @@
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
"""
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2):
"""
创建包含同音字错误的句子,保留原文标点符号
"""
result = []
typo_info = []
# 获取每个字的拼音
chars_with_pinyin = get_pinyin(sentence)
# 创建原字到拼音的映射,用于跟踪已处理的字符
processed_chars = {char: py for char, py in chars_with_pinyin}
# 遍历原句中的每个字符
char_index = 0
for i, char in enumerate(sentence):
if char.isspace():
# 保留空格
result.append(char)
elif char in processed_chars:
# 处理汉字
py = processed_chars[char]
# 基础错误率
if random.random() < error_rate:
# 获取频率相近的同音字(可能包含声调错误)
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
# 随机选择一个替换字
typo_char = random.choice(similar_chars)
# 获取替换字的频率
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
# 计算实际替换概率
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
# 根据频率差进行概率替换
if random.random() < replace_prob:
result.append(typo_char)
# 获取替换字的实际拼音
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
else:
result.append(char)
else:
result.append(char)
else:
result.append(char)
char_index += 1
else:
# 保留非汉字字符(标点符号等)
result.append(char)
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
min_freq=5, tone_error_rate=0.2)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原字:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"错字:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印所有可能的同音字
print("\n每个字的所有同音字(按频率排序,仅显示频率>=5的字")
for char, py in result:
homophones = get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5)
char_freq = char_frequency.get(char, 0)
print(f"{char}: {py} [频率:{format_frequency(char_freq)}]")
if homophones:
homophone_info = []
for h in homophones:
h_freq = char_frequency.get(h, 0)
homophone_info.append(f"{h}[{format_frequency(h_freq)}]")
print(f"同音字: {''.join(homophone_info)}")
else:
print("没有找到频率>=5的同音字")
if __name__ == "__main__":
main()