26
.env
26
.env
@@ -1,26 +0,0 @@
|
|||||||
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
|
||||||
ENVIRONMENT=prod
|
|
||||||
# HOST=127.0.0.1
|
|
||||||
# PORT=8080
|
|
||||||
|
|
||||||
# COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# # 插件配置
|
|
||||||
# PLUGINS=["src2.plugins.chat"]
|
|
||||||
|
|
||||||
# # 默认配置
|
|
||||||
# MONGODB_HOST=127.0.0.1
|
|
||||||
# MONGODB_PORT=27017
|
|
||||||
# DATABASE_NAME=MegBot
|
|
||||||
|
|
||||||
# MONGODB_USERNAME = "" # 默认空值
|
|
||||||
# MONGODB_PASSWORD = "" # 默认空值
|
|
||||||
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
|
||||||
|
|
||||||
# #key and url
|
|
||||||
# CHAT_ANY_WHERE_KEY=
|
|
||||||
# SILICONFLOW_KEY=
|
|
||||||
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
|
||||||
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
|
||||||
# DEEP_SEEK_KEY=
|
|
||||||
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
|
||||||
30
.github/workflows/docker-image.yml
vendored
30
.github/workflows/docker-image.yml
vendored
@@ -3,10 +3,11 @@ name: Docker Build and Push
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main # 推送到main分支时触发
|
- main
|
||||||
|
- debug # 新增 debug 分支触发
|
||||||
tags:
|
tags:
|
||||||
- 'v*' # 推送v开头的tag时触发(例如v1.0.0)
|
- 'v*'
|
||||||
workflow_dispatch: # 允许手动触发
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push:
|
build-and-push:
|
||||||
@@ -24,13 +25,24 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Determine Image Tags
|
||||||
|
id: tags
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
|
||||||
|
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }},${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
|
||||||
|
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:main,${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ github.ref }}" == "refs/heads/debug" ]; then
|
||||||
|
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:debug" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Build and Push Docker Image
|
- name: Build and Push Docker Image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: . # Docker构建上下文路径
|
context: .
|
||||||
file: ./Dockerfile # Dockerfile路径
|
file: ./Dockerfile
|
||||||
platforms: linux/amd64,linux/arm64 # 支持arm架构
|
platforms: linux/amd64,linux/arm64
|
||||||
tags: |
|
tags: ${{ steps.tags.outputs.tags }}
|
||||||
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:${{ github.ref_name }}
|
|
||||||
${{ secrets.DOCKERHUB_USERNAME }}/maimbot:latest
|
|
||||||
push: true
|
push: true
|
||||||
|
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache
|
||||||
|
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot:buildcache,mode=max
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -3,15 +3,17 @@ mongodb/
|
|||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
log/
|
log/
|
||||||
/test
|
/test
|
||||||
|
/src/test
|
||||||
message_queue_content.txt
|
message_queue_content.txt
|
||||||
message_queue_content.bat
|
message_queue_content.bat
|
||||||
message_queue_window.bat
|
message_queue_window.bat
|
||||||
message_queue_window.txt
|
message_queue_window.txt
|
||||||
queue_update.txt
|
queue_update.txt
|
||||||
memory_graph.gml
|
memory_graph.gml
|
||||||
|
.env
|
||||||
.env.*
|
.env.*
|
||||||
config/bot_config_dev.toml
|
config/bot_config_dev.toml
|
||||||
|
config/bot_config.toml
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
@@ -183,3 +185,6 @@ cython_debug/
|
|||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
# jieba
|
||||||
|
jieba.cache
|
||||||
|
|||||||
38
README.md
38
README.md
@@ -32,7 +32,7 @@
|
|||||||
> - QQ机器人存在被限制风险,请自行了解,谨慎使用
|
> - QQ机器人存在被限制风险,请自行了解,谨慎使用
|
||||||
> - 由于持续迭代,可能存在一些已知或未知的bug
|
> - 由于持续迭代,可能存在一些已知或未知的bug
|
||||||
|
|
||||||
**交流群**: 766798517(仅用于开发和建议相关讨论)
|
**交流群**: 766798517(仅用于开发和建议相关讨论)不建议在群内询问部署问题,我不一定有空回复,会优先写文档和代码
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
|
|
||||||
@@ -42,22 +42,22 @@
|
|||||||
## 🎯 功能介绍
|
## 🎯 功能介绍
|
||||||
|
|
||||||
### 💬 聊天功能
|
### 💬 聊天功能
|
||||||
- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言,目前有bug,所以现在只会检测主题,不会进行存储
|
- 支持关键词检索主动发言:对消息的话题topic进行识别,如果检测到麦麦存储过的话题就会主动进行发言
|
||||||
- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
|
- 支持bot名字呼唤发言:检测到"麦麦"会主动发言,可配置
|
||||||
- 使用硅基流动的api进行回复生成,可随机使用R1,V3,R1-distill等模型,未来将加入官网api支持
|
- 支持多模型,多厂商自定义配置
|
||||||
- 动态的prompt构建器,更拟人
|
- 动态的prompt构建器,更拟人
|
||||||
- 支持图片,转发消息,回复消息的识别
|
- 支持图片,转发消息,回复消息的识别
|
||||||
- 错别字和多条回复功能:麦麦可以随机生成错别字,会多条发送回复以及对消息进行reply
|
- 错别字和多条回复功能:麦麦可以随机生成错别字,会多条发送回复以及对消息进行reply
|
||||||
|
|
||||||
### 😊 表情包功能
|
### 😊 表情包功能
|
||||||
- 支持根据发言内容发送对应情绪的表情包:未完善,可以用
|
- 支持根据发言内容发送对应情绪的表情包
|
||||||
- 会自动偷群友的表情包(未完善,暂时禁用)目前有bug
|
- 会自动偷群友的表情包
|
||||||
|
|
||||||
### 📅 日程功能
|
### 📅 日程功能
|
||||||
- 麦麦会自动生成一天的日程,实现更拟人的回复
|
- 麦麦会自动生成一天的日程,实现更拟人的回复
|
||||||
|
|
||||||
### 🧠 记忆功能
|
### 🧠 记忆功能
|
||||||
- 对聊天记录进行概括存储,在需要时调用,没写完
|
- 对聊天记录进行概括存储,在需要时调用,待完善
|
||||||
|
|
||||||
### 📚 知识库功能
|
### 📚 知识库功能
|
||||||
- 基于embedding模型的知识库,手动放入txt会自动识别,写完了,暂时禁用
|
- 基于embedding模型的知识库,手动放入txt会自动识别,写完了,暂时禁用
|
||||||
@@ -66,25 +66,27 @@
|
|||||||
- 针对每个用户创建"关系",可以对不同用户进行个性化回复,目前只有极其简单的好感度(WIP)
|
- 针对每个用户创建"关系",可以对不同用户进行个性化回复,目前只有极其简单的好感度(WIP)
|
||||||
- 针对每个群创建"群印象",可以对不同群进行个性化回复(WIP)
|
- 针对每个群创建"群印象",可以对不同群进行个性化回复(WIP)
|
||||||
|
|
||||||
## 🚧 开发中功能
|
|
||||||
|
|
||||||
|
## 开发计划TODO:LIST
|
||||||
- 人格功能:WIP
|
- 人格功能:WIP
|
||||||
- 群氛围功能:WIP
|
- 群氛围功能:WIP
|
||||||
- 图片发送,转发功能:WIP
|
- 图片发送,转发功能:WIP
|
||||||
- 幽默和meme功能:WIP的WIP
|
- 幽默和meme功能:WIP的WIP
|
||||||
- 让麦麦玩mc:WIP的WIP的WIP
|
- 让麦麦玩mc:WIP的WIP的WIP
|
||||||
|
|
||||||
## 开发计划TODO:LIST
|
|
||||||
|
|
||||||
- 兼容gif的解析和保存
|
- 兼容gif的解析和保存
|
||||||
- 小程序转发链接解析
|
- 小程序转发链接解析
|
||||||
- 对思考链长度限制
|
- 对思考链长度限制
|
||||||
- 修复已知bug
|
- 修复已知bug
|
||||||
- 完善文档
|
- ~~完善文档~~
|
||||||
- 修复转发
|
- 修复转发
|
||||||
- config自动生成和检测
|
- ~~config自动生成和检测~~
|
||||||
- log别用print
|
- ~~log别用print~~
|
||||||
- 给发送消息写专门的类
|
- ~~给发送消息写专门的类~~
|
||||||
- 改进表情包发送逻辑
|
- 改进表情包发送逻辑
|
||||||
|
- 自动生成的回复逻辑,例如自生成的回复方向,回复风格
|
||||||
|
- 采用截断生成加快麦麦的反应速度
|
||||||
|
- 改进发送消息的触发:
|
||||||
|
|
||||||
## 📌 注意事项
|
## 📌 注意事项
|
||||||
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
纯编程外行,面向cursor编程,很多代码史一样多多包涵
|
||||||
@@ -99,4 +101,10 @@
|
|||||||
|
|
||||||
感谢各位大佬!
|
感谢各位大佬!
|
||||||
|
|
||||||
[](https://github.com/SengokuCola/MaiMBot/graphs/contributors)
|
<a href="https://github.com/SengokuCola/MaiMBot/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=SengokuCola/MaiMBot&time=true" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
|
||||||
|
## Stargazers over time
|
||||||
|
[](https://starchart.cc/SengokuCola/MaiMBot)
|
||||||
50
bot.py
50
bot.py
@@ -6,6 +6,7 @@ from loguru import logger
|
|||||||
|
|
||||||
'''彩蛋'''
|
'''彩蛋'''
|
||||||
from colorama import init, Fore
|
from colorama import init, Fore
|
||||||
|
|
||||||
init()
|
init()
|
||||||
text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
text = "多年以后,面对行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||||
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
||||||
@@ -15,25 +16,47 @@ for i, char in enumerate(text):
|
|||||||
print(rainbow_text)
|
print(rainbow_text)
|
||||||
'''彩蛋'''
|
'''彩蛋'''
|
||||||
|
|
||||||
# 首先加载基础环境变量
|
# 初次启动检测
|
||||||
|
if not os.path.exists("config/bot_config.toml") or not os.path.exists(".env"):
|
||||||
|
logger.info("检测到bot_config.toml不存在,正在从模板复制")
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.copy("config/bot_config_template.toml", "config/bot_config.toml")
|
||||||
|
logger.info("复制完成,请修改config/bot_config.toml和.env.prod中的配置后重新启动")
|
||||||
|
|
||||||
|
# 初始化.env 默认ENVIRONMENT=prod
|
||||||
|
if not os.path.exists(".env"):
|
||||||
|
with open(".env", "w") as f:
|
||||||
|
f.write("ENVIRONMENT=prod")
|
||||||
|
|
||||||
|
# 检测.env.prod文件是否存在
|
||||||
|
if not os.path.exists(".env.prod"):
|
||||||
|
logger.error("检测到.env.prod文件不存在")
|
||||||
|
shutil.copy("template.env", "./.env.prod")
|
||||||
|
|
||||||
|
# 首先加载基础环境变量.env
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
logger.success("成功加载基础环境变量配置")
|
logger.success("成功加载基础环境变量配置")
|
||||||
else:
|
|
||||||
logger.error("基础环境变量配置文件 .env 不存在")
|
|
||||||
exit(1)
|
|
||||||
# 根据 ENVIRONMENT 加载对应的环境配置
|
|
||||||
env = os.getenv("ENVIRONMENT")
|
|
||||||
env_file = f".env.{env}"
|
|
||||||
|
|
||||||
if env_file == ".env.dev" and os.path.exists(env_file):
|
# 根据 ENVIRONMENT 加载对应的环境配置
|
||||||
logger.success("加载开发环境变量配置")
|
if os.getenv("ENVIRONMENT") == "prod":
|
||||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
logger.success("加载生产环境变量配置")
|
||||||
elif os.path.exists(".env.prod"):
|
|
||||||
logger.success("加载环境变量配置")
|
|
||||||
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
|
elif os.getenv("ENVIRONMENT") == "dev":
|
||||||
|
logger.success("加载开发环境变量配置")
|
||||||
|
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
|
elif os.path.exists(f".env.{os.getenv('ENVIRONMENT')}"):
|
||||||
|
logger.success(f"加载{os.getenv('ENVIRONMENT')}环境变量配置")
|
||||||
|
load_dotenv(f".env.{os.getenv('ENVIRONMENT')}", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
else:
|
else:
|
||||||
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
logger.error(f"ENVIRONMENT配置错误,请检查.env文件中的ENVIRONMENT变量对应的.env.{os.getenv('ENVIRONMENT')}是否存在")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# 检测Key是否存在
|
||||||
|
if not os.getenv("SILICONFLOW_KEY"):
|
||||||
|
logger.error("缺失必要的API KEY")
|
||||||
|
logger.error(f"请至少在.env.{os.getenv('ENVIRONMENT')}文件中填写SILICONFLOW_KEY后重新启动")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# 获取所有环境变量
|
# 获取所有环境变量
|
||||||
@@ -57,5 +80,4 @@ driver.register_adapter(Adapter)
|
|||||||
nonebot.load_plugins("src/plugins")
|
nonebot.load_plugins("src/plugins")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
nonebot.run()
|
nonebot.run()
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
[bot]
|
|
||||||
qq = 123
|
|
||||||
nickname = "麦麦"
|
|
||||||
|
|
||||||
[message]
|
|
||||||
min_text_length = 2
|
|
||||||
max_context_size = 15
|
|
||||||
emoji_chance = 0.2
|
|
||||||
|
|
||||||
[emoji]
|
|
||||||
check_interval = 120
|
|
||||||
register_interval = 10
|
|
||||||
|
|
||||||
[cq_code]
|
|
||||||
enable_pic_translate = false
|
|
||||||
|
|
||||||
[response]
|
|
||||||
api_using = "siliconflow"
|
|
||||||
api_paid = true
|
|
||||||
model_r1_probability = 0.8
|
|
||||||
model_v3_probability = 0.1
|
|
||||||
model_r1_distill_probability = 0.1
|
|
||||||
|
|
||||||
[memory]
|
|
||||||
build_memory_interval = 300
|
|
||||||
|
|
||||||
[others]
|
|
||||||
enable_advance_output = true
|
|
||||||
|
|
||||||
[groups]
|
|
||||||
talk_allowed = [
|
|
||||||
123,
|
|
||||||
123,
|
|
||||||
]
|
|
||||||
talk_frequency_down = []
|
|
||||||
ban_user_id = []
|
|
||||||
|
|
||||||
[model.llm_reasoning]
|
|
||||||
name = "Pro/deepseek-ai/DeepSeek-R1"
|
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
|
||||||
key = "SILICONFLOW_KEY"
|
|
||||||
|
|
||||||
[model.llm_reasoning_minor]
|
|
||||||
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
|
||||||
key = "SILICONFLOW_KEY"
|
|
||||||
|
|
||||||
[model.llm_normal]
|
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
|
||||||
key = "SILICONFLOW_KEY"
|
|
||||||
|
|
||||||
[model.llm_normal_minor]
|
|
||||||
name = "deepseek-ai/DeepSeek-V2.5"
|
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
|
||||||
key = "SILICONFLOW_KEY"
|
|
||||||
|
|
||||||
[model.vlm]
|
|
||||||
name = "deepseek-ai/deepseek-vl2"
|
|
||||||
base_url = "SILICONFLOW_BASE_URL"
|
|
||||||
key = "SILICONFLOW_KEY"
|
|
||||||
98
config/bot_config_template.toml
Normal file
98
config/bot_config_template.toml
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
[bot]
|
||||||
|
qq = 123
|
||||||
|
nickname = "麦麦"
|
||||||
|
|
||||||
|
[personality]
|
||||||
|
prompt_personality = [
|
||||||
|
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", # 贴吧人格
|
||||||
|
"是一个女大学生,你有黑色头发,你会刷小红书" # 小红书人格
|
||||||
|
]
|
||||||
|
prompt_schedule = "一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书"
|
||||||
|
|
||||||
|
[message]
|
||||||
|
min_text_length = 2 # 与麦麦聊天时麦麦只会回答文本大于等于此数的消息
|
||||||
|
max_context_size = 15 # 麦麦获得的上文数量
|
||||||
|
emoji_chance = 0.2 # 麦麦使用表情包的概率
|
||||||
|
ban_words = [
|
||||||
|
# "403","张三"
|
||||||
|
]
|
||||||
|
|
||||||
|
[emoji]
|
||||||
|
check_interval = 120 # 检查表情包的时间间隔
|
||||||
|
register_interval = 10 # 注册表情包的时间间隔
|
||||||
|
|
||||||
|
[cq_code]
|
||||||
|
enable_pic_translate = false
|
||||||
|
|
||||||
|
[response]
|
||||||
|
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
||||||
|
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
||||||
|
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300 # 记忆构建间隔 单位秒
|
||||||
|
forget_memory_interval = 300 # 记忆遗忘间隔 单位秒
|
||||||
|
|
||||||
|
[others]
|
||||||
|
enable_advance_output = true # 是否启用高级输出
|
||||||
|
enable_kuuki_read = true # 是否启用读空气功能
|
||||||
|
|
||||||
|
[groups]
|
||||||
|
talk_allowed = [
|
||||||
|
123,
|
||||||
|
123,
|
||||||
|
] #可以回复消息的群
|
||||||
|
talk_frequency_down = [] #降低回复频率的群
|
||||||
|
ban_user_id = [] #禁止回复消息的QQ号
|
||||||
|
|
||||||
|
|
||||||
|
#V3
|
||||||
|
#name = "deepseek-chat"
|
||||||
|
#base_url = "DEEP_SEEK_BASE_URL"
|
||||||
|
#key = "DEEP_SEEK_KEY"
|
||||||
|
|
||||||
|
#R1
|
||||||
|
#name = "deepseek-reasoner"
|
||||||
|
#base_url = "DEEP_SEEK_BASE_URL"
|
||||||
|
#key = "DEEP_SEEK_KEY"
|
||||||
|
|
||||||
|
#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env.prod自定义的宏,使用自定义模型则选择定位相似的模型自己填写
|
||||||
|
|
||||||
|
[model.llm_reasoning] #R1
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning_minor] #R1蒸馏
|
||||||
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal] #V3
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal_minor] #V2.5
|
||||||
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.vlm] #图像识别
|
||||||
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.embedding] #嵌入
|
||||||
|
name = "BAAI/bge-m3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
# 主题提取,jieba和snownlp不用api,llm需要api
|
||||||
|
[topic]
|
||||||
|
topic_extract='snownlp' # 只支持jieba,snownlp,llm三种选项
|
||||||
|
|
||||||
|
[topic.llm_topic]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
### 🐳 Docker部署(推荐)
|
如果你不知道Docker是什么,建议寻找相关教程或使用手动部署
|
||||||
|
|
||||||
|
### 🐳 Docker部署(推荐,但不一定是最新)
|
||||||
|
|
||||||
1. 获取配置文件:
|
1. 获取配置文件:
|
||||||
```bash
|
```bash
|
||||||
@@ -25,9 +27,7 @@ NAPCAT_UID=$(id -u) NAPCAT_GID=$(id -g) docker compose restart
|
|||||||
```bash
|
```bash
|
||||||
# 创建虚拟环境(推荐)
|
# 创建虚拟环境(推荐)
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
source venv/bin/activate # Linux
|
|
||||||
venv\\Scripts\\activate # Windows
|
venv\\Scripts\\activate # Windows
|
||||||
|
|
||||||
# 安装依赖
|
# 安装依赖
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
@@ -41,33 +41,37 @@ pip install -r requirements.txt
|
|||||||
- 添加反向WS:`ws://localhost:8080/onebot/v11/ws`
|
- 添加反向WS:`ws://localhost:8080/onebot/v11/ws`
|
||||||
|
|
||||||
4. **配置文件设置**
|
4. **配置文件设置**
|
||||||
- 复制并修改环境配置:`.env.prod`
|
- 修改环境配置文件:`.env.prod`
|
||||||
- 复制并修改机器人配置:`bot_config.toml`
|
- 修改机器人配置文件:`bot_config.toml`
|
||||||
|
|
||||||
5. **启动服务**
|
5. **启动麦麦机器人**
|
||||||
|
- 打开命令行,cd到对应路径
|
||||||
```bash
|
```bash
|
||||||
nb run
|
nb run
|
||||||
```
|
```
|
||||||
|
|
||||||
6. **其他组件**
|
6. **其他组件**
|
||||||
- `run_thingking.bat`: 启动可视化推理界面(未完善)和消息队列预览
|
- `run_thingking.bat`: 启动可视化推理界面(未完善)
|
||||||
- `knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库
|
|
||||||
|
- ~~`knowledge.bat`: 将`/data/raw_info`下的文本文档载入数据库~~
|
||||||
|
- 直接运行 knowledge.py生成知识库
|
||||||
|
|
||||||
## ⚙️ 配置说明
|
## ⚙️ 配置说明
|
||||||
|
|
||||||
### 环境配置 (.env.prod)
|
### 环境配置 (.env.prod)
|
||||||
```ini
|
```ini
|
||||||
# API配置(必填)
|
# API配置,你可以在这里定义你的密钥和base_url
|
||||||
|
# 你可以选择定义其他服务商提供的KEY,完全可以自定义
|
||||||
SILICONFLOW_KEY=your_key
|
SILICONFLOW_KEY=your_key
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_KEY=your_key
|
DEEP_SEEK_KEY=your_key
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
# 服务配置
|
# 服务配置,如果你不知道这是什么,保持默认
|
||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置,如果你不知道这是什么,保持默认
|
||||||
MONGODB_HOST=127.0.0.1
|
MONGODB_HOST=127.0.0.1
|
||||||
MONGODB_PORT=27017
|
MONGODB_PORT=27017
|
||||||
DATABASE_NAME=MegBot
|
DATABASE_NAME=MegBot
|
||||||
@@ -80,19 +84,58 @@ qq = "你的机器人QQ号"
|
|||||||
nickname = "麦麦"
|
nickname = "麦麦"
|
||||||
|
|
||||||
[message]
|
[message]
|
||||||
|
min_text_length = 2
|
||||||
max_context_size = 15
|
max_context_size = 15
|
||||||
emoji_chance = 0.2
|
emoji_chance = 0.2
|
||||||
|
|
||||||
|
[emoji]
|
||||||
|
check_interval = 120
|
||||||
|
register_interval = 10
|
||||||
|
|
||||||
|
[cq_code]
|
||||||
|
enable_pic_translate = false
|
||||||
|
|
||||||
[response]
|
[response]
|
||||||
api_using = "siliconflow" # 或 "deepseek"
|
#现已移除deepseek或硅基流动选项,可以直接切换分别配置任意模型
|
||||||
|
model_r1_probability = 0.8 #推理模型权重
|
||||||
|
model_v3_probability = 0.1 #非推理模型权重
|
||||||
|
model_r1_distill_probability = 0.1
|
||||||
|
|
||||||
|
[memory]
|
||||||
|
build_memory_interval = 300
|
||||||
|
|
||||||
[others]
|
[others]
|
||||||
enable_advance_output = false # 是否启用详细日志输出
|
enable_advance_output = true # 是否启用详细日志输出
|
||||||
|
|
||||||
[groups]
|
[groups]
|
||||||
talk_allowed = [] # 允许回复的群号列表
|
talk_allowed = [] # 允许回复的群号列表
|
||||||
talk_frequency_down = [] # 降低回复频率的群号列表
|
talk_frequency_down = [] # 降低回复频率的群号列表
|
||||||
ban_user_id = [] # 禁止回复的用户QQ号列表
|
ban_user_id = [] # 禁止回复的用户QQ号列表
|
||||||
|
|
||||||
|
[model.llm_reasoning]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_reasoning_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal]
|
||||||
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.llm_normal_minor]
|
||||||
|
name = "deepseek-ai/DeepSeek-V2.5"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
|
|
||||||
|
[model.vlm]
|
||||||
|
name = "deepseek-ai/deepseek-vl2"
|
||||||
|
base_url = "SILICONFLOW_BASE_URL"
|
||||||
|
key = "SILICONFLOW_KEY"
|
||||||
```
|
```
|
||||||
|
|
||||||
## ⚠️ 注意事项
|
## ⚠️ 注意事项
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
@echo off
|
|
||||||
echo 正在查找并结束所有 MongoDB 进程...
|
|
||||||
taskkill /F /IM mongod.exe
|
|
||||||
taskkill /F /IM mongo.exe
|
|
||||||
echo MongoDB 进程已结束
|
|
||||||
pause
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
chcp 65001
|
||||||
call conda activate niuniu
|
call conda activate niuniu
|
||||||
cd .
|
cd .
|
||||||
|
|
||||||
|
|||||||
11
setup.py
Normal file
11
setup.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="maimai-bot",
|
||||||
|
version="0.1",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
'python-dotenv',
|
||||||
|
'pymongo',
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -25,3 +25,24 @@ class Database:
|
|||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
raise RuntimeError("Database not initialized")
|
raise RuntimeError("Database not initialized")
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
#测试用
|
||||||
|
|
||||||
|
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
||||||
|
# 先随机获取一条消息
|
||||||
|
random_message = list(self.db.messages.aggregate([
|
||||||
|
{"$match": {"group_id": group_id}},
|
||||||
|
{"$sample": {"size": 1}}
|
||||||
|
]))[0]
|
||||||
|
|
||||||
|
# 获取该消息之后的消息
|
||||||
|
subsequent_messages = list(self.db.messages.find({
|
||||||
|
"group_id": group_id,
|
||||||
|
"time": {"$gt": random_message["time"]}
|
||||||
|
}).sort("time", 1).limit(limit))
|
||||||
|
|
||||||
|
# 将随机消息和后续消息合并
|
||||||
|
messages = [random_message] + subsequent_messages
|
||||||
|
|
||||||
|
return messages
|
||||||
165
src/plugins/chat/Segment_builder.py
Normal file
165
src/plugins/chat/Segment_builder.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from typing import Dict, List, Union, Optional, Any
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
|
"""
|
||||||
|
OneBot v11 Message Segment Builder
|
||||||
|
|
||||||
|
This module provides classes for building message segments that conform to the
|
||||||
|
OneBot v11 standard. These segments can be used to construct complex messages
|
||||||
|
for sending through bots that implement the OneBot interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Segment:
|
||||||
|
"""Base class for all message segments."""
|
||||||
|
|
||||||
|
def __init__(self, type_: str, data: Dict[str, Any]):
|
||||||
|
self.type = type_
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert the segment to a dictionary format."""
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"data": self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Text(Segment):
|
||||||
|
"""Text message segment."""
|
||||||
|
|
||||||
|
def __init__(self, text: str):
|
||||||
|
super().__init__("text", {"text": text})
|
||||||
|
|
||||||
|
|
||||||
|
class Face(Segment):
|
||||||
|
"""Face/emoji message segment."""
|
||||||
|
|
||||||
|
def __init__(self, face_id: int):
|
||||||
|
super().__init__("face", {"id": str(face_id)})
|
||||||
|
|
||||||
|
|
||||||
|
class Image(Segment):
|
||||||
|
"""Image message segment."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_url(cls, url: str) -> 'Image':
|
||||||
|
"""Create an Image segment from a URL."""
|
||||||
|
return cls(url=url)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, path: str) -> 'Image':
|
||||||
|
"""Create an Image segment from a file path."""
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
file_b64 = base64.b64encode(f.read()).decode('utf-8')
|
||||||
|
return cls(file=f"base64://{file_b64}")
|
||||||
|
|
||||||
|
def __init__(self, file: str = None, url: str = None, cache: bool = True):
|
||||||
|
data = {}
|
||||||
|
if file:
|
||||||
|
data["file"] = file
|
||||||
|
if url:
|
||||||
|
data["url"] = url
|
||||||
|
if not cache:
|
||||||
|
data["cache"] = "0"
|
||||||
|
super().__init__("image", data)
|
||||||
|
|
||||||
|
|
||||||
|
class At(Segment):
|
||||||
|
"""@Someone message segment."""
|
||||||
|
|
||||||
|
def __init__(self, user_id: Union[int, str]):
|
||||||
|
data = {"qq": str(user_id)}
|
||||||
|
super().__init__("at", data)
|
||||||
|
|
||||||
|
|
||||||
|
class Record(Segment):
|
||||||
|
"""Voice message segment."""
|
||||||
|
|
||||||
|
def __init__(self, file: str, magic: bool = False, cache: bool = True):
|
||||||
|
data = {"file": file}
|
||||||
|
if magic:
|
||||||
|
data["magic"] = "1"
|
||||||
|
if not cache:
|
||||||
|
data["cache"] = "0"
|
||||||
|
super().__init__("record", data)
|
||||||
|
|
||||||
|
|
||||||
|
class Video(Segment):
|
||||||
|
"""Video message segment."""
|
||||||
|
|
||||||
|
def __init__(self, file: str):
|
||||||
|
super().__init__("video", {"file": file})
|
||||||
|
|
||||||
|
|
||||||
|
class Reply(Segment):
|
||||||
|
"""Reply message segment."""
|
||||||
|
|
||||||
|
def __init__(self, message_id: int):
|
||||||
|
super().__init__("reply", {"id": str(message_id)})
|
||||||
|
|
||||||
|
|
||||||
|
class MessageBuilder:
|
||||||
|
"""Helper class for building complex messages."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.segments: List[Segment] = []
|
||||||
|
|
||||||
|
def text(self, text: str) -> 'MessageBuilder':
|
||||||
|
"""Add a text segment."""
|
||||||
|
self.segments.append(Text(text))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def face(self, face_id: int) -> 'MessageBuilder':
|
||||||
|
"""Add a face/emoji segment."""
|
||||||
|
self.segments.append(Face(face_id))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def image(self, file: str = None) -> 'MessageBuilder':
|
||||||
|
"""Add an image segment."""
|
||||||
|
self.segments.append(Image(file=file))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def at(self, user_id: Union[int, str]) -> 'MessageBuilder':
|
||||||
|
"""Add an @someone segment."""
|
||||||
|
self.segments.append(At(user_id))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def record(self, file: str, magic: bool = False) -> 'MessageBuilder':
|
||||||
|
"""Add a voice record segment."""
|
||||||
|
self.segments.append(Record(file, magic))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def video(self, file: str) -> 'MessageBuilder':
|
||||||
|
"""Add a video segment."""
|
||||||
|
self.segments.append(Video(file))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reply(self, message_id: int) -> 'MessageBuilder':
|
||||||
|
"""Add a reply segment."""
|
||||||
|
self.segments.append(Reply(message_id))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Build the message into a list of segment dictionaries."""
|
||||||
|
return [segment.to_dict() for segment in self.segments]
|
||||||
|
|
||||||
|
|
||||||
|
'''Convenience functions
|
||||||
|
def text(content: str) -> Dict[str, Any]:
|
||||||
|
"""Create a text message segment."""
|
||||||
|
return Text(content).to_dict()
|
||||||
|
|
||||||
|
def image_url(url: str) -> Dict[str, Any]:
|
||||||
|
"""Create an image message segment from URL."""
|
||||||
|
return Image.from_url(url).to_dict()
|
||||||
|
|
||||||
|
def image_path(path: str) -> Dict[str, Any]:
|
||||||
|
"""Create an image message segment from file path."""
|
||||||
|
return Image.from_path(path).to_dict()
|
||||||
|
|
||||||
|
def at(user_id: Union[int, str]) -> Dict[str, Any]:
|
||||||
|
"""Create an @someone message segment."""
|
||||||
|
return At(user_id).to_dict()'''
|
||||||
@@ -10,6 +10,10 @@ import random
|
|||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
from nonebot.rule import to_me
|
||||||
|
from .bot import chat_bot
|
||||||
|
from .emoji_manager import emoji_manager
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
# 获取驱动器
|
# 获取驱动器
|
||||||
@@ -17,12 +21,12 @@ driver = get_driver()
|
|||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source= config.mongodb_auth_source
|
auth_source= config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
|
|
||||||
@@ -30,8 +34,9 @@ print("\033[1;32m[初始化数据库完成]\033[0m")
|
|||||||
# 导入其他模块
|
# 导入其他模块
|
||||||
from .bot import ChatBot
|
from .bot import ChatBot
|
||||||
from .emoji_manager import emoji_manager
|
from .emoji_manager import emoji_manager
|
||||||
from .message_send_control import message_sender
|
# from .message_send_control import message_sender
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
|
from .message_sender import message_manager,message_sender
|
||||||
from ..memory_system.memory import memory_graph,hippocampus
|
from ..memory_system.memory import memory_graph,hippocampus
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
@@ -40,8 +45,8 @@ emoji_manager.initialize()
|
|||||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
||||||
# 创建机器人实例
|
# 创建机器人实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
# 注册消息处理器
|
# 注册群消息处理器
|
||||||
group_msg = on_message()
|
group_msg = on_message(priority=5)
|
||||||
# 创建定时任务
|
# 创建定时任务
|
||||||
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
||||||
|
|
||||||
@@ -66,10 +71,13 @@ async def init_relationships():
|
|||||||
async def _(bot: Bot):
|
async def _(bot: Bot):
|
||||||
"""Bot连接成功时的处理"""
|
"""Bot连接成功时的处理"""
|
||||||
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
|
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
|
||||||
message_sender.set_bot(bot)
|
|
||||||
asyncio.create_task(message_sender.start_processor(bot))
|
|
||||||
await willing_manager.ensure_started()
|
await willing_manager.ensure_started()
|
||||||
|
|
||||||
|
|
||||||
|
message_sender.set_bot(bot)
|
||||||
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
|
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
|
||||||
|
asyncio.create_task(message_manager.start_processor())
|
||||||
|
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
|
||||||
|
|
||||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
||||||
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
||||||
@@ -79,19 +87,27 @@ async def _(bot: Bot):
|
|||||||
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
||||||
await chat_bot.handle_message(event, bot)
|
await chat_bot.handle_message(event, bot)
|
||||||
|
|
||||||
'''
|
|
||||||
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
|
|
||||||
async def monitor_relationships():
|
|
||||||
"""每15秒打印一次关系数据"""
|
|
||||||
relationship_manager.print_all_relationships()
|
|
||||||
'''
|
|
||||||
|
|
||||||
# 添加build_memory定时任务
|
# 添加build_memory定时任务
|
||||||
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
||||||
async def build_memory_task():
|
async def build_memory_task():
|
||||||
"""每30秒执行一次记忆构建"""
|
"""每30秒执行一次记忆构建"""
|
||||||
print("\033[1;32m[记忆构建]\033[0m 开始构建记忆...")
|
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
|
||||||
await hippocampus.build_memory(chat_size=30)
|
start_time = time.time()
|
||||||
print("\033[1;32m[记忆构建]\033[0m 记忆构建完成")
|
await hippocampus.operation_build_memory(chat_size=20)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
|
||||||
|
|
||||||
|
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
|
||||||
|
async def forget_memory_task():
|
||||||
|
"""每30秒执行一次记忆构建"""
|
||||||
|
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
||||||
|
# await hippocampus.operation_forget_topic(percentage=0.1)
|
||||||
|
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
||||||
|
|
||||||
|
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
|
||||||
|
async def merge_memory_task():
|
||||||
|
"""每30秒执行一次记忆构建"""
|
||||||
|
# print("\033[1;32m[记忆整合]\033[0m 开始整合")
|
||||||
|
# await hippocampus.operation_merge_memory(percentage=0.1)
|
||||||
|
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,22 @@
|
|||||||
from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot
|
from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot
|
||||||
from .message import Message,MessageSet
|
from .message import Message, MessageSet, Message_Sending
|
||||||
from .config import BotConfig, global_config
|
from .config import BotConfig, global_config
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message_stream import MessageStream, MessageStreamContainer
|
# from .message_stream import MessageStream, MessageStreamContainer
|
||||||
from .topic_identifier import topic_identifier
|
from .topic_identifier import topic_identifier
|
||||||
from random import random, choice
|
from random import random, choice
|
||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from .cq_code import CQCode # 导入CQCode模块
|
from .cq_code import CQCode # 导入CQCode模块
|
||||||
from .message_send_control import message_sender # 导入消息发送控制器
|
from .message_sender import message_manager # 导入新的消息管理器
|
||||||
from .message import Message_Thinking # 导入 Message_Thinking 类
|
from .message import Message_Thinking # 导入 Message_Thinking 类
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
||||||
from ..memory_system.memory import memory_graph
|
from ..memory_system.memory import memory_graph
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -25,16 +26,13 @@ class ChatBot:
|
|||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
self.emoji_chance = 0.2 # 发送表情包的基础概率
|
self.emoji_chance = 0.2 # 发送表情包的基础概率
|
||||||
self.message_streams = MessageStreamContainer()
|
# self.message_streams = MessageStreamContainer()
|
||||||
self.message_sender = message_sender
|
|
||||||
|
|
||||||
async def _ensure_started(self):
|
async def _ensure_started(self):
|
||||||
"""确保所有任务已启动"""
|
"""确保所有任务已启动"""
|
||||||
if not self._started:
|
if not self._started:
|
||||||
# 只保留必要的任务
|
|
||||||
self._started = True
|
self._started = True
|
||||||
|
|
||||||
|
|
||||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||||
"""处理收到的群消息"""
|
"""处理收到的群消息"""
|
||||||
|
|
||||||
@@ -45,59 +43,40 @@ class ChatBot:
|
|||||||
if event.user_id in global_config.ban_user_id:
|
if event.user_id in global_config.ban_user_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 打印原始消息内容
|
|
||||||
'''
|
|
||||||
print(f"\n\033[1;33m[消息详情]\033[0m")
|
|
||||||
# print(f"- 原始消息: {str(event.raw_message)}")
|
|
||||||
print(f"- post_type: {event.post_type}")
|
|
||||||
print(f"- sub_type: {event.sub_type}")
|
|
||||||
print(f"- user_id: {event.user_id}")
|
|
||||||
print(f"- message_type: {event.message_type}")
|
|
||||||
# print(f"- message_id: {event.message_id}")
|
|
||||||
# print(f"- message: {event.message}")
|
|
||||||
print(f"- original_message: {event.original_message}")
|
|
||||||
print(f"- raw_message: {event.raw_message}")
|
|
||||||
# print(f"- font: {event.font}")
|
|
||||||
print(f"- sender: {event.sender}")
|
|
||||||
# print(f"- to_me: {event.to_me}")
|
|
||||||
|
|
||||||
if event.reply:
|
|
||||||
print(f"\n\033[1;33m[回复消息详情]\033[0m")
|
|
||||||
# print(f"- message_id: {event.reply.message_id}")
|
|
||||||
print(f"- message_type: {event.reply.message_type}")
|
|
||||||
print(f"- sender: {event.reply.sender}")
|
|
||||||
# print(f"- time: {event.reply.time}")
|
|
||||||
print(f"- message: {event.reply.message}")
|
|
||||||
print(f"- raw_message: {event.reply.raw_message}")
|
|
||||||
# print(f"- original_message: {event.reply.original_message}")
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||||
|
|
||||||
|
|
||||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||||
# print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
|
|
||||||
|
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
message_id=event.message_id,
|
message_id=event.message_id,
|
||||||
|
user_cardname=sender_info['card'],
|
||||||
raw_message=str(event.original_message),
|
raw_message=str(event.original_message),
|
||||||
plain_text=event.get_plaintext(),
|
plain_text=event.get_plaintext(),
|
||||||
reply_message=event.reply,
|
reply_message=event.reply,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 过滤词
|
||||||
|
for word in global_config.ban_words:
|
||||||
|
if word in message.detailed_plain_text:
|
||||||
|
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||||
|
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||||
|
return
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||||
|
|
||||||
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
|
||||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
|
||||||
|
topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||||
|
|
||||||
|
|
||||||
|
# topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||||
|
# topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||||
|
# topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
|
||||||
|
logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||||
|
|
||||||
all_num = 0
|
all_num = 0
|
||||||
interested_num = 0
|
interested_num = 0
|
||||||
@@ -110,10 +89,8 @@ class ChatBot:
|
|||||||
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
|
print(f"\033[1;32m[前额叶]\033[0m 对|{current_topic}|有印象")
|
||||||
interested_rate = interested_num / all_num if all_num > 0 else 0
|
interested_rate = interested_num / all_num if all_num > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
await self.storage.store_message(message, topic[0] if topic else None)
|
await self.storage.store_message(message, topic[0] if topic else None)
|
||||||
|
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||||
reply_probability = willing_manager.change_reply_willing_received(
|
reply_probability = willing_manager.change_reply_willing_received(
|
||||||
event.group_id,
|
event.group_id,
|
||||||
@@ -127,54 +104,71 @@ class ChatBot:
|
|||||||
current_willing = willing_manager.get_willing(event.group_id)
|
current_willing = willing_manager.get_willing(event.group_id)
|
||||||
|
|
||||||
|
|
||||||
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability:.1f}]\033[0m")
|
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
# 创建思考消息
|
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
|
|
||||||
|
|
||||||
tinking_time_point = round(time.time(), 2)
|
tinking_time_point = round(time.time(), 2)
|
||||||
think_id = 'mt' + str(tinking_time_point)
|
think_id = 'mt' + str(tinking_time_point)
|
||||||
thinking_message = Message_Thinking(message=message,message_id=think_id)
|
thinking_message = Message_Thinking(message=message,message_id=think_id)
|
||||||
message_sender.send_temp_container.add_message(thinking_message)
|
|
||||||
|
message_manager.add_message(thinking_message)
|
||||||
|
|
||||||
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
||||||
|
|
||||||
response, emotion = await self.gpt.generate_response(message)
|
response, emotion = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
# 如果生成了回复,发送并记录
|
# if response is None:
|
||||||
|
# thinking_message.interupt=True
|
||||||
'''
|
|
||||||
生成回复后的内容
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
|
# print(f"\033[1;32m[思考结束]\033[0m 思考结束,已得到回复,开始回复")
|
||||||
|
# 找到并删除对应的thinking消息
|
||||||
|
container = message_manager.get_container(event.group_id)
|
||||||
|
thinking_message = None
|
||||||
|
# 找到message,删除
|
||||||
|
for msg in container.messages:
|
||||||
|
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
|
||||||
|
thinking_message = msg
|
||||||
|
container.messages.remove(msg)
|
||||||
|
print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
|
||||||
|
break
|
||||||
|
|
||||||
|
#记录开始思考的时间,避免从思考到回复的时间太久
|
||||||
|
thinking_start_time = thinking_message.thinking_start_time
|
||||||
|
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
||||||
|
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
|
|
||||||
|
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
|
||||||
for msg in response:
|
for msg in response:
|
||||||
print(f"当前消息: {msg}")
|
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
|
||||||
|
#通过时间改变时间戳
|
||||||
typing_time = calculate_typing_time(msg)
|
typing_time = calculate_typing_time(msg)
|
||||||
accu_typing_time += typing_time
|
accu_typing_time += typing_time
|
||||||
timepoint = tinking_time_point + accu_typing_time
|
timepoint = tinking_time_point + accu_typing_time
|
||||||
# print(f"\033[1;32m[调试]\033[0m 消息: {msg},添加!, 累计打字时间: {accu_typing_time:.2f}秒")
|
|
||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message_Sending(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
message_based_id=event.message_id,
|
|
||||||
raw_message=msg,
|
raw_message=msg,
|
||||||
plain_text=msg,
|
plain_text=msg,
|
||||||
processed_plain_text=msg,
|
processed_plain_text=msg,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
group_name=message.group_name,
|
group_name=message.group_name,
|
||||||
time=timepoint
|
time=timepoint, #记录了回复生成的时间
|
||||||
|
thinking_start_time=thinking_start_time, #记录了思考开始的时间
|
||||||
|
reply_message_id=message.message_id
|
||||||
)
|
)
|
||||||
message_set.add_message(bot_message)
|
message_set.add_message(bot_message)
|
||||||
|
|
||||||
message_sender.send_temp_container.update_thinking_message(message_set)
|
#message_set 可以直接加入 message_manager
|
||||||
|
print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
|
||||||
|
message_manager.add_message(message_set)
|
||||||
|
|
||||||
bot_response_time = tinking_time_point
|
bot_response_time = tinking_time_point
|
||||||
if random() < global_config.emoji_chance:
|
if random() < global_config.emoji_chance:
|
||||||
@@ -187,7 +181,7 @@ class ChatBot:
|
|||||||
else:
|
else:
|
||||||
bot_response_time = bot_response_time + 1
|
bot_response_time = bot_response_time + 1
|
||||||
|
|
||||||
bot_message = Message(
|
bot_message = Message_Sending(
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
message_id=0,
|
message_id=0,
|
||||||
@@ -198,9 +192,13 @@ class ChatBot:
|
|||||||
group_name=message.group_name,
|
group_name=message.group_name,
|
||||||
time=bot_response_time,
|
time=bot_response_time,
|
||||||
is_emoji=True,
|
is_emoji=True,
|
||||||
translate_cq=False
|
translate_cq=False,
|
||||||
|
thinking_start_time=thinking_start_time,
|
||||||
|
# reply_message_id=message.message_id
|
||||||
)
|
)
|
||||||
message_sender.send_temp_container.add_message(bot_message)
|
message_manager.add_message(bot_message)
|
||||||
|
|
||||||
# 如果收到新消息,提高回复意愿
|
|
||||||
willing_manager.change_reply_willing_after_sent(event.group_id)
|
willing_manager.change_reply_willing_after_sent(event.group_id)
|
||||||
|
|
||||||
|
# 创建全局ChatBot实例
|
||||||
|
chat_bot = ChatBot()
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Any, Optional, Set
|
from typing import Dict, Any, Optional, Set
|
||||||
import os
|
import os
|
||||||
from nonebot.log import logger, default_format
|
|
||||||
import logging
|
|
||||||
import configparser
|
import configparser
|
||||||
import tomli
|
import tomli
|
||||||
import sys
|
import sys
|
||||||
@@ -28,17 +26,25 @@ class BotConfig:
|
|||||||
talk_frequency_down_groups = set()
|
talk_frequency_down_groups = set()
|
||||||
ban_user_id = set()
|
ban_user_id = set()
|
||||||
|
|
||||||
build_memory_interval: int = 60 # 记忆构建间隔(秒)
|
build_memory_interval: int = 30 # 记忆构建间隔(秒)
|
||||||
|
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
|
||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
|
ban_words = set()
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
# 主题提取配置
|
||||||
|
topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm
|
||||||
|
llm_topic_extract: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
API_USING: str = "siliconflow" # 使用的API
|
API_USING: str = "siliconflow" # 使用的API
|
||||||
API_PAID: bool = False # 是否使用付费API
|
API_PAID: bool = False # 是否使用付费API
|
||||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||||
@@ -48,6 +54,13 @@ class BotConfig:
|
|||||||
enable_advance_output: bool = False # 是否启用高级输出
|
enable_advance_output: bool = False # 是否启用高级输出
|
||||||
enable_kuuki_read: bool = True # 是否启用读空气功能
|
enable_kuuki_read: bool = True # 是否启用读空气功能
|
||||||
|
|
||||||
|
# 默认人设
|
||||||
|
PROMPT_PERSONALITY=[
|
||||||
|
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
|
||||||
|
"是一个女大学生,你有黑色头发,你会刷小红书"
|
||||||
|
]
|
||||||
|
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_dir() -> str:
|
def get_config_dir() -> str:
|
||||||
"""获取配置文件目录"""
|
"""获取配置文件目录"""
|
||||||
@@ -67,6 +80,15 @@ class BotConfig:
|
|||||||
with open(config_path, "rb") as f:
|
with open(config_path, "rb") as f:
|
||||||
toml_dict = tomli.load(f)
|
toml_dict = tomli.load(f)
|
||||||
|
|
||||||
|
if 'personality' in toml_dict:
|
||||||
|
personality_config=toml_dict['personality']
|
||||||
|
personality=personality_config.get('prompt_personality')
|
||||||
|
if len(personality) >= 2:
|
||||||
|
logger.info(f"载入自定义人格:{personality}")
|
||||||
|
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY)
|
||||||
|
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}")
|
||||||
|
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)
|
||||||
|
|
||||||
if "emoji" in toml_dict:
|
if "emoji" in toml_dict:
|
||||||
emoji_config = toml_dict["emoji"]
|
emoji_config = toml_dict["emoji"]
|
||||||
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
||||||
@@ -103,6 +125,7 @@ class BotConfig:
|
|||||||
|
|
||||||
if "llm_normal" in model_config:
|
if "llm_normal" in model_config:
|
||||||
config.llm_normal = model_config["llm_normal"]
|
config.llm_normal = model_config["llm_normal"]
|
||||||
|
config.llm_topic_extract = config.llm_normal
|
||||||
|
|
||||||
if "llm_normal_minor" in model_config:
|
if "llm_normal_minor" in model_config:
|
||||||
config.llm_normal_minor = model_config["llm_normal_minor"]
|
config.llm_normal_minor = model_config["llm_normal_minor"]
|
||||||
@@ -110,16 +133,30 @@ class BotConfig:
|
|||||||
if "vlm" in model_config:
|
if "vlm" in model_config:
|
||||||
config.vlm = model_config["vlm"]
|
config.vlm = model_config["vlm"]
|
||||||
|
|
||||||
|
if "embedding" in model_config:
|
||||||
|
config.embedding = model_config["embedding"]
|
||||||
|
|
||||||
|
if 'topic' in toml_dict:
|
||||||
|
topic_config=toml_dict['topic']
|
||||||
|
if 'topic_extract' in topic_config:
|
||||||
|
config.topic_extract=topic_config.get('topic_extract',config.topic_extract)
|
||||||
|
logger.info(f"载入自定义主题提取为{config.topic_extract}")
|
||||||
|
if config.topic_extract=='llm' and 'llm_topic' in topic_config:
|
||||||
|
config.llm_topic_extract=topic_config['llm_topic']
|
||||||
|
logger.info(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}")
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
msg_config = toml_dict["message"]
|
msg_config = toml_dict["message"]
|
||||||
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
|
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
|
||||||
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
||||||
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
||||||
|
config.ban_words=msg_config.get("ban_words",config.ban_words)
|
||||||
|
|
||||||
if "memory" in toml_dict:
|
if "memory" in toml_dict:
|
||||||
memory_config = toml_dict["memory"]
|
memory_config = toml_dict["memory"]
|
||||||
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
||||||
|
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
|
||||||
|
|
||||||
# 群组配置
|
# 群组配置
|
||||||
if "groups" in toml_dict:
|
if "groups" in toml_dict:
|
||||||
@@ -131,6 +168,7 @@ class BotConfig:
|
|||||||
if "others" in toml_dict:
|
if "others" in toml_dict:
|
||||||
others_config = toml_dict["others"]
|
others_config = toml_dict["others"]
|
||||||
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
|
||||||
|
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
|
||||||
|
|
||||||
logger.success(f"成功加载配置文件: {config_path}")
|
logger.success(f"成功加载配置文件: {config_path}")
|
||||||
|
|
||||||
@@ -144,32 +182,14 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
|||||||
if not os.path.exists(bot_config_path):
|
if not os.path.exists(bot_config_path):
|
||||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||||
logger.info("使用默认配置文件")
|
logger.info("使用bot配置文件")
|
||||||
else:
|
else:
|
||||||
logger.info("已找到开发环境配置文件")
|
logger.info("已找到开发bot配置文件")
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMConfig:
|
|
||||||
"""机器人配置类"""
|
|
||||||
# 基础配置
|
|
||||||
SILICONFLOW_API_KEY: str = None
|
|
||||||
SILICONFLOW_BASE_URL: str = None
|
|
||||||
DEEP_SEEK_API_KEY: str = None
|
|
||||||
DEEP_SEEK_BASE_URL: str = None
|
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
|
||||||
config = get_driver().config
|
|
||||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
|
||||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
|
||||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
|
||||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
|
||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# logger.remove()
|
logger.remove()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import base64
|
|||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
@@ -235,32 +237,102 @@ class EmojiManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
|
||||||
|
return "skip"
|
||||||
|
|
||||||
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
||||||
return "skip" # 默认标签
|
return "skip" # 默认标签
|
||||||
|
|
||||||
|
async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]:
|
||||||
|
"""压缩图片并返回base64编码
|
||||||
|
Args:
|
||||||
|
image_path: 图片文件路径
|
||||||
|
target_size: 目标文件大小(字节),默认0.8MB
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 成功返回base64编码的图片数据,失败返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(image_path)
|
||||||
|
if file_size <= target_size:
|
||||||
|
# 如果文件已经小于目标大小,直接读取并返回base64
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
return base64.b64encode(f.read()).decode('utf-8')
|
||||||
|
|
||||||
|
# 打开图片
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
# 获取原始尺寸
|
||||||
|
original_width, original_height = img.size
|
||||||
|
|
||||||
|
# 计算缩放比例
|
||||||
|
scale = min(1.0, (target_size / file_size) ** 0.5)
|
||||||
|
|
||||||
|
# 计算新的尺寸
|
||||||
|
new_width = int(original_width * scale)
|
||||||
|
new_height = int(original_height * scale)
|
||||||
|
|
||||||
|
# 创建内存缓冲区
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
|
||||||
|
# 如果是GIF,处理所有帧
|
||||||
|
if getattr(img, "is_animated", False):
|
||||||
|
frames = []
|
||||||
|
for frame_idx in range(img.n_frames):
|
||||||
|
img.seek(frame_idx)
|
||||||
|
new_frame = img.copy()
|
||||||
|
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
frames.append(new_frame)
|
||||||
|
|
||||||
|
# 保存到缓冲区
|
||||||
|
frames[0].save(
|
||||||
|
output_buffer,
|
||||||
|
format='GIF',
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
optimize=True,
|
||||||
|
duration=img.info.get('duration', 100),
|
||||||
|
loop=img.info.get('loop', 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 处理静态图片
|
||||||
|
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# 保存到缓冲区,保持原始格式
|
||||||
|
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||||
|
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||||
|
else:
|
||||||
|
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||||
|
|
||||||
|
# 获取压缩后的数据并转换为base64
|
||||||
|
compressed_data = output_buffer.getvalue()
|
||||||
|
print(f"\033[1;32m[成功]\033[0m 压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})")
|
||||||
|
|
||||||
|
return base64.b64encode(compressed_data).decode('utf-8')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def scan_new_emojis(self):
|
async def scan_new_emojis(self):
|
||||||
"""扫描新的表情包"""
|
"""扫描新的表情包"""
|
||||||
try:
|
try:
|
||||||
emoji_dir = "data/emoji"
|
emoji_dir = "data/emoji"
|
||||||
os.makedirs(emoji_dir, exist_ok=True)
|
os.makedirs(emoji_dir, exist_ok=True)
|
||||||
|
|
||||||
# 获取所有jpg文件
|
# 获取所有支持的图片文件
|
||||||
files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')]
|
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
|
||||||
|
|
||||||
for filename in files_to_process:
|
for filename in files_to_process:
|
||||||
|
image_path = os.path.join(emoji_dir, filename)
|
||||||
|
|
||||||
# 检查是否已经注册过
|
# 检查是否已经注册过
|
||||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||||
if existing_emoji:
|
if existing_emoji:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
image_path = os.path.join(emoji_dir, filename)
|
# 压缩图片并获取base64编码
|
||||||
# 读取图片数据
|
image_base64 = await self._compress_image(image_path)
|
||||||
with open(image_path, 'rb') as f:
|
if image_base64 is None:
|
||||||
image_data = f.read()
|
os.remove(image_path)
|
||||||
|
continue
|
||||||
# 将图片转换为base64
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
|
|
||||||
# 获取表情包的情感标签
|
# 获取表情包的情感标签
|
||||||
tag = await self._get_emoji_tag(image_base64)
|
tag = await self._get_emoji_tag(image_base64)
|
||||||
@@ -280,7 +352,6 @@ class EmojiManager:
|
|||||||
else:
|
else:
|
||||||
print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}")
|
print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}")
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ config = driver.config
|
|||||||
|
|
||||||
class ResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000)
|
||||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
|
||||||
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self.current_model_type = 'r1' # 默认使用 R1
|
self.current_model_type = 'r1' # 默认使用 R1
|
||||||
|
|
||||||
@@ -50,12 +50,19 @@ class ResponseGenerator:
|
|||||||
model_response, emotion = await self._process_response(model_response)
|
model_response, emotion = await self._process_response(model_response)
|
||||||
if model_response:
|
if model_response:
|
||||||
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
||||||
|
valuedict={
|
||||||
|
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
|
||||||
|
}
|
||||||
|
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
||||||
|
|
||||||
return model_response, emotion
|
return model_response, emotion
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
||||||
"""使用指定的模型生成回复"""
|
"""使用指定的模型生成回复"""
|
||||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||||
|
if message.user_cardname:
|
||||||
|
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
|
||||||
|
|
||||||
# 获取关系值
|
# 获取关系值
|
||||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
||||||
@@ -70,25 +77,29 @@ class ResponseGenerator:
|
|||||||
group_id=message.group_id
|
group_id=message.group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 读空气模块
|
# 读空气模块 简化逻辑,先停用
|
||||||
if global_config.enable_kuuki_read:
|
# if global_config.enable_kuuki_read:
|
||||||
content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
|
# content_check, reasoning_content_check = await self.model_v3.generate_response(prompt_check)
|
||||||
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
# print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
||||||
if 'yes' not in content_check.lower() and random.random() < 0.3:
|
# if 'yes' not in content_check.lower() and random.random() < 0.3:
|
||||||
self._save_to_db(
|
# self._save_to_db(
|
||||||
message=message,
|
# message=message,
|
||||||
sender_name=sender_name,
|
# sender_name=sender_name,
|
||||||
prompt=prompt,
|
# prompt=prompt,
|
||||||
prompt_check=prompt_check,
|
# prompt_check=prompt_check,
|
||||||
content="",
|
# content="",
|
||||||
content_check=content_check,
|
# content_check=content_check,
|
||||||
reasoning_content="",
|
# reasoning_content="",
|
||||||
reasoning_content_check=reasoning_content_check
|
# reasoning_content_check=reasoning_content_check
|
||||||
)
|
# )
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
# 生成回复
|
# 生成回复
|
||||||
|
try:
|
||||||
content, reasoning_content = await model.generate_response(prompt)
|
content, reasoning_content = await model.generate_response(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成回复时出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
self._save_to_db(
|
self._save_to_db(
|
||||||
@@ -97,15 +108,17 @@ class ResponseGenerator:
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_check=prompt_check,
|
prompt_check=prompt_check,
|
||||||
content=content,
|
content=content,
|
||||||
content_check=content_check if global_config.enable_kuuki_read else "",
|
# content_check=content_check if global_config.enable_kuuki_read else "",
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
|
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||||
|
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||||
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||||
content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
content: str, reasoning_content: str,):
|
||||||
"""保存对话记录到数据库"""
|
"""保存对话记录到数据库"""
|
||||||
self.db.db.reasoning_logs.insert_one({
|
self.db.db.reasoning_logs.insert_one({
|
||||||
'time': time.time(),
|
'time': time.time(),
|
||||||
@@ -113,8 +126,8 @@ class ResponseGenerator:
|
|||||||
'user': sender_name,
|
'user': sender_name,
|
||||||
'message': message.processed_plain_text,
|
'message': message.processed_plain_text,
|
||||||
'model': self.current_model_type,
|
'model': self.current_model_type,
|
||||||
'reasoning_check': reasoning_content_check,
|
# 'reasoning_check': reasoning_content_check,
|
||||||
'response_check': content_check,
|
# 'response_check': content_check,
|
||||||
'reasoning': reasoning_content,
|
'reasoning': reasoning_content,
|
||||||
'response': content,
|
'response': content,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
@@ -129,9 +142,12 @@ class ResponseGenerator:
|
|||||||
内容:{content}
|
内容:{content}
|
||||||
输出:
|
输出:
|
||||||
'''
|
'''
|
||||||
|
|
||||||
content, _ = await self.model_v3.generate_response(prompt)
|
content, _ = await self.model_v3.generate_response(prompt)
|
||||||
return [content.strip()] if content else ["neutral"]
|
content=content.strip()
|
||||||
|
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
|
||||||
|
return [content]
|
||||||
|
else:
|
||||||
|
return ["neutral"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取情感标签时出错: {e}")
|
print(f"获取情感标签时出错: {e}")
|
||||||
@@ -146,3 +162,41 @@ class ResponseGenerator:
|
|||||||
processed_response = process_llm_response(content)
|
processed_response = process_llm_response(content)
|
||||||
|
|
||||||
return processed_response, emotion_tags
|
return processed_response, emotion_tags
|
||||||
|
|
||||||
|
|
||||||
|
class InitiativeMessageGenerate:
|
||||||
|
def __init__(self):
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
||||||
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
||||||
|
self.model_r1_distill = LLM_request(
|
||||||
|
model=global_config.llm_reasoning_minor, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen_response(self, message: Message):
|
||||||
|
topic_select_prompt, dots_for_select, prompt_template = (
|
||||||
|
prompt_builder._build_initiative_prompt_select(message.group_id)
|
||||||
|
)
|
||||||
|
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
|
||||||
|
print(f"[DEBUG] {content_select} {reasoning}")
|
||||||
|
topics_list = [dot[0] for dot in dots_for_select]
|
||||||
|
if content_select:
|
||||||
|
if content_select in topics_list:
|
||||||
|
select_dot = dots_for_select[topics_list.index(content_select)]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
prompt_check, memory = prompt_builder._build_initiative_prompt_check(
|
||||||
|
select_dot[1], prompt_template
|
||||||
|
)
|
||||||
|
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
|
||||||
|
print(f"[DEBUG] {content_check} {reasoning_check}")
|
||||||
|
if "yes" not in content_check.lower():
|
||||||
|
return None
|
||||||
|
prompt = prompt_builder._build_initiative_prompt(
|
||||||
|
select_dot, prompt_template, memory
|
||||||
|
)
|
||||||
|
content, reasoning = self.model_r1.generate_response(prompt)
|
||||||
|
print(f"[DEBUG] {content} {reasoning}")
|
||||||
|
return content
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from ...common.database import Database
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
import urllib3
|
import urllib3
|
||||||
from .utils_user import get_user_nickname
|
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
|
||||||
from .utils_cq import parse_cq_code
|
from .utils_cq import parse_cq_code
|
||||||
from .cq_code import cq_code_tool,CQCode
|
from .cq_code import cq_code_tool,CQCode
|
||||||
|
|
||||||
@@ -21,46 +21,46 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||||||
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message:
|
class Message:
|
||||||
"""消息数据类"""
|
"""消息数据类"""
|
||||||
|
message_id: int = None
|
||||||
|
time: float = None
|
||||||
|
|
||||||
group_id: int = None
|
group_id: int = None
|
||||||
user_id: int = None
|
|
||||||
user_nickname: str = None # 用户昵称
|
|
||||||
group_name: str = None # 群名称
|
group_name: str = None # 群名称
|
||||||
|
|
||||||
message_id: int = None
|
user_id: int = None
|
||||||
raw_message: str = None
|
user_nickname: str = None # 用户昵称
|
||||||
plain_text: str = None
|
user_cardname: str=None # 用户群昵称
|
||||||
|
|
||||||
message_based_id: int = None
|
raw_message: str = None # 原始消息,包含未解析的cq码
|
||||||
reply_message: Dict = None # 存储回复消息
|
plain_text: str = None # 纯文本
|
||||||
|
|
||||||
message_segments: List[Dict] = None # 存储解析后的消息片段
|
message_segments: List[Dict] = None # 存储解析后的消息片段
|
||||||
processed_plain_text: str = None # 用于存储处理后的plain_text
|
processed_plain_text: str = None # 用于存储处理后的plain_text
|
||||||
detailed_plain_text: str = None # 用于存储详细可读文本
|
detailed_plain_text: str = None # 用于存储详细可读文本
|
||||||
|
|
||||||
time: float = None
|
reply_message: Dict = None # 存储 回复的 源消息
|
||||||
|
|
||||||
is_emoji: bool = False # 是否是表情包
|
is_emoji: bool = False # 是否是表情包
|
||||||
has_emoji: bool = False # 是否包含表情包
|
has_emoji: bool = False # 是否包含表情包
|
||||||
|
|
||||||
translate_cq: bool = True # 是否翻译cq码
|
translate_cq: bool = True # 是否翻译cq码
|
||||||
|
|
||||||
|
|
||||||
reply_benefits: float = 0.0
|
|
||||||
|
|
||||||
type: str = 'received' # 消息类型,可以是received或者send
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.time is None:
|
if self.time is None:
|
||||||
self.time = int(time.time())
|
self.time = int(time.time())
|
||||||
|
|
||||||
|
if not self.group_name:
|
||||||
|
self.group_name = get_groupname(self.group_id)
|
||||||
|
|
||||||
if not self.user_nickname:
|
if not self.user_nickname:
|
||||||
self.user_nickname = get_user_nickname(self.user_id)
|
self.user_nickname = get_user_nickname(self.user_id)
|
||||||
|
|
||||||
if not self.group_name:
|
if not self.user_cardname:
|
||||||
self.group_name = self.get_groupname(self.group_id)
|
self.user_cardname=get_user_cardname(self.user_id)
|
||||||
|
|
||||||
if not self.processed_plain_text:
|
if not self.processed_plain_text:
|
||||||
if self.raw_message:
|
if self.raw_message:
|
||||||
@@ -71,25 +71,13 @@ class Message:
|
|||||||
)
|
)
|
||||||
#将详细翻译为详细可读文本
|
#将详细翻译为详细可读文本
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
|
||||||
|
try:
|
||||||
|
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
|
||||||
|
except:
|
||||||
name = self.user_nickname or f"用户{self.user_id}"
|
name = self.user_nickname or f"用户{self.user_id}"
|
||||||
content = self.processed_plain_text
|
content = self.processed_plain_text
|
||||||
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n"
|
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n"
|
||||||
|
|
||||||
|
|
||||||
def get_groupname(self, group_id: int) -> str:
|
|
||||||
if not group_id:
|
|
||||||
return "未知群"
|
|
||||||
group_id = int(group_id)
|
|
||||||
# 使用数据库单例
|
|
||||||
db = Database.get_instance()
|
|
||||||
# 查找用户,打印查询条件和结果
|
|
||||||
query = {'group_id': group_id}
|
|
||||||
group = db.db.group_info.find_one(query)
|
|
||||||
if group:
|
|
||||||
return group.get('group_name')
|
|
||||||
else:
|
|
||||||
return f"群{group_id}"
|
|
||||||
|
|
||||||
def parse_message_segments(self, message: str) -> List[CQCode]:
|
def parse_message_segments(self, message: str) -> List[CQCode]:
|
||||||
"""
|
"""
|
||||||
将消息解析为片段列表,包括纯文本和CQ码
|
将消息解析为片段列表,包括纯文本和CQ码
|
||||||
@@ -159,49 +147,58 @@ class Message_Thinking:
|
|||||||
self.group_id = message.group_id
|
self.group_id = message.group_id
|
||||||
self.user_id = message.user_id
|
self.user_id = message.user_id
|
||||||
self.user_nickname = message.user_nickname
|
self.user_nickname = message.user_nickname
|
||||||
|
self.user_cardname = message.user_cardname
|
||||||
self.group_name = message.group_name
|
self.group_name = message.group_name
|
||||||
|
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
|
|
||||||
# 思考状态相关属性
|
# 思考状态相关属性
|
||||||
self.thinking_text = "正在思考..."
|
self.thinking_start_time = int(time.time())
|
||||||
self.time = int(time.time())
|
|
||||||
self.thinking_time = 0
|
self.thinking_time = 0
|
||||||
|
self.interupt=False
|
||||||
|
|
||||||
def update_thinking_time(self):
|
def update_thinking_time(self):
|
||||||
self.thinking_time = round(time.time(), 2) - self.time
|
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
|
||||||
|
|
||||||
@property
|
|
||||||
def processed_plain_text(self) -> str:
|
|
||||||
"""获取处理后的文本"""
|
|
||||||
return self.thinking_text
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
@dataclass
|
||||||
return f"[思考中] 群:{self.group_id} 用户:{self.user_nickname} 时间:{self.time} 消息ID:{self.message_id}"
|
class Message_Sending(Message):
|
||||||
|
"""发送中的消息类"""
|
||||||
|
thinking_start_time: float = None # 思考开始时间
|
||||||
|
thinking_time: float = None # 思考时间
|
||||||
|
|
||||||
|
reply_message_id: int = None # 存储 回复的 源消息ID
|
||||||
|
|
||||||
|
def update_thinking_time(self):
|
||||||
|
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
|
||||||
|
return self.thinking_time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageSet:
|
class MessageSet:
|
||||||
"""消息集合类,可以存储多个相关的消息"""
|
"""消息集合类,可以存储多个发送消息"""
|
||||||
def __init__(self, group_id: int, user_id: int, message_id: str):
|
def __init__(self, group_id: int, user_id: int, message_id: str):
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.messages: List[Message] = []
|
self.messages: List[Message_Sending] = [] # 修改类型标注
|
||||||
self.time = round(time.time(), 2)
|
self.time = round(time.time(), 2)
|
||||||
|
|
||||||
def add_message(self, message: Message) -> None:
|
def add_message(self, message: Message_Sending) -> None:
|
||||||
"""添加消息到集合"""
|
"""添加消息到集合,只接受Message_Sending类型"""
|
||||||
|
if not isinstance(message, Message_Sending):
|
||||||
|
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
# 按时间排序
|
# 按时间排序
|
||||||
self.messages.sort(key=lambda x: x.time)
|
self.messages.sort(key=lambda x: x.time)
|
||||||
|
|
||||||
def get_message_by_index(self, index: int) -> Optional[Message]:
|
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
|
||||||
"""通过索引获取消息"""
|
"""通过索引获取消息"""
|
||||||
if 0 <= index < len(self.messages):
|
if 0 <= index < len(self.messages):
|
||||||
return self.messages[index]
|
return self.messages[index]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_message_by_time(self, target_time: float) -> Optional[Message]:
|
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
|
||||||
"""获取最接近指定时间的消息"""
|
"""获取最接近指定时间的消息"""
|
||||||
if not self.messages:
|
if not self.messages:
|
||||||
return None
|
return None
|
||||||
@@ -222,7 +219,7 @@ class MessageSet:
|
|||||||
"""清空所有消息"""
|
"""清空所有消息"""
|
||||||
self.messages.clear()
|
self.messages.clear()
|
||||||
|
|
||||||
def remove_message(self, message: Message) -> bool:
|
def remove_message(self, message: Message_Sending) -> bool:
|
||||||
"""移除指定消息"""
|
"""移除指定消息"""
|
||||||
if message in self.messages:
|
if message in self.messages:
|
||||||
self.messages.remove(message)
|
self.messages.remove(message)
|
||||||
@@ -237,7 +234,3 @@ class MessageSet:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,251 +0,0 @@
|
|||||||
from typing import Union, List, Optional, Deque, Dict
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot, MessageSegment
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import os
|
|
||||||
from .message import Message, Message_Thinking, MessageSet
|
|
||||||
from .cq_code import CQCode
|
|
||||||
from collections import deque
|
|
||||||
import time
|
|
||||||
from .storage import MessageStorage
|
|
||||||
from .config import global_config
|
|
||||||
from .cq_code import cq_code_tool
|
|
||||||
|
|
||||||
if os.name == "nt":
|
|
||||||
from .message_visualizer import message_visualizer
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SendTemp:
|
|
||||||
"""单个群组的临时消息队列管理器"""
|
|
||||||
def __init__(self, group_id: int, max_size: int = 100):
|
|
||||||
self.group_id = group_id
|
|
||||||
self.max_size = max_size
|
|
||||||
self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
|
|
||||||
self.last_send_time = 0
|
|
||||||
|
|
||||||
def add(self, message: Message) -> None:
|
|
||||||
"""按时间顺序添加消息到队列"""
|
|
||||||
if not self.messages:
|
|
||||||
self.messages.append(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 按时间顺序插入
|
|
||||||
if message.time >= self.messages[-1].time:
|
|
||||||
self.messages.append(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 使用二分查找找到合适的插入位置
|
|
||||||
messages_list = list(self.messages)
|
|
||||||
left, right = 0, len(messages_list)
|
|
||||||
|
|
||||||
while left < right:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
if messages_list[mid].time < message.time:
|
|
||||||
left = mid + 1
|
|
||||||
else:
|
|
||||||
right = mid
|
|
||||||
|
|
||||||
# 重建消息队列,保持时间顺序
|
|
||||||
new_messages = deque(maxlen=self.max_size)
|
|
||||||
new_messages.extend(messages_list[:left])
|
|
||||||
new_messages.append(message)
|
|
||||||
new_messages.extend(messages_list[left:])
|
|
||||||
self.messages = new_messages
|
|
||||||
def get_earliest_message(self) -> Optional[Message]:
|
|
||||||
"""获取时间最早的消息"""
|
|
||||||
message = self.messages.popleft() if self.messages else None
|
|
||||||
return message
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""清空队列"""
|
|
||||||
self.messages.clear()
|
|
||||||
|
|
||||||
def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
|
|
||||||
"""获取所有待发送的消息"""
|
|
||||||
if group_id is None:
|
|
||||||
return list(self.messages)
|
|
||||||
return [msg for msg in self.messages if msg.group_id == group_id]
|
|
||||||
|
|
||||||
def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
|
|
||||||
"""查看下一条要发送的消息(不移除)"""
|
|
||||||
return self.messages[0] if self.messages else None
|
|
||||||
|
|
||||||
def has_messages(self) -> bool:
|
|
||||||
"""检查是否有待发送的消息"""
|
|
||||||
return bool(self.messages)
|
|
||||||
|
|
||||||
def count(self, group_id: Optional[int] = None) -> int:
|
|
||||||
"""获取待发送消息数量"""
|
|
||||||
if group_id is None:
|
|
||||||
return len(self.messages)
|
|
||||||
return len([msg for msg in self.messages if msg.group_id == group_id])
|
|
||||||
|
|
||||||
def get_last_send_time(self) -> float:
|
|
||||||
"""获取最后一次发送时间"""
|
|
||||||
return self.last_send_time
|
|
||||||
|
|
||||||
def update_send_time(self):
|
|
||||||
"""更新最后发送时间"""
|
|
||||||
self.last_send_time = time.time()
|
|
||||||
|
|
||||||
class SendTempContainer:
|
|
||||||
"""管理所有群组的消息缓存容器"""
|
|
||||||
def __init__(self):
|
|
||||||
self.temp_queues: Dict[int, SendTemp] = {}
|
|
||||||
|
|
||||||
def get_queue(self, group_id: int) -> SendTemp:
|
|
||||||
"""获取或创建群组的消息队列"""
|
|
||||||
if group_id not in self.temp_queues:
|
|
||||||
self.temp_queues[group_id] = SendTemp(group_id)
|
|
||||||
return self.temp_queues[group_id]
|
|
||||||
|
|
||||||
def add_message(self, message: Message) -> None:
|
|
||||||
"""添加消息到对应群组的队列"""
|
|
||||||
queue = self.get_queue(message.group_id)
|
|
||||||
queue.add(message)
|
|
||||||
|
|
||||||
def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
|
|
||||||
"""获取指定群组的所有待发送消息"""
|
|
||||||
queue = self.get_queue(group_id)
|
|
||||||
return queue.get_all()
|
|
||||||
|
|
||||||
def has_messages(self, group_id: int) -> bool:
|
|
||||||
"""检查指定群组是否有待发送消息"""
|
|
||||||
queue = self.get_queue(group_id)
|
|
||||||
return queue.has_messages()
|
|
||||||
|
|
||||||
def get_all_groups(self) -> List[int]:
|
|
||||||
"""获取所有有待发送消息的群组ID"""
|
|
||||||
return list(self.temp_queues.keys())
|
|
||||||
|
|
||||||
def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
|
|
||||||
queue = self.get_queue(message_obj.group_id)
|
|
||||||
# 使用列表解析找到匹配的消息索引
|
|
||||||
matching_indices = [
|
|
||||||
i for i, msg in enumerate(queue.messages)
|
|
||||||
if msg.message_id == message_obj.message_id
|
|
||||||
]
|
|
||||||
|
|
||||||
if not matching_indices:
|
|
||||||
return False
|
|
||||||
|
|
||||||
index = matching_indices[0] # 获取第一个匹配的索引
|
|
||||||
|
|
||||||
# 将消息转换为列表以便修改
|
|
||||||
messages = list(queue.messages)
|
|
||||||
|
|
||||||
# 根据消息类型处理
|
|
||||||
if isinstance(message_obj, MessageSet):
|
|
||||||
messages.pop(index)
|
|
||||||
# 在原位置插入新消息组
|
|
||||||
for i, single_message in enumerate(message_obj.messages):
|
|
||||||
messages.insert(index + i, single_message)
|
|
||||||
# print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
|
|
||||||
else:
|
|
||||||
# 直接替换原消息
|
|
||||||
messages[index] = message_obj
|
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
|
|
||||||
|
|
||||||
# 重建队列
|
|
||||||
queue.messages.clear()
|
|
||||||
for msg in messages:
|
|
||||||
queue.messages.append(msg)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class MessageSendControl:
|
|
||||||
"""消息发送控制器"""
|
|
||||||
def __init__(self):
|
|
||||||
self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
|
|
||||||
self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
|
|
||||||
self.max_retry = 3 # 最大重试次数
|
|
||||||
self.send_temp_container = SendTempContainer()
|
|
||||||
self._running = True
|
|
||||||
self._paused = False
|
|
||||||
self._current_bot = None
|
|
||||||
self.storage = MessageStorage() # 添加存储实例
|
|
||||||
try:
|
|
||||||
message_visualizer.start()
|
|
||||||
except(NameError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_bot(self, bot: Bot):
|
|
||||||
"""设置当前bot实例"""
|
|
||||||
self._current_bot = bot
|
|
||||||
|
|
||||||
async def process_group_messages(self, group_id: int):
|
|
||||||
queue = self.send_temp_container.get_queue(group_id)
|
|
||||||
if queue.has_messages():
|
|
||||||
message = queue.peek_next()
|
|
||||||
# 处理消息的逻辑
|
|
||||||
if isinstance(message, Message_Thinking):
|
|
||||||
message.update_thinking_time()
|
|
||||||
thinking_time = message.thinking_time
|
|
||||||
if thinking_time < 90: # 最少思考2秒
|
|
||||||
if int(thinking_time) % 15 == 0:
|
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
|
|
||||||
queue.get_earliest_message() # 移除超时的思考消息
|
|
||||||
return
|
|
||||||
elif isinstance(message, Message):
|
|
||||||
message = queue.get_earliest_message()
|
|
||||||
if message and message.processed_plain_text:
|
|
||||||
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
|
|
||||||
cost_time = round(time.time(), 2) - message.time
|
|
||||||
if cost_time > 40:
|
|
||||||
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
|
|
||||||
cur_time = time.time()
|
|
||||||
await self._current_bot.send_group_msg(
|
|
||||||
group_id=group_id,
|
|
||||||
message=str(message.processed_plain_text),
|
|
||||||
auto_escape=False
|
|
||||||
)
|
|
||||||
cost_time = round(time.time(), 2) - cur_time
|
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息发送时间: {cost_time}秒")
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
|
||||||
print(f"\033[1;32m群 {group_id} 消息, 用户 {global_config.BOT_NICKNAME}, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
|
|
||||||
|
|
||||||
if message.is_emoji:
|
|
||||||
message.processed_plain_text = "[表情包]"
|
|
||||||
await self.storage.store_message(message, None)
|
|
||||||
else:
|
|
||||||
await self.storage.store_message(message, None)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
queue.update_send_time()
|
|
||||||
if queue.has_messages():
|
|
||||||
await asyncio.sleep(
|
|
||||||
random.uniform(
|
|
||||||
self.message_interval[0],
|
|
||||||
self.message_interval[1]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start_processor(self, bot: Bot):
|
|
||||||
"""启动消息处理器"""
|
|
||||||
self._current_bot = bot
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
await asyncio.sleep(1.5)
|
|
||||||
tasks = []
|
|
||||||
for group_id in self.send_temp_container.get_all_groups():
|
|
||||||
tasks.append(self.process_group_messages(group_id))
|
|
||||||
|
|
||||||
# 并行处理所有群组的消息
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
try:
|
|
||||||
message_visualizer.update_content(self.send_temp_container)
|
|
||||||
except(NameError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_typing_speed(self, min_speed: float, max_speed: float):
|
|
||||||
"""设置打字速度范围"""
|
|
||||||
self.typing_speed = (min_speed, max_speed)
|
|
||||||
|
|
||||||
# 创建全局实例
|
|
||||||
message_sender = MessageSendControl()
|
|
||||||
225
src/plugins/chat/message_sender.py
Normal file
225
src/plugins/chat/message_sender.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
from typing import Union, List, Optional, Dict
|
||||||
|
from collections import deque
|
||||||
|
from .message import Message, Message_Thinking, MessageSet, Message_Sending
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
from .config import global_config
|
||||||
|
from .storage import MessageStorage
|
||||||
|
from .cq_code import cq_code_tool
|
||||||
|
import random
|
||||||
|
from .utils import calculate_typing_time
|
||||||
|
|
||||||
|
class Message_Sender:
|
||||||
|
"""发送器"""
|
||||||
|
def __init__(self):
|
||||||
|
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
|
||||||
|
self.last_send_time = 0
|
||||||
|
self._current_bot = None
|
||||||
|
|
||||||
|
def set_bot(self, bot: Bot):
|
||||||
|
"""设置当前bot实例"""
|
||||||
|
self._current_bot = bot
|
||||||
|
|
||||||
|
async def send_group_message(
|
||||||
|
self,
|
||||||
|
group_id: int,
|
||||||
|
send_text: str,
|
||||||
|
auto_escape: bool = False,
|
||||||
|
reply_message_id: int = None,
|
||||||
|
at_user_id: int = None
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
if not self._current_bot:
|
||||||
|
raise RuntimeError("Bot未设置,请先调用set_bot方法设置bot实例")
|
||||||
|
|
||||||
|
message = send_text
|
||||||
|
|
||||||
|
# 如果需要回复
|
||||||
|
if reply_message_id:
|
||||||
|
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
|
||||||
|
message = reply_cq + message
|
||||||
|
|
||||||
|
# 如果需要at
|
||||||
|
# if at_user_id:
|
||||||
|
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
||||||
|
# message = at_cq + " " + message
|
||||||
|
|
||||||
|
|
||||||
|
typing_time = calculate_typing_time(message)
|
||||||
|
if typing_time > 10:
|
||||||
|
typing_time = 10
|
||||||
|
await asyncio.sleep(typing_time)
|
||||||
|
|
||||||
|
# 发送消息
|
||||||
|
try:
|
||||||
|
await self._current_bot.send_group_msg(
|
||||||
|
group_id=group_id,
|
||||||
|
message=message,
|
||||||
|
auto_escape=auto_escape
|
||||||
|
)
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误 {e}")
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContainer:
|
||||||
|
"""单个群的发送/思考消息容器"""
|
||||||
|
def __init__(self, group_id: int, max_size: int = 100):
|
||||||
|
self.group_id = group_id
|
||||||
|
self.max_size = max_size
|
||||||
|
self.messages = []
|
||||||
|
self.last_send_time = 0
|
||||||
|
self.thinking_timeout = 20 # 思考超时时间(秒)
|
||||||
|
|
||||||
|
def get_timeout_messages(self) -> List[Message_Sending]:
|
||||||
|
"""获取所有超时的Message_Sending对象(思考时间超过30秒),按thinking_start_time排序"""
|
||||||
|
current_time = time.time()
|
||||||
|
timeout_messages = []
|
||||||
|
|
||||||
|
for msg in self.messages:
|
||||||
|
if isinstance(msg, Message_Sending):
|
||||||
|
if current_time - msg.thinking_start_time > self.thinking_timeout:
|
||||||
|
timeout_messages.append(msg)
|
||||||
|
|
||||||
|
# 按thinking_start_time排序,时间早的在前面
|
||||||
|
timeout_messages.sort(key=lambda x: x.thinking_start_time)
|
||||||
|
|
||||||
|
return timeout_messages
|
||||||
|
|
||||||
|
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
|
||||||
|
"""获取thinking_start_time最早的消息对象"""
|
||||||
|
if not self.messages:
|
||||||
|
return None
|
||||||
|
earliest_time = float('inf')
|
||||||
|
earliest_message = None
|
||||||
|
for msg in self.messages:
|
||||||
|
msg_time = msg.thinking_start_time
|
||||||
|
if msg_time < earliest_time:
|
||||||
|
earliest_time = msg_time
|
||||||
|
earliest_message = msg
|
||||||
|
return earliest_message
|
||||||
|
|
||||||
|
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
|
||||||
|
"""添加消息到队列"""
|
||||||
|
print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
|
||||||
|
if isinstance(message, MessageSet):
|
||||||
|
for single_message in message.messages:
|
||||||
|
self.messages.append(single_message)
|
||||||
|
else:
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
|
||||||
|
"""移除消息,如果消息存在则返回True,否则返回False"""
|
||||||
|
try:
|
||||||
|
if message in self.messages:
|
||||||
|
self.messages.remove(message)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def has_messages(self) -> bool:
|
||||||
|
"""检查是否有待发送的消息"""
|
||||||
|
return bool(self.messages)
|
||||||
|
|
||||||
|
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
|
||||||
|
"""获取所有消息"""
|
||||||
|
return list(self.messages)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageManager:
|
||||||
|
"""管理所有群的消息容器"""
|
||||||
|
def __init__(self):
|
||||||
|
self.containers: Dict[int, MessageContainer] = {}
|
||||||
|
self.storage = MessageStorage()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
def get_container(self, group_id: int) -> MessageContainer:
|
||||||
|
"""获取或创建群的消息容器"""
|
||||||
|
if group_id not in self.containers:
|
||||||
|
self.containers[group_id] = MessageContainer(group_id)
|
||||||
|
return self.containers[group_id]
|
||||||
|
|
||||||
|
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
|
||||||
|
container = self.get_container(message.group_id)
|
||||||
|
container.add_message(message)
|
||||||
|
|
||||||
|
async def process_group_messages(self, group_id: int):
|
||||||
|
"""处理群消息"""
|
||||||
|
# if int(time.time() / 3) == time.time() / 3:
|
||||||
|
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
||||||
|
container = self.get_container(group_id)
|
||||||
|
if container.has_messages():
|
||||||
|
#最早的对象,可能是思考消息,也可能是发送消息
|
||||||
|
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
|
||||||
|
|
||||||
|
#一个月后删了
|
||||||
|
if not message_earliest:
|
||||||
|
print(f"\033[1;34m[BUG,如果出现这个,说明有BUG,3月4日留]\033[0m ")
|
||||||
|
return
|
||||||
|
|
||||||
|
#如果是思考消息
|
||||||
|
if isinstance(message_earliest, Message_Thinking):
|
||||||
|
#优先等待这条消息
|
||||||
|
message_earliest.update_thinking_time()
|
||||||
|
thinking_time = message_earliest.thinking_time
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒")
|
||||||
|
else:# 如果不是message_thinking就只能是message_sending
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||||
|
#直接发,等什么呢
|
||||||
|
if message_earliest.update_thinking_time() < 30:
|
||||||
|
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
|
||||||
|
else:
|
||||||
|
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
|
||||||
|
|
||||||
|
#移除消息
|
||||||
|
if message_earliest.is_emoji:
|
||||||
|
message_earliest.processed_plain_text = "[表情包]"
|
||||||
|
await self.storage.store_message(message_earliest, None)
|
||||||
|
|
||||||
|
container.remove_message(message_earliest)
|
||||||
|
|
||||||
|
#获取并处理超时消息
|
||||||
|
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
|
||||||
|
if message_timeout:
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
|
||||||
|
for msg in message_timeout:
|
||||||
|
if msg == message_earliest:
|
||||||
|
continue # 跳过已经处理过的消息
|
||||||
|
|
||||||
|
try:
|
||||||
|
#发送
|
||||||
|
if msg.update_thinking_time() < 30:
|
||||||
|
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
|
||||||
|
else:
|
||||||
|
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
||||||
|
|
||||||
|
#如果是表情包,则替换为"[表情包]"
|
||||||
|
if msg.is_emoji:
|
||||||
|
msg.processed_plain_text = "[表情包]"
|
||||||
|
await self.storage.store_message(msg, None)
|
||||||
|
|
||||||
|
# 安全地移除消息
|
||||||
|
if not container.remove_message(msg):
|
||||||
|
print(f"\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def start_processor(self):
|
||||||
|
"""启动消息处理器"""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
tasks = []
|
||||||
|
for group_id in self.containers.keys():
|
||||||
|
tasks.append(self.process_group_messages(group_id))
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# 创建全局消息管理器实例
|
||||||
|
message_manager = MessageManager()
|
||||||
|
# 创建全局发送器实例
|
||||||
|
message_sender = Message_Sender()
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
from typing import List, Optional, Dict
|
|
||||||
from .message import Message
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
class MessageStream:
|
|
||||||
"""单个群组的消息流容器"""
|
|
||||||
def __init__(self, group_id: int, max_size: int = 1000):
|
|
||||||
self.group_id = group_id
|
|
||||||
self.messages = deque(maxlen=max_size)
|
|
||||||
self.max_size = max_size
|
|
||||||
self.last_save_time = time.time()
|
|
||||||
|
|
||||||
# 确保日志目录存在
|
|
||||||
self.log_dir = os.path.join("log", str(self.group_id))
|
|
||||||
os.makedirs(self.log_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 启动自动保存任务
|
|
||||||
asyncio.create_task(self._auto_save())
|
|
||||||
|
|
||||||
async def _auto_save(self):
|
|
||||||
"""每30秒自动保存一次消息记录"""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(30) # 等待30秒
|
|
||||||
await self.save_to_log()
|
|
||||||
|
|
||||||
async def save_to_log(self):
|
|
||||||
"""将消息保存到日志文件"""
|
|
||||||
try:
|
|
||||||
current_time = time.time()
|
|
||||||
# 只有有新消息时才保存
|
|
||||||
if not self.messages or self.last_save_time == current_time:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 生成日志文件名 (使用当前日期)
|
|
||||||
date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
|
|
||||||
log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
|
|
||||||
|
|
||||||
# 获取需要保存的新消息
|
|
||||||
new_messages = [
|
|
||||||
msg for msg in self.messages
|
|
||||||
if msg.time > self.last_save_time
|
|
||||||
]
|
|
||||||
|
|
||||||
if not new_messages:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 将消息转换为可序列化的格式
|
|
||||||
message_logs = []
|
|
||||||
for msg in new_messages:
|
|
||||||
message_logs.append({
|
|
||||||
"time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
|
|
||||||
"user_id": msg.user_id,
|
|
||||||
"user_nickname": msg.user_nickname,
|
|
||||||
"message_id": msg.message_id,
|
|
||||||
"raw_message": msg.raw_message,
|
|
||||||
"processed_text": msg.processed_plain_text
|
|
||||||
})
|
|
||||||
|
|
||||||
# 追加写入日志文件
|
|
||||||
with open(log_file, "a", encoding="utf-8") as f:
|
|
||||||
for log in message_logs:
|
|
||||||
f.write(json.dumps(log, ensure_ascii=False) + "\n")
|
|
||||||
|
|
||||||
self.last_save_time = current_time
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
|
|
||||||
|
|
||||||
def add_message(self, message: Message) -> None:
|
|
||||||
"""按时间顺序添加新消息到队列
|
|
||||||
|
|
||||||
使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Message对象,要添加的新消息
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 空队列或消息应该添加到末尾的情况
|
|
||||||
if (not self.messages or
|
|
||||||
message.time >= self.messages[-1].time):
|
|
||||||
self.messages.append(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 消息应该添加到开头的情况
|
|
||||||
if message.time <= self.messages[0].time:
|
|
||||||
self.messages.appendleft(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 使用二分查找在现有队列中找到合适的插入位置
|
|
||||||
left, right = 0, len(self.messages) - 1
|
|
||||||
while left <= right:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
if self.messages[mid].time < message.time:
|
|
||||||
left = mid + 1
|
|
||||||
else:
|
|
||||||
right = mid - 1
|
|
||||||
|
|
||||||
temp = list(self.messages)
|
|
||||||
temp.insert(left, message)
|
|
||||||
|
|
||||||
# 如果超出最大长度,移除多余的消息
|
|
||||||
if len(temp) > self.max_size:
|
|
||||||
temp = temp[-self.max_size:]
|
|
||||||
|
|
||||||
# 重建队列
|
|
||||||
self.messages = deque(temp, maxlen=self.max_size)
|
|
||||||
|
|
||||||
async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
|
|
||||||
"""从数据库中获取最近的消息记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
count: 需要获取的消息数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Message]: 最近的消息列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from ...common.database import Database
|
|
||||||
db = Database.get_instance()
|
|
||||||
|
|
||||||
# 从数据库中查询最近的消息
|
|
||||||
recent_messages = list(db.db.messages.find(
|
|
||||||
{"group_id": self.group_id},
|
|
||||||
{
|
|
||||||
"time": 1,
|
|
||||||
"user_id": 1,
|
|
||||||
"user_nickname": 1,
|
|
||||||
"message_id": 1,
|
|
||||||
"raw_message": 1,
|
|
||||||
"processed_text": 1
|
|
||||||
}
|
|
||||||
).sort("time", -1).limit(count))
|
|
||||||
|
|
||||||
if not recent_messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 转换为 Message 对象
|
|
||||||
from .message import Message
|
|
||||||
messages = []
|
|
||||||
for msg_data in recent_messages:
|
|
||||||
msg = Message(
|
|
||||||
time=msg_data["time"],
|
|
||||||
user_id=msg_data["user_id"],
|
|
||||||
user_nickname=msg_data.get("user_nickname", ""),
|
|
||||||
message_id=msg_data["message_id"],
|
|
||||||
raw_message=msg_data["raw_message"],
|
|
||||||
processed_plain_text=msg_data.get("processed_text", ""),
|
|
||||||
group_id=self.group_id
|
|
||||||
)
|
|
||||||
messages.append(msg)
|
|
||||||
|
|
||||||
return list(reversed(messages)) # 返回按时间正序的消息
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_recent_messages(self, count: int = 10) -> List[Message]:
|
|
||||||
"""获取最近的n条消息(从内存队列)"""
|
|
||||||
print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
|
|
||||||
return list(self.messages)[-count:]
|
|
||||||
|
|
||||||
def get_messages_in_timerange(self,
|
|
||||||
start_time: Optional[float] = None,
|
|
||||||
end_time: Optional[float] = None) -> List[Message]:
|
|
||||||
"""获取时间范围内的消息"""
|
|
||||||
if start_time is None:
|
|
||||||
start_time = time.time() - 3600
|
|
||||||
if end_time is None:
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
return [
|
|
||||||
msg for msg in self.messages
|
|
||||||
if start_time <= msg.time <= end_time
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
|
|
||||||
"""获取特定用户的最近消息"""
|
|
||||||
user_messages = [msg for msg in self.messages if msg.user_id == user_id]
|
|
||||||
return user_messages[-count:]
|
|
||||||
|
|
||||||
def clear_old_messages(self, hours: int = 24) -> None:
|
|
||||||
"""清理旧消息"""
|
|
||||||
cutoff_time = time.time() - (hours * 3600)
|
|
||||||
self.messages = deque(
|
|
||||||
[msg for msg in self.messages if msg.time > cutoff_time],
|
|
||||||
maxlen=self.max_size
|
|
||||||
)
|
|
||||||
|
|
||||||
class MessageStreamContainer:
|
|
||||||
"""管理所有群组的消息流容器"""
|
|
||||||
def __init__(self, max_size: int = 1000):
|
|
||||||
self.streams: Dict[int, MessageStream] = {}
|
|
||||||
self.max_size = max_size
|
|
||||||
|
|
||||||
async def save_all_logs(self):
|
|
||||||
"""保存所有群组的消息日志"""
|
|
||||||
for stream in self.streams.values():
|
|
||||||
await stream.save_to_log()
|
|
||||||
|
|
||||||
def add_message(self, message: Message) -> None:
|
|
||||||
"""添加消息到对应群组的消息流"""
|
|
||||||
if not message.group_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.group_id not in self.streams:
|
|
||||||
self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
|
|
||||||
|
|
||||||
self.streams[message.group_id].add_message(message)
|
|
||||||
|
|
||||||
def get_stream(self, group_id: int) -> Optional[MessageStream]:
|
|
||||||
"""获取特定群组的消息流"""
|
|
||||||
return self.streams.get(group_id)
|
|
||||||
|
|
||||||
def get_all_streams(self) -> Dict[int, MessageStream]:
|
|
||||||
"""获取所有群组的消息流"""
|
|
||||||
return self.streams
|
|
||||||
|
|
||||||
def clear_old_messages(self, hours: int = 24) -> None:
|
|
||||||
"""清理所有群组的旧消息"""
|
|
||||||
for stream in self.streams.values():
|
|
||||||
stream.clear_old_messages(hours)
|
|
||||||
|
|
||||||
def get_group_stats(self, group_id: int) -> Dict:
|
|
||||||
"""获取群组的消息统计信息"""
|
|
||||||
stream = self.streams.get(group_id)
|
|
||||||
if not stream:
|
|
||||||
return {
|
|
||||||
"total_messages": 0,
|
|
||||||
"unique_users": 0,
|
|
||||||
"active_hours": [],
|
|
||||||
"most_active_user": None
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = stream.messages
|
|
||||||
user_counts = {}
|
|
||||||
hour_counts = {}
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
|
|
||||||
hour = datetime.fromtimestamp(msg.time).hour
|
|
||||||
hour_counts[hour] = hour_counts.get(hour, 0) + 1
|
|
||||||
|
|
||||||
most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
|
|
||||||
active_hours = sorted(
|
|
||||||
hour_counts.items(),
|
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
)[:5]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_messages": len(messages),
|
|
||||||
"unique_users": len(user_counts),
|
|
||||||
"active_hours": active_hours,
|
|
||||||
"most_active_user": most_active_user
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建全局实例
|
|
||||||
message_stream_container = MessageStreamContainer()
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
import queue
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Dict
|
|
||||||
from .message import Message_Thinking
|
|
||||||
|
|
||||||
class MessageVisualizer:
|
|
||||||
def __init__(self):
|
|
||||||
self.process = None
|
|
||||||
self.message_queue = queue.Queue()
|
|
||||||
self.is_running = False
|
|
||||||
self.content_file = "message_queue_content.txt"
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if self.process is None:
|
|
||||||
# 创建用于显示的批处理文件
|
|
||||||
with open("message_queue_window.bat", "w", encoding="utf-8") as f:
|
|
||||||
f.write('@echo off\n')
|
|
||||||
f.write('chcp 65001\n') # 设置UTF-8编码
|
|
||||||
f.write('title Message Queue Visualizer\n')
|
|
||||||
f.write('echo Waiting for message queue updates...\n')
|
|
||||||
f.write(':loop\n')
|
|
||||||
f.write('if exist "queue_update.txt" (\n')
|
|
||||||
f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
|
|
||||||
f.write(' del "queue_update.txt"\n')
|
|
||||||
f.write(' cls\n')
|
|
||||||
f.write(' type "message_queue_content.txt"\n')
|
|
||||||
f.write(')\n')
|
|
||||||
f.write('timeout /t 1 /nobreak >nul\n')
|
|
||||||
f.write('goto loop\n')
|
|
||||||
|
|
||||||
# 清空内容文件
|
|
||||||
with open(self.content_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write("")
|
|
||||||
|
|
||||||
# 启动新窗口
|
|
||||||
startupinfo = subprocess.STARTUPINFO()
|
|
||||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
|
||||||
self.process = subprocess.Popen(
|
|
||||||
['cmd', '/c', 'start', 'message_queue_window.bat'],
|
|
||||||
shell=True,
|
|
||||||
startupinfo=startupinfo
|
|
||||||
)
|
|
||||||
self.is_running = True
|
|
||||||
|
|
||||||
# 启动处理线程
|
|
||||||
threading.Thread(target=self._process_messages, daemon=True).start()
|
|
||||||
|
|
||||||
def _process_messages(self):
|
|
||||||
while self.is_running:
|
|
||||||
try:
|
|
||||||
# 获取新消息
|
|
||||||
text = self.message_queue.get(timeout=1)
|
|
||||||
# 写入更新文件
|
|
||||||
with open("queue_update.txt", "w", encoding="utf-8") as f:
|
|
||||||
f.write(text)
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理队列可视化内容时出错: {e}")
|
|
||||||
|
|
||||||
def update_content(self, send_temp_container):
|
|
||||||
"""更新显示内容"""
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
display_text = f"Message Queue Status - {current_time}\n"
|
|
||||||
display_text += "=" * 50 + "\n\n"
|
|
||||||
|
|
||||||
# 遍历所有群组的队列
|
|
||||||
for group_id, queue in send_temp_container.temp_queues.items():
|
|
||||||
display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
|
|
||||||
display_text += f"消息队列长度: {len(queue.messages)}\n"
|
|
||||||
display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
|
|
||||||
display_text += "\n消息队列内容:\n"
|
|
||||||
|
|
||||||
# 显示队列中的消息
|
|
||||||
if not queue.messages:
|
|
||||||
display_text += " [空队列]\n"
|
|
||||||
else:
|
|
||||||
for i, msg in enumerate(queue.messages):
|
|
||||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
|
||||||
display_text += f"\n--- 消息 {i+1} ---\n"
|
|
||||||
|
|
||||||
if isinstance(msg, Message_Thinking):
|
|
||||||
display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
|
|
||||||
display_text += f"时间: {msg_time}\n"
|
|
||||||
display_text += f"消息ID: {msg.message_id}\n"
|
|
||||||
display_text += f"群组: {msg.group_id}\n"
|
|
||||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
|
||||||
display_text += f"内容: {msg.thinking_text}\n"
|
|
||||||
display_text += f"思考时间: {int(msg.thinking_time)}秒\n"
|
|
||||||
else:
|
|
||||||
display_text += f"类型: 普通消息\n"
|
|
||||||
display_text += f"时间: {msg_time}\n"
|
|
||||||
display_text += f"消息ID: {msg.message_id}\n"
|
|
||||||
display_text += f"群组: {msg.group_id}\n"
|
|
||||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
|
||||||
if hasattr(msg, 'is_emoji') and msg.is_emoji:
|
|
||||||
display_text += f"内容: [表情包消息]\n"
|
|
||||||
else:
|
|
||||||
# 显示原始消息和处理后的消息
|
|
||||||
display_text += f"原始内容: {msg.raw_message[:50]}...\n"
|
|
||||||
display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
|
|
||||||
|
|
||||||
if msg.reply_message:
|
|
||||||
display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
|
|
||||||
|
|
||||||
display_text += f"\n{'-' * 50}\n"
|
|
||||||
|
|
||||||
# 添加统计信息
|
|
||||||
display_text += "\n总体统计:\n"
|
|
||||||
display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
|
|
||||||
total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
|
|
||||||
display_text += f"总消息数: {total_messages}\n"
|
|
||||||
thinking_messages = sum(
|
|
||||||
sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
|
|
||||||
for q in send_temp_container.temp_queues.values()
|
|
||||||
)
|
|
||||||
display_text += f"思考中消息数: {thinking_messages}\n"
|
|
||||||
|
|
||||||
self.message_queue.put(display_text)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self.is_running = False
|
|
||||||
if self.process:
|
|
||||||
self.process.terminate()
|
|
||||||
self.process = None
|
|
||||||
# 清理文件
|
|
||||||
for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
|
|
||||||
if os.path.exists(file):
|
|
||||||
os.remove(file)
|
|
||||||
|
|
||||||
# 创建全局单例
|
|
||||||
message_visualizer = MessageVisualizer()
|
|
||||||
@@ -36,7 +36,9 @@ class PromptBuilder:
|
|||||||
|
|
||||||
memory_prompt = ''
|
memory_prompt = ''
|
||||||
start_time = time.time() # 记录开始时间
|
start_time = time.time() # 记录开始时间
|
||||||
topic = topic_identifier.identify_topic_jieba(message_txt)
|
# topic = await topic_identifier.identify_topic_llm(message_txt)
|
||||||
|
topic = topic_identifier.identify_topic_snownlp(message_txt)
|
||||||
|
|
||||||
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
|
# print(f"\033[1;32m[pb主题识别]\033[0m 主题: {topic}")
|
||||||
|
|
||||||
all_first_layer_items = [] # 存储所有第一层记忆
|
all_first_layer_items = [] # 存储所有第一层记忆
|
||||||
@@ -65,14 +67,6 @@ class PromptBuilder:
|
|||||||
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
|
# print(f"\033[1;32m[前额叶]\033[0m 发现主题 '{current_topic}' 和 '{other_topic}' 有共同的第二层记忆: {overlap}")
|
||||||
overlapping_second_layer.update(overlap)
|
overlapping_second_layer.update(overlap)
|
||||||
|
|
||||||
# 合并所有需要的记忆
|
|
||||||
# if all_first_layer_items:
|
|
||||||
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
|
||||||
# if overlapping_second_layer:
|
|
||||||
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
|
||||||
|
|
||||||
# 使用集合去重
|
|
||||||
# 从每个来源随机选择2条记忆(如果有的话)
|
|
||||||
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
|
selected_first_layer = random.sample(all_first_layer_items, min(2, len(all_first_layer_items))) if all_first_layer_items else []
|
||||||
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
|
selected_second_layer = random.sample(list(overlapping_second_layer), min(2, len(overlapping_second_layer))) if overlapping_second_layer else []
|
||||||
|
|
||||||
@@ -147,15 +141,15 @@ class PromptBuilder:
|
|||||||
is_bot_prompt = ''
|
is_bot_prompt = ''
|
||||||
|
|
||||||
#人格选择
|
#人格选择
|
||||||
|
personality=global_config.PROMPT_PERSONALITY
|
||||||
prompt_personality = ''
|
prompt_personality = ''
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
if personality_choice < 4/6: # 第一种人格
|
if personality_choice < 4/6: # 第一种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
||||||
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
||||||
elif personality_choice < 1: # 第二种人格
|
elif personality_choice < 1: # 第二种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt},
|
||||||
|
|
||||||
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
||||||
请你表达自己的见解和观点。可以有个性。'''
|
请你表达自己的见解和观点。可以有个性。'''
|
||||||
|
|
||||||
@@ -170,7 +164,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
|
|
||||||
#额外信息要求
|
#额外信息要求
|
||||||
extra_info = '''但是记得回复平淡一些,简短一些,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -195,23 +189,67 @@ class PromptBuilder:
|
|||||||
prompt_personality_check = ''
|
prompt_personality_check = ''
|
||||||
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
||||||
if personality_choice < 4/6: # 第一种人格
|
if personality_choice < 4/6: # 第一种人格
|
||||||
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧,你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
||||||
elif personality_choice < 1: # 第二种人格
|
elif personality_choice < 1: # 第二种人格
|
||||||
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},是一个女大学生,你有黑色头发,你会刷小红书, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
|
||||||
|
|
||||||
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
|
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
|
||||||
|
|
||||||
return prompt,prompt_check_if_response
|
return prompt,prompt_check_if_response
|
||||||
|
|
||||||
|
def _build_initiative_prompt_select(self,group_id):
|
||||||
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
|
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
|
||||||
|
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
|
||||||
|
|
||||||
|
chat_talking_prompt = ''
|
||||||
|
if group_id:
|
||||||
|
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
||||||
|
|
||||||
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
|
# 获取主动发言的话题
|
||||||
|
all_nodes=memory_graph.dots
|
||||||
|
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes)
|
||||||
|
nodes_for_select=random.sample(all_nodes,5)
|
||||||
|
topics=[info[0] for info in nodes_for_select]
|
||||||
|
infos=[info[1] for info in nodes_for_select]
|
||||||
|
|
||||||
|
#激活prompt构建
|
||||||
|
activate_prompt = ''
|
||||||
|
activate_prompt = f"以上是群里正在进行的聊天。"
|
||||||
|
personality=global_config.PROMPT_PERSONALITY
|
||||||
|
prompt_personality = ''
|
||||||
|
personality_choice = random.random()
|
||||||
|
if personality_choice < 4/6: # 第一种人格
|
||||||
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}'''
|
||||||
|
elif personality_choice < 1: # 第二种人格
|
||||||
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}'''
|
||||||
|
|
||||||
|
topics_str=','.join(f"\"{topics}\"")
|
||||||
|
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
|
||||||
|
|
||||||
|
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
|
||||||
|
prompt_regular=f"{prompt_date}\n{prompt_personality}"
|
||||||
|
|
||||||
|
return prompt_initiative_select,nodes_for_select,prompt_regular
|
||||||
|
|
||||||
|
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
|
||||||
|
memory=random.sample(selected_node['memory_items'],3)
|
||||||
|
memory='\n'.join(memory)
|
||||||
|
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,以这个作为主题发言合适吗?请在把握群里的聊天内容的基础上,综合群内的氛围,如果认为应该发言请输出yes,否则输出no,请注意是决定是否需要发言,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
||||||
|
return prompt_for_check,memory
|
||||||
|
|
||||||
|
def _build_initiative_prompt(self,selected_node,prompt_regular,memory):
|
||||||
|
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
|
||||||
|
return prompt_for_initiative
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_info(self,message:str,threshold:float):
|
def get_prompt_info(self,message:str,threshold:float):
|
||||||
related_info = ''
|
related_info = ''
|
||||||
if len(message) > 10:
|
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
message_segments = [message[i:i+10] for i in range(0, len(message), 10)]
|
|
||||||
for segment in message_segments:
|
|
||||||
embedding = get_embedding(segment)
|
|
||||||
related_info += self.get_info_from_db(embedding,threshold=threshold)
|
|
||||||
|
|
||||||
else:
|
|
||||||
embedding = get_embedding(message)
|
embedding = get_embedding(message)
|
||||||
related_info += self.get_info_from_db(embedding,threshold=threshold)
|
related_info += self.get_info_from_db(embedding,threshold=threshold)
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class MessageStorage:
|
|||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
"time": message.time,
|
"time": message.time,
|
||||||
"user_nickname": message.user_nickname,
|
"user_nickname": message.user_nickname,
|
||||||
|
"user_cardname": message.user_cardname,
|
||||||
"group_name": message.group_name,
|
"group_name": message.group_name,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
@@ -37,6 +38,7 @@ class MessageStorage:
|
|||||||
"processed_plain_text": '[表情包]',
|
"processed_plain_text": '[表情包]',
|
||||||
"time": message.time,
|
"time": message.time,
|
||||||
"user_nickname": message.user_nickname,
|
"user_nickname": message.user_nickname,
|
||||||
|
"user_cardname": message.user_cardname,
|
||||||
"group_name": message.group_name,
|
"group_name": message.group_name,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
|
|||||||
14
src/plugins/chat/thinking_idea.py
Normal file
14
src/plugins/chat/thinking_idea.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#Broca's Area
|
||||||
|
# 功能:语言产生、语法处理和言语运动控制。
|
||||||
|
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Thinking_Idea:
|
||||||
|
def __init__(self, message_id: str):
|
||||||
|
self.messages = [] # 消息列表集合
|
||||||
|
self.current_thoughts = [] # 当前思考内容列表
|
||||||
|
self.time = time.time() # 创建时间
|
||||||
|
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID
|
||||||
|
|
||||||
@@ -4,19 +4,20 @@ from .message import Message
|
|||||||
import jieba
|
import jieba
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from snownlp import SnowNLP
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = OpenAI(
|
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
|
||||||
api_key=config.siliconflow_key,
|
self.select=global_config.topic_extract
|
||||||
base_url=config.siliconflow_base_url
|
|
||||||
)
|
|
||||||
|
|
||||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
|
||||||
"""识别消息主题"""
|
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
||||||
|
"""识别消息主题,返回主题列表"""
|
||||||
|
|
||||||
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:
|
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:
|
||||||
1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。
|
1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。
|
||||||
@@ -24,77 +25,42 @@ class TopicIdentifier:
|
|||||||
|
|
||||||
消息内容:{text}"""
|
消息内容:{text}"""
|
||||||
|
|
||||||
response = self.client.chat.completions.create(
|
# 使用 LLM_request 类进行请求
|
||||||
model=global_config.SILICONFLOW_MODEL_V3,
|
topic, _ = await self.llm_client.generate_response(prompt)
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
temperature=0.8,
|
|
||||||
max_tokens=10
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response or not response.choices:
|
if not topic:
|
||||||
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
|
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
|
# 直接在这里处理主题解析
|
||||||
topic = response.choices[0].message.content.strip() if response.choices[0].message.content else None
|
|
||||||
|
|
||||||
if topic == "无主题":
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# print(f"[主题分析结果]{text[:20]}... : {topic}")
|
|
||||||
split_topic = self.parse_topic(topic)
|
|
||||||
return split_topic
|
|
||||||
|
|
||||||
|
|
||||||
def parse_topic(self, topic: str) -> List[str]:
|
|
||||||
"""解析主题,返回主题列表"""
|
|
||||||
if not topic or topic == "无主题":
|
if not topic or topic == "无主题":
|
||||||
return []
|
return None
|
||||||
return [t.strip() for t in topic.split(",") if t.strip()]
|
|
||||||
|
|
||||||
def identify_topic_jieba(self, text: str) -> Optional[str]:
|
# 解析主题字符串为列表
|
||||||
"""使用jieba识别主题"""
|
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
|
||||||
words = jieba.lcut(text)
|
|
||||||
# 去除停用词和标点符号
|
|
||||||
stop_words = {
|
|
||||||
'的', '了', '和', '是', '就', '都', '而', '及', '与', '这', '那', '但', '然', '却',
|
|
||||||
'因为', '所以', '如果', '虽然', '一个', '我', '你', '他', '她', '它', '我们', '你们',
|
|
||||||
'他们', '在', '有', '个', '把', '被', '让', '给', '从', '向', '到', '又', '也', '很',
|
|
||||||
'啊', '吧', '呢', '吗', '呀', '哦', '哈', '么', '嘛', '啦', '哎', '唉', '哇', '嗯',
|
|
||||||
'哼', '哪', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么',
|
|
||||||
'那么', '多少', '几', '谁', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办',
|
|
||||||
'怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样',
|
|
||||||
'一会儿', '一边', '一起',
|
|
||||||
# 添加更多量词
|
|
||||||
'个', '只', '条', '张', '片', '块', '本', '册', '页', '幅', '面', '篇', '份',
|
|
||||||
'朵', '颗', '粒', '座', '幢', '栋', '间', '层', '家', '户', '位', '名', '群',
|
|
||||||
'双', '对', '打', '副', '套', '批', '组', '串', '包', '箱', '袋', '瓶', '罐',
|
|
||||||
# 添加更多介词
|
|
||||||
'按', '按照', '把', '被', '比', '比如', '除', '除了', '当', '对', '对于',
|
|
||||||
'根据', '关于', '跟', '和', '将', '经', '经过', '靠', '连', '论', '通过',
|
|
||||||
'同', '往', '为', '为了', '围绕', '于', '由', '由于', '与', '在', '沿', '沿着',
|
|
||||||
'依', '依照', '以', '因', '因为', '用', '由', '与', '自', '自从'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 过滤掉停用词和标点符号,只保留名词和动词
|
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
|
||||||
filtered_words = []
|
return topic_list if topic_list else None
|
||||||
for word in words:
|
|
||||||
if word not in stop_words and not word.strip() in {
|
|
||||||
'。', ',', '、', ':', ';', '!', '?', '"', '"', ''', ''',
|
|
||||||
'(', ')', '【', '】', '《', '》', '…', '—', '·', '、', '~',
|
|
||||||
'~', '+', '=', '-'
|
|
||||||
}:
|
|
||||||
filtered_words.append(word)
|
|
||||||
|
|
||||||
# 统计词频
|
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
|
||||||
word_freq = {}
|
"""使用 SnowNLP 进行主题识别
|
||||||
for word in filtered_words:
|
|
||||||
word_freq[word] = word_freq.get(word, 0) + 1
|
|
||||||
|
|
||||||
# 按词频排序,取前3个
|
Args:
|
||||||
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
text (str): 需要识别主题的文本
|
||||||
top_words = [word for word, freq in sorted_words[:3]]
|
|
||||||
|
|
||||||
return top_words if top_words else None
|
Returns:
|
||||||
|
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
|
||||||
|
"""
|
||||||
|
if not text or len(text.strip()) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = SnowNLP(text)
|
||||||
|
# 提取前3个关键词作为主题
|
||||||
|
keywords = s.keywords(5)
|
||||||
|
return keywords if keywords else None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
topic_identifier = TopicIdentifier()
|
topic_identifier = TopicIdentifier()
|
||||||
@@ -10,6 +10,7 @@ from typing import Dict
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import math
|
import math
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -37,6 +38,9 @@ def combine_messages(messages: List[Message]) -> str:
|
|||||||
def db_message_to_str (message_dict: Dict) -> str:
|
def db_message_to_str (message_dict: Dict) -> str:
|
||||||
print(f"message_dict: {message_dict}")
|
print(f"message_dict: {message_dict}")
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||||
|
try:
|
||||||
|
name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", ""))
|
||||||
|
except:
|
||||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||||
content = message_dict.get("processed_plain_text", "")
|
content = message_dict.get("processed_plain_text", "")
|
||||||
result = f"[{time_str}] {name}: {content}\n"
|
result = f"[{time_str}] {name}: {content}\n"
|
||||||
@@ -61,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_embedding(text):
|
def get_embedding(text):
|
||||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
"""获取文本的embedding向量"""
|
||||||
payload = {
|
llm = LLM_request(model=global_config.embedding)
|
||||||
"model": "BAAI/bge-m3",
|
return llm.get_embedding_sync(text)
|
||||||
"input": text,
|
|
||||||
"encoding_format": "float"
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"API请求失败: {response.status_code}")
|
|
||||||
print(f"错误信息: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response.json()['data'][0]['embedding']
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
@@ -89,11 +77,9 @@ def cosine_similarity(v1, v2):
|
|||||||
|
|
||||||
def calculate_information_content(text):
|
def calculate_information_content(text):
|
||||||
"""计算文本的信息量(熵)"""
|
"""计算文本的信息量(熵)"""
|
||||||
# 统计字符频率
|
|
||||||
char_count = Counter(text)
|
char_count = Counter(text)
|
||||||
total_chars = len(text)
|
total_chars = len(text)
|
||||||
|
|
||||||
# 计算熵
|
|
||||||
entropy = 0
|
entropy = 0
|
||||||
for count in char_count.values():
|
for count in char_count.values():
|
||||||
probability = count / total_chars
|
probability = count / total_chars
|
||||||
@@ -102,23 +88,37 @@ def calculate_information_content(text):
|
|||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||||
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
|
||||||
|
|
||||||
if closest_record:
|
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
chat_records = list(db.db.messages.find(
|
||||||
for record in chat_record:
|
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
).sort('time', 1).limit(length))
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
|
||||||
|
# 更新每条消息的memorized属性
|
||||||
|
for record in chat_records:
|
||||||
|
# 检查当前记录的memorized值
|
||||||
|
current_memorized = record.get('memorized', 0)
|
||||||
|
if current_memorized > 3:
|
||||||
|
# print(f"消息已读取3次,跳过")
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# 更新memorized值
|
||||||
|
db.db.messages.update_one(
|
||||||
|
{"_id": record["_id"]},
|
||||||
|
{"$set": {"memorized": current_memorized + 1}}
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_text += record["detailed_plain_text"]
|
||||||
|
|
||||||
return chat_text
|
return chat_text
|
||||||
|
print(f"消息已读取3次,跳过")
|
||||||
return [] # 如果没有找到记录,返回空列表
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||||
"""从数据库获取群组最近的消息记录
|
"""从数据库获取群组最近的消息记录
|
||||||
@@ -135,14 +135,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
# 从数据库获取最近消息
|
# 从数据库获取最近消息
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"group_id": group_id},
|
||||||
{
|
# {
|
||||||
"time": 1,
|
# "time": 1,
|
||||||
"user_id": 1,
|
# "user_id": 1,
|
||||||
"user_nickname": 1,
|
# "user_nickname": 1,
|
||||||
"message_id": 1,
|
# "message_id": 1,
|
||||||
"raw_message": 1,
|
# "raw_message": 1,
|
||||||
"processed_text": 1
|
# "processed_text": 1
|
||||||
}
|
# }
|
||||||
).sort("time", -1).limit(limit))
|
).sort("time", -1).limit(limit))
|
||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
@@ -152,6 +152,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
from .message import Message
|
from .message import Message
|
||||||
message_objects = []
|
message_objects = []
|
||||||
for msg_data in recent_messages:
|
for msg_data in recent_messages:
|
||||||
|
try:
|
||||||
msg = Message(
|
msg = Message(
|
||||||
time=msg_data["time"],
|
time=msg_data["time"],
|
||||||
user_id=msg_data["user_id"],
|
user_id=msg_data["user_id"],
|
||||||
@@ -162,6 +163,9 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
group_id=group_id
|
group_id=group_id
|
||||||
)
|
)
|
||||||
message_objects.append(msg)
|
message_objects.append(msg)
|
||||||
|
except KeyError:
|
||||||
|
print("[WARNING] 数据库中存在无效的消息")
|
||||||
|
continue
|
||||||
|
|
||||||
# 按时间正序排列
|
# 按时间正序排列
|
||||||
message_objects.reverse()
|
message_objects.reverse()
|
||||||
|
|||||||
@@ -6,32 +6,27 @@ import os
|
|||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
import zlib # 用于 CRC32
|
import zlib # 用于 CRC32
|
||||||
import base64
|
import base64
|
||||||
from .config import global_config
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
|
||||||
if type == 'image':
|
|
||||||
return storage_compress_image(image_data, max_size)
|
|
||||||
elif type == 'emoji':
|
|
||||||
return storage_emoji(image_data)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的图片类型: {type}")
|
|
||||||
|
|
||||||
|
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
|
||||||
def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|
||||||
"""
|
"""
|
||||||
压缩图片到指定大小(单位:KB)并在数据库中记录图片信息
|
压缩base64格式的图片到指定大小(单位:KB)并在数据库中记录图片信息
|
||||||
Args:
|
Args:
|
||||||
image_data: 图片字节数据
|
base64_data: base64编码的图片数据
|
||||||
group_id: 群组ID
|
|
||||||
user_id: 用户ID
|
|
||||||
max_size: 最大文件大小(KB)
|
max_size: 最大文件大小(KB)
|
||||||
|
Returns:
|
||||||
|
str: 压缩后的base64图片数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 将base64转换为字节数据
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
# 使用 CRC32 计算哈希值
|
# 使用 CRC32 计算哈希值
|
||||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||||
|
|
||||||
@@ -55,14 +50,14 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|||||||
|
|
||||||
if existing_image:
|
if existing_image:
|
||||||
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
|
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
|
||||||
return image_data
|
return base64_data
|
||||||
|
|
||||||
# 将字节数据转换为图片对象
|
# 将字节数据转换为图片对象
|
||||||
img = Image.open(io.BytesIO(image_data))
|
img = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
# 如果是动图,直接返回原图
|
# 如果是动图,直接返回原图
|
||||||
if getattr(img, 'is_animated', False):
|
if getattr(img, 'is_animated', False):
|
||||||
return image_data
|
return base64_data
|
||||||
|
|
||||||
# 计算当前大小(KB)
|
# 计算当前大小(KB)
|
||||||
current_size = len(image_data) / 1024
|
current_size = len(image_data) / 1024
|
||||||
@@ -128,13 +123,15 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
|||||||
except Exception as db_error:
|
except Exception as db_error:
|
||||||
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
|
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
|
||||||
|
|
||||||
return compressed_data
|
# 将压缩后的数据转换为base64
|
||||||
|
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
|
||||||
|
return compressed_base64
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return image_data
|
return base64_data
|
||||||
|
|
||||||
def storage_emoji(image_data: bytes) -> bytes:
|
def storage_emoji(image_data: bytes) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -216,3 +213,77 @@ def storage_image(image_data: bytes) -> bytes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
|
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||||
|
"""压缩base64格式的图片到指定大小
|
||||||
|
Args:
|
||||||
|
base64_data: base64编码的图片数据
|
||||||
|
target_size: 目标文件大小(字节),默认0.8MB
|
||||||
|
Returns:
|
||||||
|
str: 压缩后的base64图片数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 将base64转换为字节数据
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
# 如果已经小于目标大小,直接返回原图
|
||||||
|
if len(image_data) <= target_size:
|
||||||
|
return base64_data
|
||||||
|
|
||||||
|
# 将字节数据转换为图片对象
|
||||||
|
img = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# 获取原始尺寸
|
||||||
|
original_width, original_height = img.size
|
||||||
|
|
||||||
|
# 计算缩放比例
|
||||||
|
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||||
|
|
||||||
|
# 计算新的尺寸
|
||||||
|
new_width = int(original_width * scale)
|
||||||
|
new_height = int(original_height * scale)
|
||||||
|
|
||||||
|
# 创建内存缓冲区
|
||||||
|
output_buffer = io.BytesIO()
|
||||||
|
|
||||||
|
# 如果是GIF,处理所有帧
|
||||||
|
if getattr(img, "is_animated", False):
|
||||||
|
frames = []
|
||||||
|
for frame_idx in range(img.n_frames):
|
||||||
|
img.seek(frame_idx)
|
||||||
|
new_frame = img.copy()
|
||||||
|
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
frames.append(new_frame)
|
||||||
|
|
||||||
|
# 保存到缓冲区
|
||||||
|
frames[0].save(
|
||||||
|
output_buffer,
|
||||||
|
format='GIF',
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
optimize=True,
|
||||||
|
duration=img.info.get('duration', 100),
|
||||||
|
loop=img.info.get('loop', 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 处理静态图片
|
||||||
|
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# 保存到缓冲区,保持原始格式
|
||||||
|
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||||
|
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||||
|
else:
|
||||||
|
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||||
|
|
||||||
|
# 获取压缩后的数据并转换为base64
|
||||||
|
compressed_data = output_buffer.getvalue()
|
||||||
|
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||||
|
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||||||
|
|
||||||
|
return base64.b64encode(compressed_data).decode('utf-8')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"压缩图片失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return base64_data
|
||||||
@@ -6,3 +6,12 @@ def get_user_nickname(user_id: int) -> str:
|
|||||||
return global_config.BOT_NICKNAME
|
return global_config.BOT_NICKNAME
|
||||||
# print(user_id)
|
# print(user_id)
|
||||||
return relationship_manager.get_name(user_id)
|
return relationship_manager.get_name(user_id)
|
||||||
|
|
||||||
|
def get_user_cardname(user_id: int) -> str:
|
||||||
|
if int(user_id) == int(global_config.BOT_QQ):
|
||||||
|
return global_config.BOT_NICKNAME
|
||||||
|
# print(user_id)
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def get_groupname(group_id: int) -> str:
|
||||||
|
return f"群{group_id}"
|
||||||
@@ -9,9 +9,8 @@ class WillingManager:
|
|||||||
async def _decay_reply_willing(self):
|
async def _decay_reply_willing(self):
|
||||||
"""定期衰减回复意愿"""
|
"""定期衰减回复意愿"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(5)
|
||||||
for group_id in self.group_reply_willing:
|
for group_id in self.group_reply_willing:
|
||||||
# 每分钟衰减10%的回复意愿
|
|
||||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
||||||
|
|
||||||
def get_willing(self, group_id: int) -> float:
|
def get_willing(self, group_id: int) -> float:
|
||||||
@@ -26,13 +25,7 @@ class WillingManager:
|
|||||||
"""改变指定群组的回复意愿并返回回复概率"""
|
"""改变指定群组的回复意愿并返回回复概率"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||||
|
|
||||||
print(f"初始意愿: {current_willing}")
|
# print(f"初始意愿: {current_willing}")
|
||||||
|
|
||||||
# if topic and current_willing < 1:
|
|
||||||
# current_willing += 0.2
|
|
||||||
# elif topic:
|
|
||||||
# current_willing += 0.05
|
|
||||||
|
|
||||||
if is_mentioned_bot and current_willing < 1.0:
|
if is_mentioned_bot and current_willing < 1.0:
|
||||||
current_willing += 0.9
|
current_willing += 0.9
|
||||||
print(f"被提及, 当前意愿: {current_willing}")
|
print(f"被提及, 当前意愿: {current_willing}")
|
||||||
@@ -44,13 +37,13 @@ class WillingManager:
|
|||||||
current_willing *= 0.15
|
current_willing *= 0.15
|
||||||
print(f"表情包, 当前意愿: {current_willing}")
|
print(f"表情包, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
if interested_rate > 0.6:
|
if interested_rate > 0.65:
|
||||||
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||||
current_willing += interested_rate-0.45
|
current_willing += interested_rate-0.6
|
||||||
|
|
||||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
reply_probability = (current_willing - 0.5) * 2
|
reply_probability = max((current_willing - 0.55) * 1.9, 0)
|
||||||
if group_id not in config.talk_allowed_groups:
|
if group_id not in config.talk_allowed_groups:
|
||||||
current_willing = 0
|
current_willing = 0
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
@@ -58,9 +51,9 @@ class WillingManager:
|
|||||||
if group_id in config.talk_frequency_down_groups:
|
if group_id in config.talk_frequency_down_groups:
|
||||||
reply_probability = reply_probability / 3.5
|
reply_probability = reply_probability / 3.5
|
||||||
|
|
||||||
# if is_mentioned_bot and user_id == int(1026294844):
|
reply_probability = min(reply_probability, 1)
|
||||||
# reply_probability = 1
|
if reply_probability < 0:
|
||||||
|
reply_probability = 0
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
def change_reply_willing_sent(self, group_id: int):
|
def change_reply_willing_sent(self, group_id: int):
|
||||||
@@ -72,7 +65,7 @@ class WillingManager:
|
|||||||
"""发送消息后提高群组的回复意愿"""
|
"""发送消息后提高群组的回复意愿"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||||
if current_willing < 1:
|
if current_willing < 1:
|
||||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.3)
|
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
|
||||||
|
|
||||||
async def ensure_started(self):
|
async def ensure_started(self):
|
||||||
"""确保衰减任务已启动"""
|
"""确保衰减任务已启动"""
|
||||||
|
|||||||
@@ -3,26 +3,28 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from nonebot import get_driver
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
sys.path.append(root_path)
|
sys.path.append(root_path)
|
||||||
|
|
||||||
from src.common.database import Database
|
# 加载根目录下的env.edv文件
|
||||||
from src.plugins.chat.config import llm_config
|
env_path = os.path.join(root_path, ".env.dev")
|
||||||
|
if not os.path.exists(env_path):
|
||||||
|
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||||
|
load_dotenv(env_path)
|
||||||
|
|
||||||
# 直接配置数据库连接信息
|
from src.common.database import Database
|
||||||
|
|
||||||
|
# 从环境变量获取配置
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host=os.getenv("MONGODB_HOST", "localhost"),
|
||||||
port= int(config.mongodb_port),
|
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||||
db_name= config.database_name,
|
db_name=os.getenv("DATABASE_NAME", "maimai"),
|
||||||
username= config.mongodb_username,
|
username=os.getenv("MONGODB_USERNAME"),
|
||||||
password= config.mongodb_password,
|
password=os.getenv("MONGODB_PASSWORD"),
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
|
||||||
)
|
)
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
class KnowledgeLibrary:
|
||||||
@@ -30,6 +32,9 @@ class KnowledgeLibrary:
|
|||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self.raw_info_dir = "data/raw_info"
|
self.raw_info_dir = "data/raw_info"
|
||||||
self._ensure_dirs()
|
self._ensure_dirs()
|
||||||
|
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||||
|
|
||||||
def _ensure_dirs(self):
|
def _ensure_dirs(self):
|
||||||
"""确保必要的目录存在"""
|
"""确保必要的目录存在"""
|
||||||
@@ -44,7 +49,7 @@ class KnowledgeLibrary:
|
|||||||
"encoding_format": "float"
|
"encoding_format": "float"
|
||||||
}
|
}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +79,7 @@ class KnowledgeLibrary:
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# 按1024字符分段
|
# 按1024字符分段
|
||||||
segments = [content[i:i+300] for i in range(0, len(content), 300)]
|
segments = [content[i:i+600] for i in range(0, len(content), 600)]
|
||||||
|
|
||||||
# 处理每个分段
|
# 处理每个分段
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import jieba
|
import jieba
|
||||||
from llm_module import LLMModel
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
import math
|
||||||
@@ -10,11 +9,20 @@ from collections import Counter
|
|||||||
import datetime
|
import datetime
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
# from chat.config import global_config
|
from dotenv import load_dotenv
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
|
|
||||||
|
# 加载.env.dev文件
|
||||||
|
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||||
|
load_dotenv(env_path)
|
||||||
|
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
@@ -112,7 +120,11 @@ class Memory_graph:
|
|||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||||
for record in chat_record:
|
for record in chat_record:
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
try:
|
||||||
|
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
|
||||||
|
except:
|
||||||
|
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
|
||||||
|
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||||
return chat_text
|
return chat_text
|
||||||
|
|
||||||
return [] # 如果没有找到记录,返回空列表
|
return [] # 如果没有找到记录,返回空列表
|
||||||
@@ -154,38 +166,32 @@ class Memory_graph:
|
|||||||
def main():
|
def main():
|
||||||
# 初始化数据库
|
# 初始化数据库
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= os.getenv("MONGODB_HOST"),
|
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
username=os.getenv("MONGODB_USERNAME", ""),
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
password=os.getenv("MONGODB_PASSWORD", ""),
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
# 创建LLM模型实例
|
|
||||||
|
|
||||||
memory_graph.load_graph_from_db()
|
memory_graph.load_graph_from_db()
|
||||||
# 展示两种不同的可视化方式
|
|
||||||
print("\n按连接数量着色的图谱:")
|
|
||||||
# visualize_graph(memory_graph, color_by_memory=False)
|
|
||||||
visualize_graph_lite(memory_graph, color_by_memory=False)
|
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
# 只显示一次优化后的图形
|
||||||
# visualize_graph(memory_graph, color_by_memory=True)
|
visualize_graph_lite(memory_graph)
|
||||||
visualize_graph_lite(memory_graph, color_by_memory=True)
|
|
||||||
|
|
||||||
# memory_graph.save_graph_to_db()
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||||
if query.lower() == '退出':
|
if query.lower() == '退出':
|
||||||
break
|
break
|
||||||
items_list = memory_graph.get_related_item(query)
|
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||||
if items_list:
|
if first_layer_items or second_layer_items:
|
||||||
# print(items_list)
|
print("\n第一层记忆:")
|
||||||
for memory_item in items_list:
|
for item in first_layer_items:
|
||||||
print(memory_item)
|
print(item)
|
||||||
|
print("\n第二层记忆:")
|
||||||
|
for item in second_layer_items:
|
||||||
|
print(item)
|
||||||
else:
|
else:
|
||||||
print("未找到相关记忆。")
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
@@ -255,7 +261,7 @@ def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|||||||
nx.draw(G, pos,
|
nx.draw(G, pos,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
node_color=node_colors,
|
node_color=node_colors,
|
||||||
node_size=2000,
|
node_size=200,
|
||||||
font_size=10,
|
font_size=10,
|
||||||
font_family='SimHei',
|
font_family='SimHei',
|
||||||
font_weight='bold')
|
font_weight='bold')
|
||||||
@@ -281,7 +287,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
|||||||
memory_items = H.nodes[node].get('memory_items', [])
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
degree = H.degree(node)
|
degree = H.degree(node)
|
||||||
if memory_count <= 2 or degree <= 2:
|
if memory_count < 5 or degree < 2: # 改为小于2而不是小于等于2
|
||||||
nodes_to_remove.append(node)
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
H.remove_nodes_from(nodes_to_remove)
|
H.remove_nodes_from(nodes_to_remove)
|
||||||
@@ -294,55 +300,55 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
|||||||
# 保存图到本地
|
# 保存图到本地
|
||||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||||
|
|
||||||
# 根据连接条数或记忆数量设置节点颜色
|
# 计算节点大小和颜色
|
||||||
node_colors = []
|
node_colors = []
|
||||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
node_sizes = []
|
||||||
|
nodes = list(H.nodes())
|
||||||
|
|
||||||
if color_by_memory:
|
# 获取最大记忆数和最大度数用于归一化
|
||||||
# 计算每个节点的记忆数量
|
max_memories = 1
|
||||||
memory_counts = []
|
max_degree = 1
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
memory_items = H.nodes[node].get('memory_items', [])
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
if isinstance(memory_items, list):
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
count = len(memory_items)
|
|
||||||
else:
|
|
||||||
count = 1 if memory_items else 0
|
|
||||||
memory_counts.append(count)
|
|
||||||
max_memories = max(memory_counts) if memory_counts else 1
|
|
||||||
|
|
||||||
for count in memory_counts:
|
|
||||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
|
||||||
if max_memories > 0:
|
|
||||||
intensity = min(1.0, count / max_memories)
|
|
||||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
|
||||||
else:
|
|
||||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
|
||||||
node_colors.append(color)
|
|
||||||
else:
|
|
||||||
# 使用原来的连接数量着色方案
|
|
||||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
|
||||||
for node in nodes:
|
|
||||||
degree = H.degree(node)
|
degree = H.degree(node)
|
||||||
if max_degree > 0:
|
max_memories = max(max_memories, memory_count)
|
||||||
|
max_degree = max(max_degree, degree)
|
||||||
|
|
||||||
|
# 计算每个节点的大小和颜色
|
||||||
|
for node in nodes:
|
||||||
|
# 计算节点大小(基于记忆数量)
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
# 使用指数函数使变化更明显
|
||||||
|
ratio = memory_count / max_memories
|
||||||
|
size = 500 + 5000 * (ratio ** 2) # 使用平方函数使差异更明显
|
||||||
|
node_sizes.append(size)
|
||||||
|
|
||||||
|
# 计算节点颜色(基于连接数)
|
||||||
|
degree = H.degree(node)
|
||||||
|
# 红色分量随着度数增加而增加
|
||||||
red = min(1.0, degree / max_degree)
|
red = min(1.0, degree / max_degree)
|
||||||
|
# 蓝色分量随着度数减少而增加
|
||||||
blue = 1.0 - red
|
blue = 1.0 - red
|
||||||
color = (red, 0, blue)
|
color = (red, 0, blue)
|
||||||
else:
|
|
||||||
color = (0, 0, 1)
|
|
||||||
node_colors.append(color)
|
node_colors.append(color)
|
||||||
|
|
||||||
# 绘制图形
|
# 绘制图形
|
||||||
plt.figure(figsize=(12, 8))
|
plt.figure(figsize=(12, 8))
|
||||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
pos = nx.spring_layout(H, k=1.5, iterations=50) # 增加k值使节点分布更开
|
||||||
nx.draw(H, pos,
|
nx.draw(H, pos,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
node_color=node_colors,
|
node_color=node_colors,
|
||||||
node_size=2000,
|
node_size=node_sizes,
|
||||||
font_size=10,
|
font_size=10,
|
||||||
font_family='SimHei',
|
font_family='SimHei',
|
||||||
font_weight='bold')
|
font_weight='bold',
|
||||||
|
edge_color='gray',
|
||||||
|
width=0.5,
|
||||||
|
alpha=0.7)
|
||||||
|
|
||||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
from typing import Tuple, Union
|
|
||||||
import time
|
|
||||||
from nonebot import get_driver
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
from src.plugins.chat.config import BotConfig, global_config
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
class LLMModel:
|
|
||||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
|
||||||
def __init__(self, model_name=global_config.SILICONFLOW_MODEL_V3, **kwargs):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.params = kwargs
|
|
||||||
self.api_key = config.siliconflow_key
|
|
||||||
self.base_url = config.siliconflow_base_url
|
|
||||||
|
|
||||||
if not self.api_key or not self.base_url:
|
|
||||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
|
||||||
|
|
||||||
print(f"API URL: {self.base_url}") # 打印 base_url 用于调试
|
|
||||||
|
|
||||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
|
||||||
"""根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
"temperature": 0.5,
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
|
|
||||||
max_retries = 3
|
|
||||||
base_wait_time = 15 # 基础等待时间(秒)
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
return f"请求失败: {str(e)}", ""
|
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
|
||||||
@@ -9,15 +9,26 @@ import random
|
|||||||
import time
|
import time
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
|
import math
|
||||||
|
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Memory_graph:
|
class Memory_graph:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
def connect_dot(self, concept1, concept2):
|
def connect_dot(self, concept1, concept2):
|
||||||
self.G.add_edge(concept1, concept2)
|
# 如果边已存在,增加 strength
|
||||||
|
if self.G.has_edge(concept1, concept2):
|
||||||
|
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||||
|
else:
|
||||||
|
# 如果是新边,初始化 strength 为 1
|
||||||
|
self.G.add_edge(concept1, concept2, strength=1)
|
||||||
|
|
||||||
def add_dot(self, concept, memory):
|
def add_dot(self, concept, memory):
|
||||||
if concept in self.G:
|
if concept in self.G:
|
||||||
@@ -38,8 +49,6 @@ class Memory_graph:
|
|||||||
if concept in self.G:
|
if concept in self.G:
|
||||||
# 从图中获取节点数据
|
# 从图中获取节点数据
|
||||||
node_data = self.G.nodes[concept]
|
node_data = self.G.nodes[concept]
|
||||||
# print(node_data)
|
|
||||||
# 创建新的Memory_dot对象
|
|
||||||
return concept, node_data
|
return concept, node_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -52,7 +61,6 @@ class Memory_graph:
|
|||||||
|
|
||||||
# 获取相邻节点
|
# 获取相邻节点
|
||||||
neighbors = list(self.G.neighbors(topic))
|
neighbors = list(self.G.neighbors(topic))
|
||||||
# print(f"第一层: {topic}")
|
|
||||||
|
|
||||||
# 获取当前节点的记忆项
|
# 获取当前节点的记忆项
|
||||||
node_data = self.get_dot(topic)
|
node_data = self.get_dot(topic)
|
||||||
@@ -69,7 +77,6 @@ class Memory_graph:
|
|||||||
if depth >= 2:
|
if depth >= 2:
|
||||||
# 获取相邻节点的记忆项
|
# 获取相邻节点的记忆项
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
# print(f"第二层: {neighbor}")
|
|
||||||
node_data = self.get_dot(neighbor)
|
node_data = self.get_dot(neighbor)
|
||||||
if node_data:
|
if node_data:
|
||||||
concept, data = node_data
|
concept, data = node_data
|
||||||
@@ -87,87 +94,59 @@ class Memory_graph:
|
|||||||
# 返回所有节点对应的 Memory_dot 对象
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
return [self.get_dot(node) for node in self.G.nodes()]
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
def save_graph_to_db(self):
|
def forget_topic(self, topic):
|
||||||
# 保存节点
|
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
|
||||||
for node in self.G.nodes(data=True):
|
if topic not in self.G:
|
||||||
concept = node[0]
|
return None
|
||||||
memory_items = node[1].get('memory_items', [])
|
|
||||||
|
|
||||||
# 查找是否存在同名节点
|
# 获取话题节点数据
|
||||||
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
node_data = self.G.nodes[topic]
|
||||||
if existing_node:
|
|
||||||
# 如果存在,合并memory_items并去重
|
|
||||||
existing_items = existing_node.get('memory_items', [])
|
|
||||||
if not isinstance(existing_items, list):
|
|
||||||
existing_items = [existing_items] if existing_items else []
|
|
||||||
|
|
||||||
# 合并并去重
|
# 如果节点存在memory_items
|
||||||
all_items = list(set(existing_items + memory_items))
|
if 'memory_items' in node_data:
|
||||||
|
memory_items = node_data['memory_items']
|
||||||
|
|
||||||
# 更新节点
|
# 确保memory_items是列表
|
||||||
self.db.db.graph_data.nodes.update_one(
|
|
||||||
{'concept': concept},
|
|
||||||
{'$set': {'memory_items': all_items}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果不存在,创建新节点
|
|
||||||
node_data = {
|
|
||||||
'concept': concept,
|
|
||||||
'memory_items': memory_items
|
|
||||||
}
|
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
|
||||||
|
|
||||||
# 保存边
|
|
||||||
for edge in self.G.edges():
|
|
||||||
source, target = edge
|
|
||||||
|
|
||||||
# 查找是否存在同样的边
|
|
||||||
existing_edge = self.db.db.graph_data.edges.find_one({
|
|
||||||
'source': source,
|
|
||||||
'target': target
|
|
||||||
})
|
|
||||||
|
|
||||||
if existing_edge:
|
|
||||||
# 如果存在,增加num属性
|
|
||||||
num = existing_edge.get('num', 1) + 1
|
|
||||||
self.db.db.graph_data.edges.update_one(
|
|
||||||
{'source': source, 'target': target},
|
|
||||||
{'$set': {'num': num}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果不存在,创建新边
|
|
||||||
edge_data = {
|
|
||||||
'source': source,
|
|
||||||
'target': target,
|
|
||||||
'num': 1
|
|
||||||
}
|
|
||||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
|
||||||
|
|
||||||
def load_graph_from_db(self):
|
|
||||||
# 清空当前图
|
|
||||||
self.G.clear()
|
|
||||||
# 加载节点
|
|
||||||
nodes = self.db.db.graph_data.nodes.find()
|
|
||||||
for node in nodes:
|
|
||||||
memory_items = node.get('memory_items', [])
|
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
|
||||||
# 加载边
|
|
||||||
edges = self.db.db.graph_data.edges.find()
|
|
||||||
for edge in edges:
|
|
||||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
|
||||||
|
|
||||||
|
# 如果有记忆项可以删除
|
||||||
|
if memory_items:
|
||||||
|
# 随机选择一个记忆项删除
|
||||||
|
removed_item = random.choice(memory_items)
|
||||||
|
memory_items.remove(removed_item)
|
||||||
|
|
||||||
|
# 更新节点的记忆项
|
||||||
|
if memory_items:
|
||||||
|
self.G.nodes[topic]['memory_items'] = memory_items
|
||||||
|
else:
|
||||||
|
# 如果没有记忆项了,删除整个节点
|
||||||
|
self.G.remove_node(topic)
|
||||||
|
|
||||||
|
return removed_item
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# 海马体
|
# 海马体
|
||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self,memory_graph:Memory_graph):
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
self.memory_graph = memory_graph
|
self.memory_graph = memory_graph
|
||||||
self.llm_model = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
||||||
self.llm_model_small = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
||||||
|
|
||||||
|
def calculate_node_hash(self, concept, memory_items):
|
||||||
|
"""计算节点的特征值"""
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
sorted_items = sorted(memory_items)
|
||||||
|
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||||
|
return hash(content)
|
||||||
|
|
||||||
|
def calculate_edge_hash(self, source, target):
|
||||||
|
"""计算边的特征值"""
|
||||||
|
nodes = sorted([source, target])
|
||||||
|
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||||
|
|
||||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
@@ -175,82 +154,340 @@ class Hippocampus:
|
|||||||
#短期:1h 中期:4h 长期:24h
|
#短期:1h 中期:4h 长期:24h
|
||||||
for _ in range(time_frequency.get('near')): # 循环10次
|
for _ in range(time_frequency.get('near')): # 循环10次
|
||||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
||||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
chat_text.append(chat_)
|
chat_text.append(chat_)
|
||||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
||||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
chat_text.append(chat_)
|
chat_text.append(chat_)
|
||||||
for _ in range(time_frequency.get('far')): # 循环10次
|
for _ in range(time_frequency.get('far')): # 循环10次
|
||||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||||
# print(f"获得 最近 随机时间戳对应的时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(random_time))}")
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
chat_text.append(chat_)
|
chat_text.append(chat_)
|
||||||
return chat_text
|
return [text for text in chat_text if text]
|
||||||
|
|
||||||
async def memory_compress(self, input_text, rate=1):
|
async def memory_compress(self, input_text, compress_rate=0.1):
|
||||||
information_content = calculate_information_content(input_text)
|
print(input_text)
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
#获取topics
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
||||||
topic_response = await self.llm_model.generate_response(topic_prompt)
|
topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
|
||||||
# 检查 topic_response 是否为元组
|
# 修改话题处理逻辑
|
||||||
if isinstance(topic_response, tuple):
|
print(f"话题: {topics_response[0]}")
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||||
else:
|
print(f"话题: {topics}")
|
||||||
topics = topic_response.split(",")
|
|
||||||
compressed_memory = set()
|
# 创建所有话题的请求任务
|
||||||
|
tasks = []
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
topic_what_prompt = self.topic_what(input_text, topic)
|
||||||
topic_what_response = await self.llm_model_small.generate_response(topic_what_prompt)
|
# 创建异步任务
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
task = self.llm_model_summary.generate_response_async(topic_what_prompt)
|
||||||
|
tasks.append((topic.strip(), task))
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
compressed_memory = set()
|
||||||
|
for topic, task in tasks:
|
||||||
|
response = await task
|
||||||
|
if response:
|
||||||
|
compressed_memory.add((topic, response[0]))
|
||||||
|
|
||||||
return compressed_memory
|
return compressed_memory
|
||||||
|
|
||||||
async def build_memory(self,chat_size=12):
|
def calculate_topic_num(self,text, compress_rate):
|
||||||
|
"""计算文本的话题数量"""
|
||||||
|
information_content = calculate_information_content(text)
|
||||||
|
topic_by_length = text.count('\n')*compress_rate
|
||||||
|
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
|
||||||
|
topic_num = int((topic_by_length + topic_by_information_content)/2)
|
||||||
|
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
|
||||||
|
return topic_num
|
||||||
|
|
||||||
|
async def operation_build_memory(self,chat_size=20):
|
||||||
# 最近消息获取频率
|
# 最近消息获取频率
|
||||||
time_frequency = {'near':1,'mid':2,'far':2}
|
time_frequency = {'near':2,'mid':4,'far':2}
|
||||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
||||||
# print(f"\033[1;32m[记忆构建]\033[0m 获取记忆样本: {memory_sample}")
|
|
||||||
for i, input_text in enumerate(memory_sample, 1):
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
# 加载进度可视化
|
# 加载进度可视化
|
||||||
|
all_topics = []
|
||||||
progress = (i / len(memory_sample)) * 100
|
progress = (i / len(memory_sample)) * 100
|
||||||
bar_length = 30
|
bar_length = 30
|
||||||
filled_length = int(bar_length * i // len(memory_sample))
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
if input_text:
|
|
||||||
# 生成压缩后记忆
|
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
|
||||||
first_memory = set()
|
compressed_memory = set()
|
||||||
first_memory = await self.memory_compress(input_text, 2.5)
|
compress_rate = 0.1
|
||||||
|
compressed_memory = await self.memory_compress(input_text, compress_rate)
|
||||||
|
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
|
||||||
|
|
||||||
# 将记忆加入到图谱中
|
# 将记忆加入到图谱中
|
||||||
for topic, memory in first_memory:
|
for topic, memory in compressed_memory:
|
||||||
topics = segment_text(topic)
|
print(f"\033[1;32m添加节点\033[0m: {topic}")
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
self.memory_graph.add_dot(topic, memory)
|
||||||
for split_topic in topics:
|
all_topics.append(topic) # 收集所有话题
|
||||||
self.memory_graph.add_dot(split_topic,memory)
|
for i in range(len(all_topics)):
|
||||||
for split_topic in topics:
|
for j in range(i + 1, len(all_topics)):
|
||||||
for other_split_topic in topics:
|
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]} 和 {all_topics[j]}")
|
||||||
if split_topic != other_split_topic:
|
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
|
||||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
|
||||||
|
def sync_memory_to_db(self):
|
||||||
|
"""检查并同步内存中的图结构与数据库"""
|
||||||
|
# 获取数据库中所有节点和内存中所有节点
|
||||||
|
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||||
|
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
|
|
||||||
|
# 转换数据库节点为字典格式,方便查找
|
||||||
|
db_nodes_dict = {node['concept']: node for node in db_nodes}
|
||||||
|
|
||||||
|
# 检查并更新节点
|
||||||
|
for concept, data in memory_nodes:
|
||||||
|
memory_items = data.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 计算内存中节点的特征值
|
||||||
|
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||||
|
|
||||||
|
if concept not in db_nodes_dict:
|
||||||
|
# 数据库中缺少的节点,添加
|
||||||
|
node_data = {
|
||||||
|
'concept': concept,
|
||||||
|
'memory_items': memory_items,
|
||||||
|
'hash': memory_hash
|
||||||
|
}
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
else:
|
else:
|
||||||
print(f"空消息 跳过")
|
# 获取数据库中节点的特征值
|
||||||
self.memory_graph.save_graph_to_db()
|
db_node = db_nodes_dict[concept]
|
||||||
|
db_hash = db_node.get('hash', None)
|
||||||
|
|
||||||
|
# 如果特征值不同,则更新节点
|
||||||
|
if db_hash != memory_hash:
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||||
|
{'concept': concept},
|
||||||
|
{'$set': {
|
||||||
|
'memory_items': memory_items,
|
||||||
|
'hash': memory_hash
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查并删除数据库中多余的节点
|
||||||
|
memory_concepts = set(node[0] for node in memory_nodes)
|
||||||
|
for db_node in db_nodes:
|
||||||
|
if db_node['concept'] not in memory_concepts:
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||||
|
|
||||||
|
# 处理边的信息
|
||||||
|
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||||
|
memory_edges = list(self.memory_graph.G.edges())
|
||||||
|
|
||||||
|
# 创建边的哈希值字典
|
||||||
|
db_edge_dict = {}
|
||||||
|
for edge in db_edges:
|
||||||
|
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
|
||||||
|
db_edge_dict[(edge['source'], edge['target'])] = {
|
||||||
|
'hash': edge_hash,
|
||||||
|
'strength': edge.get('strength', 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查并更新边
|
||||||
|
for source, target in memory_edges:
|
||||||
|
edge_hash = self.calculate_edge_hash(source, target)
|
||||||
|
edge_key = (source, target)
|
||||||
|
strength = self.memory_graph.G[source][target].get('strength', 1)
|
||||||
|
|
||||||
|
if edge_key not in db_edge_dict:
|
||||||
|
# 添加新边
|
||||||
|
edge_data = {
|
||||||
|
'source': source,
|
||||||
|
'target': target,
|
||||||
|
'strength': strength,
|
||||||
|
'hash': edge_hash
|
||||||
|
}
|
||||||
|
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
else:
|
||||||
|
# 检查边的特征值是否变化
|
||||||
|
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||||
|
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||||
|
{'source': source, 'target': target},
|
||||||
|
{'$set': {
|
||||||
|
'hash': edge_hash,
|
||||||
|
'strength': strength
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除多余的边
|
||||||
|
memory_edge_set = set(memory_edges)
|
||||||
|
for edge_key in db_edge_dict:
|
||||||
|
if edge_key not in memory_edge_set:
|
||||||
|
source, target = edge_key
|
||||||
|
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||||
|
'source': source,
|
||||||
|
'target': target
|
||||||
|
})
|
||||||
|
|
||||||
|
def sync_memory_from_db(self):
|
||||||
|
"""从数据库同步数据到内存中的图结构"""
|
||||||
|
# 清空当前图
|
||||||
|
self.memory_graph.G.clear()
|
||||||
|
|
||||||
|
# 从数据库加载所有节点
|
||||||
|
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||||
|
for node in nodes:
|
||||||
|
concept = node['concept']
|
||||||
|
memory_items = node.get('memory_items', [])
|
||||||
|
# 确保memory_items是列表
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
# 添加节点到图中
|
||||||
|
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||||
|
|
||||||
|
# 从数据库加载所有边
|
||||||
|
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||||
|
for edge in edges:
|
||||||
|
source = edge['source']
|
||||||
|
target = edge['target']
|
||||||
|
strength = edge.get('strength', 1) # 获取 strength,默认为 1
|
||||||
|
# 只有当源节点和目标节点都存在时才添加边
|
||||||
|
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||||
|
self.memory_graph.G.add_edge(source, target, strength=strength)
|
||||||
|
|
||||||
|
async def operation_forget_topic(self, percentage=0.1):
|
||||||
|
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
|
||||||
|
# 获取所有节点
|
||||||
|
all_nodes = list(self.memory_graph.G.nodes())
|
||||||
|
# 计算要检查的节点数量
|
||||||
|
check_count = max(1, int(len(all_nodes) * percentage))
|
||||||
|
# 随机选择节点
|
||||||
|
nodes_to_check = random.sample(all_nodes, check_count)
|
||||||
|
|
||||||
|
forgotten_nodes = []
|
||||||
|
for node in nodes_to_check:
|
||||||
|
# 获取节点的连接数
|
||||||
|
connections = self.memory_graph.G.degree(node)
|
||||||
|
|
||||||
|
# 获取节点的内容条数
|
||||||
|
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
content_count = len(memory_items)
|
||||||
|
|
||||||
|
# 检查连接强度
|
||||||
|
weak_connections = True
|
||||||
|
if connections > 1: # 只有当连接数大于1时才检查强度
|
||||||
|
for neighbor in self.memory_graph.G.neighbors(node):
|
||||||
|
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
|
||||||
|
if strength > 2:
|
||||||
|
weak_connections = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果满足遗忘条件
|
||||||
|
if (connections <= 1 and weak_connections) or content_count <= 2:
|
||||||
|
removed_item = self.memory_graph.forget_topic(node)
|
||||||
|
if removed_item:
|
||||||
|
forgotten_nodes.append((node, removed_item))
|
||||||
|
print(f"遗忘节点 {node} 的记忆: {removed_item}")
|
||||||
|
|
||||||
|
# 同步到数据库
|
||||||
|
if forgotten_nodes:
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
||||||
|
else:
|
||||||
|
print("本次检查没有节点满足遗忘条件")
|
||||||
|
|
||||||
|
async def merge_memory(self, topic):
|
||||||
|
"""
|
||||||
|
对指定话题的记忆进行合并压缩
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 要合并的话题节点
|
||||||
|
"""
|
||||||
|
# 获取节点的记忆项
|
||||||
|
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 如果记忆项不足,直接返回
|
||||||
|
if len(memory_items) < 10:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 随机选择10条记忆
|
||||||
|
selected_memories = random.sample(memory_items, 10)
|
||||||
|
|
||||||
|
# 拼接成文本
|
||||||
|
merged_text = "\n".join(selected_memories)
|
||||||
|
print(f"\n[合并记忆] 话题: {topic}")
|
||||||
|
print(f"选择的记忆:\n{merged_text}")
|
||||||
|
|
||||||
|
# 使用memory_compress生成新的压缩记忆
|
||||||
|
compressed_memories = await self.memory_compress(merged_text, 0.1)
|
||||||
|
|
||||||
|
# 从原记忆列表中移除被选中的记忆
|
||||||
|
for memory in selected_memories:
|
||||||
|
memory_items.remove(memory)
|
||||||
|
|
||||||
|
# 添加新的压缩记忆
|
||||||
|
for _, compressed_memory in compressed_memories:
|
||||||
|
memory_items.append(compressed_memory)
|
||||||
|
print(f"添加压缩记忆: {compressed_memory}")
|
||||||
|
|
||||||
|
# 更新节点的记忆项
|
||||||
|
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||||
|
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||||
|
|
||||||
|
async def operation_merge_memory(self, percentage=0.1):
|
||||||
|
"""
|
||||||
|
随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
|
||||||
|
|
||||||
|
Args:
|
||||||
|
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||||
|
"""
|
||||||
|
# 获取所有节点
|
||||||
|
all_nodes = list(self.memory_graph.G.nodes())
|
||||||
|
# 计算要检查的节点数量
|
||||||
|
check_count = max(1, int(len(all_nodes) * percentage))
|
||||||
|
# 随机选择节点
|
||||||
|
nodes_to_check = random.sample(all_nodes, check_count)
|
||||||
|
|
||||||
|
merged_nodes = []
|
||||||
|
for node in nodes_to_check:
|
||||||
|
# 获取节点的内容条数
|
||||||
|
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
content_count = len(memory_items)
|
||||||
|
|
||||||
|
# 如果内容数量超过100,进行合并
|
||||||
|
if content_count > 100:
|
||||||
|
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
|
||||||
|
await self.merge_memory(node)
|
||||||
|
merged_nodes.append(node)
|
||||||
|
|
||||||
|
# 同步到数据库
|
||||||
|
if merged_nodes:
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
||||||
|
else:
|
||||||
|
print("\n本次检查没有需要合并的节点")
|
||||||
|
|
||||||
|
def find_topic_llm(self,text, topic_num):
|
||||||
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def topic_what(self,text, topic):
|
||||||
|
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
return seg_text
|
return seg_text
|
||||||
|
|
||||||
def find_topic(text, topic_num):
|
|
||||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def topic_what(text, topic):
|
|
||||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
@@ -259,19 +496,19 @@ config = driver.config
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= config.MONGODB_PORT,
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
#创建记忆图
|
#创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
#加载数据库中存储的记忆图
|
|
||||||
memory_graph.load_graph_from_db()
|
|
||||||
#创建海马体
|
#创建海马体
|
||||||
hippocampus = Hippocampus(memory_graph)
|
hippocampus = Hippocampus(memory_graph)
|
||||||
|
#从数据库加载记忆图
|
||||||
|
hippocampus.sync_memory_from_db()
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
@@ -1,459 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import sys
|
|
||||||
import jieba
|
|
||||||
import networkx as nx
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import math
|
|
||||||
from collections import Counter
|
|
||||||
import datetime
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
# from chat.config import global_config
|
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
|
||||||
from src.plugins.memory_system.llm_module import LLMModel
|
|
||||||
|
|
||||||
def calculate_information_content(text):
|
|
||||||
"""计算文本的信息量(熵)"""
|
|
||||||
# 统计字符频率
|
|
||||||
char_count = Counter(text)
|
|
||||||
total_chars = len(text)
|
|
||||||
|
|
||||||
# 计算熵
|
|
||||||
entropy = 0
|
|
||||||
for count in char_count.values():
|
|
||||||
probability = count / total_chars
|
|
||||||
entropy -= probability * math.log2(probability)
|
|
||||||
|
|
||||||
return entropy
|
|
||||||
|
|
||||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
|
||||||
"""从数据库中获取最接近指定时间戳的聊天记录"""
|
|
||||||
chat_text = ''
|
|
||||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
|
||||||
|
|
||||||
if closest_record:
|
|
||||||
closest_time = closest_record['time']
|
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
|
||||||
chat_record = list(db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
|
||||||
for record in chat_record:
|
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n'
|
|
||||||
return chat_text
|
|
||||||
|
|
||||||
return ''
|
|
||||||
|
|
||||||
class Memory_graph:
|
|
||||||
def __init__(self):
|
|
||||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
|
||||||
self.db = Database.get_instance()
|
|
||||||
|
|
||||||
def connect_dot(self, concept1, concept2):
|
|
||||||
self.G.add_edge(concept1, concept2)
|
|
||||||
|
|
||||||
def add_dot(self, concept, memory):
|
|
||||||
if concept in self.G:
|
|
||||||
# 如果节点已存在,将新记忆添加到现有列表中
|
|
||||||
if 'memory_items' in self.G.nodes[concept]:
|
|
||||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
|
||||||
# 如果当前不是列表,将其转换为列表
|
|
||||||
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
|
||||||
self.G.nodes[concept]['memory_items'].append(memory)
|
|
||||||
else:
|
|
||||||
self.G.nodes[concept]['memory_items'] = [memory]
|
|
||||||
else:
|
|
||||||
# 如果是新节点,创建新的记忆列表
|
|
||||||
self.G.add_node(concept, memory_items=[memory])
|
|
||||||
|
|
||||||
def get_dot(self, concept):
|
|
||||||
# 检查节点是否存在于图中
|
|
||||||
if concept in self.G:
|
|
||||||
# 从图中获取节点数据
|
|
||||||
node_data = self.G.nodes[concept]
|
|
||||||
# print(node_data)
|
|
||||||
# 创建新的Memory_dot对象
|
|
||||||
return concept,node_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_related_item(self, topic, depth=1):
|
|
||||||
if topic not in self.G:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
first_layer_items = []
|
|
||||||
second_layer_items = []
|
|
||||||
|
|
||||||
# 获取相邻节点
|
|
||||||
neighbors = list(self.G.neighbors(topic))
|
|
||||||
# print(f"第一层: {topic}")
|
|
||||||
|
|
||||||
# 获取当前节点的记忆项
|
|
||||||
node_data = self.get_dot(topic)
|
|
||||||
if node_data:
|
|
||||||
concept, data = node_data
|
|
||||||
if 'memory_items' in data:
|
|
||||||
memory_items = data['memory_items']
|
|
||||||
if isinstance(memory_items, list):
|
|
||||||
first_layer_items.extend(memory_items)
|
|
||||||
else:
|
|
||||||
first_layer_items.append(memory_items)
|
|
||||||
|
|
||||||
# 只在depth=2时获取第二层记忆
|
|
||||||
if depth >= 2:
|
|
||||||
# 获取相邻节点的记忆项
|
|
||||||
for neighbor in neighbors:
|
|
||||||
# print(f"第二层: {neighbor}")
|
|
||||||
node_data = self.get_dot(neighbor)
|
|
||||||
if node_data:
|
|
||||||
concept, data = node_data
|
|
||||||
if 'memory_items' in data:
|
|
||||||
memory_items = data['memory_items']
|
|
||||||
if isinstance(memory_items, list):
|
|
||||||
second_layer_items.extend(memory_items)
|
|
||||||
else:
|
|
||||||
second_layer_items.append(memory_items)
|
|
||||||
|
|
||||||
return first_layer_items, second_layer_items
|
|
||||||
|
|
||||||
def store_memory(self):
|
|
||||||
for node in self.G.nodes():
|
|
||||||
dot_data = {
|
|
||||||
"concept": node
|
|
||||||
}
|
|
||||||
self.db.db.store_memory_dots.insert_one(dot_data)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dots(self):
|
|
||||||
# 返回所有节点对应的 Memory_dot 对象
|
|
||||||
return [self.get_dot(node) for node in self.G.nodes()]
|
|
||||||
|
|
||||||
|
|
||||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
|
||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
|
||||||
chat_text = ''
|
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
|
||||||
|
|
||||||
# print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
|
||||||
|
|
||||||
if closest_record:
|
|
||||||
closest_time = closest_record['time']
|
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
|
||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
|
||||||
for record in chat_record:
|
|
||||||
if record:
|
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
|
||||||
chat_text += f'[{time_str}] {record["user_nickname"] or "用户" + str(record["user_id"])}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
|
||||||
return chat_text
|
|
||||||
|
|
||||||
return [] # 如果没有找到记录,返回空列表
|
|
||||||
|
|
||||||
def save_graph_to_db(self):
|
|
||||||
# 保存节点
|
|
||||||
for node in self.G.nodes(data=True):
|
|
||||||
concept = node[0]
|
|
||||||
memory_items = node[1].get('memory_items', [])
|
|
||||||
|
|
||||||
# 查找是否存在同名节点
|
|
||||||
existing_node = self.db.db.graph_data.nodes.find_one({'concept': concept})
|
|
||||||
if existing_node:
|
|
||||||
# 如果存在,合并memory_items并去重
|
|
||||||
existing_items = existing_node.get('memory_items', [])
|
|
||||||
if not isinstance(existing_items, list):
|
|
||||||
existing_items = [existing_items] if existing_items else []
|
|
||||||
|
|
||||||
# 合并并去重
|
|
||||||
all_items = list(set(existing_items + memory_items))
|
|
||||||
|
|
||||||
# 更新节点
|
|
||||||
self.db.db.graph_data.nodes.update_one(
|
|
||||||
{'concept': concept},
|
|
||||||
{'$set': {'memory_items': all_items}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果不存在,创建新节点
|
|
||||||
node_data = {
|
|
||||||
'concept': concept,
|
|
||||||
'memory_items': memory_items
|
|
||||||
}
|
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
|
||||||
|
|
||||||
# 保存边
|
|
||||||
for edge in self.G.edges():
|
|
||||||
source, target = edge
|
|
||||||
|
|
||||||
# 查找是否存在同样的边
|
|
||||||
existing_edge = self.db.db.graph_data.edges.find_one({
|
|
||||||
'source': source,
|
|
||||||
'target': target
|
|
||||||
})
|
|
||||||
|
|
||||||
if existing_edge:
|
|
||||||
# 如果存在,增加num属性
|
|
||||||
num = existing_edge.get('num', 1) + 1
|
|
||||||
self.db.db.graph_data.edges.update_one(
|
|
||||||
{'source': source, 'target': target},
|
|
||||||
{'$set': {'num': num}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果不存在,创建新边
|
|
||||||
edge_data = {
|
|
||||||
'source': source,
|
|
||||||
'target': target,
|
|
||||||
'num': 1
|
|
||||||
}
|
|
||||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
|
||||||
|
|
||||||
def load_graph_from_db(self):
|
|
||||||
# 清空当前图
|
|
||||||
self.G.clear()
|
|
||||||
# 加载节点
|
|
||||||
nodes = self.db.db.graph_data.nodes.find()
|
|
||||||
for node in nodes:
|
|
||||||
memory_items = node.get('memory_items', [])
|
|
||||||
if not isinstance(memory_items, list):
|
|
||||||
memory_items = [memory_items] if memory_items else []
|
|
||||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
|
||||||
# 加载边
|
|
||||||
edges = self.db.db.graph_data.edges.find()
|
|
||||||
for edge in edges:
|
|
||||||
self.G.add_edge(edge['source'], edge['target'], num=edge.get('num', 1))
|
|
||||||
|
|
||||||
# 海马体
|
|
||||||
class Hippocampus:
|
|
||||||
def __init__(self,memory_graph:Memory_graph):
|
|
||||||
self.memory_graph = memory_graph
|
|
||||||
self.llm_model = LLMModel()
|
|
||||||
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
|
||||||
|
|
||||||
def get_memory_sample(self,chat_size=20,time_frequency:dict={'near':2,'mid':4,'far':3}):
|
|
||||||
current_timestamp = datetime.datetime.now().timestamp()
|
|
||||||
chat_text = []
|
|
||||||
#短期:1h 中期:4h 长期:24h
|
|
||||||
for _ in range(time_frequency.get('near')): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(1, 3600) # 随机时间
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
|
||||||
chat_text.append(chat_)
|
|
||||||
for _ in range(time_frequency.get('mid')): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(3600, 3600*4) # 随机时间
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
|
||||||
chat_text.append(chat_)
|
|
||||||
for _ in range(time_frequency.get('far')): # 循环10次
|
|
||||||
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
|
||||||
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
|
||||||
chat_text.append(chat_)
|
|
||||||
return chat_text
|
|
||||||
|
|
||||||
def build_memory(self,chat_size=12):
|
|
||||||
#最近消息获取频率
|
|
||||||
time_frequency = {'near':1,'mid':2,'far':2}
|
|
||||||
memory_sample = self.get_memory_sample(chat_size,time_frequency)
|
|
||||||
|
|
||||||
#加载进度可视化
|
|
||||||
for i, input_text in enumerate(memory_sample, 1):
|
|
||||||
progress = (i / len(memory_sample)) * 100
|
|
||||||
bar_length = 30
|
|
||||||
filled_length = int(bar_length * i // len(memory_sample))
|
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
|
||||||
# print(f"第{i}条消息: {input_text}")
|
|
||||||
if input_text:
|
|
||||||
# 生成压缩后记忆
|
|
||||||
first_memory = set()
|
|
||||||
first_memory = self.memory_compress(input_text, 2.5)
|
|
||||||
#将记忆加入到图谱中
|
|
||||||
for topic, memory in first_memory:
|
|
||||||
topics = segment_text(topic)
|
|
||||||
print(f"\033[1;34m话题\033[0m: {topic},节点: {topics}, 记忆: {memory}")
|
|
||||||
for split_topic in topics:
|
|
||||||
self.memory_graph.add_dot(split_topic,memory)
|
|
||||||
for split_topic in topics:
|
|
||||||
for other_split_topic in topics:
|
|
||||||
if split_topic != other_split_topic:
|
|
||||||
self.memory_graph.connect_dot(split_topic, other_split_topic)
|
|
||||||
else:
|
|
||||||
print(f"空消息 跳过")
|
|
||||||
|
|
||||||
self.memory_graph.save_graph_to_db()
|
|
||||||
|
|
||||||
def memory_compress(self, input_text, rate=1):
|
|
||||||
information_content = calculate_information_content(input_text)
|
|
||||||
print(f"文本的信息量(熵): {information_content:.4f} bits")
|
|
||||||
topic_num = max(1, min(5, int(information_content * rate / 4)))
|
|
||||||
topic_prompt = find_topic(input_text, topic_num)
|
|
||||||
topic_response = self.llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
compressed_memory = set()
|
|
||||||
for topic in topics:
|
|
||||||
topic_what_prompt = topic_what(input_text,topic)
|
|
||||||
topic_what_response = self.llm_model_small.generate_response(topic_what_prompt)
|
|
||||||
compressed_memory.add((topic.strip(), topic_what_response[0])) # 将话题和记忆作为元组存储
|
|
||||||
return compressed_memory
|
|
||||||
|
|
||||||
def segment_text(text):
|
|
||||||
seg_text = list(jieba.cut(text))
|
|
||||||
return seg_text
|
|
||||||
|
|
||||||
def find_topic(text, topic_num):
|
|
||||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def topic_what(text, topic):
|
|
||||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def visualize_graph(memory_graph: Memory_graph, color_by_memory: bool = False):
|
|
||||||
# 设置中文字体
|
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
|
||||||
|
|
||||||
G = memory_graph.G
|
|
||||||
|
|
||||||
# 创建一个新图用于可视化
|
|
||||||
H = G.copy()
|
|
||||||
|
|
||||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
|
||||||
nodes_to_remove = []
|
|
||||||
for node in H.nodes():
|
|
||||||
memory_items = H.nodes[node].get('memory_items', [])
|
|
||||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
|
||||||
degree = H.degree(node)
|
|
||||||
if memory_count <= 1 or degree <= 2:
|
|
||||||
nodes_to_remove.append(node)
|
|
||||||
|
|
||||||
H.remove_nodes_from(nodes_to_remove)
|
|
||||||
|
|
||||||
# 如果过滤后没有节点,则返回
|
|
||||||
if len(H.nodes()) == 0:
|
|
||||||
print("过滤后没有符合条件的节点可显示")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 保存图到本地
|
|
||||||
nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
|
||||||
|
|
||||||
# 根据连接条数或记忆数量设置节点颜色
|
|
||||||
node_colors = []
|
|
||||||
nodes = list(H.nodes()) # 获取图中实际的节点列表
|
|
||||||
|
|
||||||
if color_by_memory:
|
|
||||||
# 计算每个节点的记忆数量
|
|
||||||
memory_counts = []
|
|
||||||
for node in nodes:
|
|
||||||
memory_items = H.nodes[node].get('memory_items', [])
|
|
||||||
if isinstance(memory_items, list):
|
|
||||||
count = len(memory_items)
|
|
||||||
else:
|
|
||||||
count = 1 if memory_items else 0
|
|
||||||
memory_counts.append(count)
|
|
||||||
max_memories = max(memory_counts) if memory_counts else 1
|
|
||||||
|
|
||||||
for count in memory_counts:
|
|
||||||
# 使用不同的颜色方案:红色表示记忆多,蓝色表示记忆少
|
|
||||||
if max_memories > 0:
|
|
||||||
intensity = min(1.0, count / max_memories)
|
|
||||||
color = (intensity, 0, 1.0 - intensity) # 从蓝色渐变到红色
|
|
||||||
else:
|
|
||||||
color = (0, 0, 1) # 如果没有记忆,则为蓝色
|
|
||||||
node_colors.append(color)
|
|
||||||
else:
|
|
||||||
# 使用原来的连接数量着色方案
|
|
||||||
max_degree = max(H.degree(), key=lambda x: x[1])[1] if H.degree() else 1
|
|
||||||
for node in nodes:
|
|
||||||
degree = H.degree(node)
|
|
||||||
if max_degree > 0:
|
|
||||||
red = min(1.0, degree / max_degree)
|
|
||||||
blue = 1.0 - red
|
|
||||||
color = (red, 0, blue)
|
|
||||||
else:
|
|
||||||
color = (0, 0, 1)
|
|
||||||
node_colors.append(color)
|
|
||||||
|
|
||||||
# 绘制图形
|
|
||||||
plt.figure(figsize=(12, 8))
|
|
||||||
pos = nx.spring_layout(H, k=1, iterations=50)
|
|
||||||
nx.draw(H, pos,
|
|
||||||
with_labels=True,
|
|
||||||
node_color=node_colors,
|
|
||||||
node_size=2000,
|
|
||||||
font_size=10,
|
|
||||||
font_family='SimHei',
|
|
||||||
font_weight='bold')
|
|
||||||
|
|
||||||
title = '记忆图谱可视化 - ' + ('按记忆数量着色' if color_by_memory else '按连接数量着色')
|
|
||||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 初始化数据库
|
|
||||||
Database.initialize(
|
|
||||||
host= os.getenv("MONGODB_HOST"),
|
|
||||||
port= int(os.getenv("MONGODB_PORT")),
|
|
||||||
db_name= os.getenv("DATABASE_NAME"),
|
|
||||||
username= os.getenv("MONGODB_USERNAME"),
|
|
||||||
password= os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 创建记忆图
|
|
||||||
memory_graph = Memory_graph()
|
|
||||||
# 加载数据库中存储的记忆图
|
|
||||||
memory_graph.load_graph_from_db()
|
|
||||||
# 创建海马体
|
|
||||||
hippocampus = Hippocampus(memory_graph)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
|
||||||
|
|
||||||
# 构建记忆
|
|
||||||
hippocampus.build_memory(chat_size=25)
|
|
||||||
|
|
||||||
# 展示两种不同的可视化方式
|
|
||||||
print("\n按连接数量着色的图谱:")
|
|
||||||
visualize_graph(memory_graph, color_by_memory=False)
|
|
||||||
|
|
||||||
print("\n按记忆数量着色的图谱:")
|
|
||||||
visualize_graph(memory_graph, color_by_memory=True)
|
|
||||||
|
|
||||||
# 交互式查询
|
|
||||||
while True:
|
|
||||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
items_list = memory_graph.get_related_item(query)
|
|
||||||
if items_list:
|
|
||||||
for memory_item in items_list:
|
|
||||||
print(memory_item)
|
|
||||||
else:
|
|
||||||
print("未找到相关记忆。")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
query = input("请输入问题:")
|
|
||||||
|
|
||||||
if query.lower() == '退出':
|
|
||||||
break
|
|
||||||
|
|
||||||
topic_prompt = find_topic(query, 3)
|
|
||||||
topic_response = hippocampus.llm_model.generate_response(topic_prompt)
|
|
||||||
# 检查 topic_response 是否为元组
|
|
||||||
if isinstance(topic_response, tuple):
|
|
||||||
topics = topic_response[0].split(",") # 假设第一个元素是我们需要的字符串
|
|
||||||
else:
|
|
||||||
topics = topic_response.split(",")
|
|
||||||
print(topics)
|
|
||||||
|
|
||||||
for keyword in topics:
|
|
||||||
items_list = memory_graph.get_related_item(keyword)
|
|
||||||
if items_list:
|
|
||||||
print(items_list)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
786
src/plugins/memory_system/memory_manual_build.py
Normal file
786
src/plugins/memory_system/memory_manual_build.py
Normal file
@@ -0,0 +1,786 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import sys
|
||||||
|
import jieba
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import pymongo
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
from snownlp import SnowNLP
|
||||||
|
# from chat.config import global_config
|
||||||
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
|
from src.common.database import Database
|
||||||
|
from src.plugins.memory_system.offline_llm import LLMModel
|
||||||
|
|
||||||
|
# 获取当前文件的目录
|
||||||
|
current_dir = Path(__file__).resolve().parent
|
||||||
|
# 获取项目根目录(上三层目录)
|
||||||
|
project_root = current_dir.parent.parent.parent
|
||||||
|
# env.dev文件路径
|
||||||
|
env_path = project_root / ".env.dev"
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
if env_path.exists():
|
||||||
|
logger.info(f"从 {env_path} 加载环境变量")
|
||||||
|
load_dotenv(env_path)
|
||||||
|
else:
|
||||||
|
logger.warning(f"未找到环境变量文件: {env_path}")
|
||||||
|
logger.info("将使用默认配置")
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
_instance = None
|
||||||
|
db = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not Database.db:
|
||||||
|
Database.initialize(
|
||||||
|
host=os.getenv("MONGODB_HOST"),
|
||||||
|
port=int(os.getenv("MONGODB_PORT")),
|
||||||
|
db_name=os.getenv("DATABASE_NAME"),
|
||||||
|
username=os.getenv("MONGODB_USERNAME"),
|
||||||
|
password=os.getenv("MONGODB_PASSWORD"),
|
||||||
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
|
||||||
|
try:
|
||||||
|
if username and password:
|
||||||
|
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
|
||||||
|
else:
|
||||||
|
uri = f"mongodb://{host}:{port}"
|
||||||
|
|
||||||
|
client = pymongo.MongoClient(uri)
|
||||||
|
cls.db = client[db_name]
|
||||||
|
# 测试连接
|
||||||
|
client.server_info()
|
||||||
|
logger.success("MongoDB连接成功!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化MongoDB失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def calculate_information_content(text):
|
||||||
|
"""计算文本的信息量(熵)"""
|
||||||
|
char_count = Counter(text)
|
||||||
|
total_chars = len(text)
|
||||||
|
|
||||||
|
entropy = 0
|
||||||
|
for count in char_count.values():
|
||||||
|
probability = count / total_chars
|
||||||
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
|
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
|
||||||
|
chat_text = ''
|
||||||
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||||
|
|
||||||
|
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||||
|
closest_time = closest_record['time']
|
||||||
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
|
chat_records = list(db.db.messages.find(
|
||||||
|
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||||
|
).sort('time', 1).limit(length))
|
||||||
|
|
||||||
|
# 更新每条消息的memorized属性
|
||||||
|
for record in chat_records:
|
||||||
|
# 检查当前记录的memorized值
|
||||||
|
current_memorized = record.get('memorized', 0)
|
||||||
|
if current_memorized > 3:
|
||||||
|
print(f"消息已读取3次,跳过")
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# 更新memorized值
|
||||||
|
db.db.messages.update_one(
|
||||||
|
{"_id": record["_id"]},
|
||||||
|
{"$set": {"memorized": current_memorized + 1}}
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_text += record["detailed_plain_text"]
|
||||||
|
|
||||||
|
return chat_text
|
||||||
|
print(f"消息已读取3次,跳过")
|
||||||
|
return ''
|
||||||
|
|
||||||
|
class Memory_graph:
|
||||||
|
def __init__(self):
|
||||||
|
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
|
def connect_dot(self, concept1, concept2):
|
||||||
|
# 如果边已存在,增加 strength
|
||||||
|
if self.G.has_edge(concept1, concept2):
|
||||||
|
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||||
|
else:
|
||||||
|
# 如果是新边,初始化 strength 为 1
|
||||||
|
self.G.add_edge(concept1, concept2, strength=1)
|
||||||
|
|
||||||
|
def add_dot(self, concept, memory):
|
||||||
|
if concept in self.G:
|
||||||
|
# 如果节点已存在,将新记忆添加到现有列表中
|
||||||
|
if 'memory_items' in self.G.nodes[concept]:
|
||||||
|
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||||
|
# 如果当前不是列表,将其转换为列表
|
||||||
|
self.G.nodes[concept]['memory_items'] = [self.G.nodes[concept]['memory_items']]
|
||||||
|
self.G.nodes[concept]['memory_items'].append(memory)
|
||||||
|
else:
|
||||||
|
self.G.nodes[concept]['memory_items'] = [memory]
|
||||||
|
else:
|
||||||
|
# 如果是新节点,创建新的记忆列表
|
||||||
|
self.G.add_node(concept, memory_items=[memory])
|
||||||
|
|
||||||
|
def get_dot(self, concept):
|
||||||
|
# 检查节点是否存在于图中
|
||||||
|
if concept in self.G:
|
||||||
|
# 从图中获取节点数据
|
||||||
|
node_data = self.G.nodes[concept]
|
||||||
|
return concept, node_data
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_related_item(self, topic, depth=1):
|
||||||
|
if topic not in self.G:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
first_layer_items = []
|
||||||
|
second_layer_items = []
|
||||||
|
|
||||||
|
# 获取相邻节点
|
||||||
|
neighbors = list(self.G.neighbors(topic))
|
||||||
|
|
||||||
|
# 获取当前节点的记忆项
|
||||||
|
node_data = self.get_dot(topic)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
first_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
first_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
# 只在depth=2时获取第二层记忆
|
||||||
|
if depth >= 2:
|
||||||
|
# 获取相邻节点的记忆项
|
||||||
|
for neighbor in neighbors:
|
||||||
|
node_data = self.get_dot(neighbor)
|
||||||
|
if node_data:
|
||||||
|
concept, data = node_data
|
||||||
|
if 'memory_items' in data:
|
||||||
|
memory_items = data['memory_items']
|
||||||
|
if isinstance(memory_items, list):
|
||||||
|
second_layer_items.extend(memory_items)
|
||||||
|
else:
|
||||||
|
second_layer_items.append(memory_items)
|
||||||
|
|
||||||
|
return first_layer_items, second_layer_items
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dots(self):
|
||||||
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
|
# 海马体
|
||||||
|
class Hippocampus:
|
||||||
|
def __init__(self, memory_graph: Memory_graph):
|
||||||
|
self.memory_graph = memory_graph
|
||||||
|
self.llm_model = LLMModel()
|
||||||
|
self.llm_model_small = LLMModel(model_name="deepseek-ai/DeepSeek-V2.5")
|
||||||
|
self.llm_model_get_topic = LLMModel(model_name="Pro/Qwen/Qwen2.5-7B-Instruct")
|
||||||
|
self.llm_model_summary = LLMModel(model_name="Qwen/Qwen2.5-32B-Instruct")
|
||||||
|
|
||||||
|
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}):
|
||||||
|
current_timestamp = datetime.datetime.now().timestamp()
|
||||||
|
chat_text = []
|
||||||
|
#短期:1h 中期:4h 长期:24h
|
||||||
|
for _ in range(time_frequency.get('near')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(1, 3600*4) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('mid')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600*4, 3600*24) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
for _ in range(time_frequency.get('far')): # 循环10次
|
||||||
|
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) # 随机时间
|
||||||
|
chat_ = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
|
||||||
|
chat_text.append(chat_)
|
||||||
|
return [chat for chat in chat_text if chat]
|
||||||
|
|
||||||
|
def calculate_topic_num(self,text, compress_rate):
|
||||||
|
"""计算文本的话题数量"""
|
||||||
|
information_content = calculate_information_content(text)
|
||||||
|
topic_by_length = text.count('\n')*compress_rate
|
||||||
|
topic_by_information_content = max(1, min(5, int((information_content-3) * 2)))
|
||||||
|
topic_num = int((topic_by_length + topic_by_information_content)/2)
|
||||||
|
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
|
||||||
|
return topic_num
|
||||||
|
|
||||||
|
async def memory_compress(self, input_text, compress_rate=0.1):
|
||||||
|
print(input_text)
|
||||||
|
|
||||||
|
#获取topics
|
||||||
|
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
||||||
|
topics_response = await self.llm_model_get_topic.generate_response_async(self.find_topic_llm(input_text, topic_num))
|
||||||
|
# 修改话题处理逻辑
|
||||||
|
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||||
|
print(f"话题: {topics}")
|
||||||
|
|
||||||
|
# 创建所有话题的请求任务
|
||||||
|
tasks = []
|
||||||
|
for topic in topics:
|
||||||
|
topic_what_prompt = self.topic_what(input_text, topic)
|
||||||
|
# 创建异步任务
|
||||||
|
task = self.llm_model_small.generate_response_async(topic_what_prompt)
|
||||||
|
tasks.append((topic.strip(), task))
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
compressed_memory = set()
|
||||||
|
for topic, task in tasks:
|
||||||
|
response = await task
|
||||||
|
if response:
|
||||||
|
compressed_memory.add((topic, response[0]))
|
||||||
|
|
||||||
|
return compressed_memory
|
||||||
|
|
||||||
|
async def operation_build_memory(self, chat_size=12):
|
||||||
|
# 最近消息获取频率
|
||||||
|
time_frequency = {'near': 3, 'mid': 8, 'far': 5}
|
||||||
|
memory_sample = self.get_memory_sample(chat_size, time_frequency)
|
||||||
|
|
||||||
|
all_topics = [] # 用于存储所有话题
|
||||||
|
|
||||||
|
for i, input_text in enumerate(memory_sample, 1):
|
||||||
|
# 加载进度可视化
|
||||||
|
all_topics = []
|
||||||
|
progress = (i / len(memory_sample)) * 100
|
||||||
|
bar_length = 30
|
||||||
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
|
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
|
|
||||||
|
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
|
||||||
|
compressed_memory = set()
|
||||||
|
compress_rate = 0.1
|
||||||
|
compressed_memory = await self.memory_compress(input_text, compress_rate)
|
||||||
|
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
|
||||||
|
|
||||||
|
# 将记忆加入到图谱中
|
||||||
|
for topic, memory in compressed_memory:
|
||||||
|
print(f"\033[1;32m添加节点\033[0m: {topic}")
|
||||||
|
self.memory_graph.add_dot(topic, memory)
|
||||||
|
all_topics.append(topic) # 收集所有话题
|
||||||
|
for i in range(len(all_topics)):
|
||||||
|
for j in range(i + 1, len(all_topics)):
|
||||||
|
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]} 和 {all_topics[j]}")
|
||||||
|
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
|
||||||
|
def sync_memory_from_db(self):
|
||||||
|
"""
|
||||||
|
从数据库同步数据到内存中的图结构
|
||||||
|
将清空当前内存中的图,并从数据库重新加载所有节点和边
|
||||||
|
"""
|
||||||
|
# 清空当前图
|
||||||
|
self.memory_graph.G.clear()
|
||||||
|
|
||||||
|
# 从数据库加载所有节点
|
||||||
|
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||||
|
for node in nodes:
|
||||||
|
concept = node['concept']
|
||||||
|
memory_items = node.get('memory_items', [])
|
||||||
|
# 确保memory_items是列表
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
# 添加节点到图中
|
||||||
|
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||||
|
|
||||||
|
# 从数据库加载所有边
|
||||||
|
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||||
|
for edge in edges:
|
||||||
|
source = edge['source']
|
||||||
|
target = edge['target']
|
||||||
|
strength = edge.get('strength', 1) # 获取 strength,默认为 1
|
||||||
|
# 只有当源节点和目标节点都存在时才添加边
|
||||||
|
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||||
|
self.memory_graph.G.add_edge(source, target, strength=strength)
|
||||||
|
|
||||||
|
logger.success("从数据库同步记忆图谱完成")
|
||||||
|
|
||||||
|
def calculate_node_hash(self, concept, memory_items):
|
||||||
|
"""
|
||||||
|
计算节点的特征值
|
||||||
|
"""
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
# 将记忆项排序以确保相同内容生成相同的哈希值
|
||||||
|
sorted_items = sorted(memory_items)
|
||||||
|
# 组合概念和记忆项生成特征值
|
||||||
|
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||||
|
return hash(content)
|
||||||
|
|
||||||
|
def calculate_edge_hash(self, source, target):
|
||||||
|
"""
|
||||||
|
计算边的特征值
|
||||||
|
"""
|
||||||
|
# 对源节点和目标节点排序以确保相同的边生成相同的哈希值
|
||||||
|
nodes = sorted([source, target])
|
||||||
|
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||||
|
|
||||||
|
def sync_memory_to_db(self):
|
||||||
|
"""
|
||||||
|
检查并同步内存中的图结构与数据库
|
||||||
|
使用特征值(哈希值)快速判断是否需要更新
|
||||||
|
"""
|
||||||
|
# 获取数据库中所有节点和内存中所有节点
|
||||||
|
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||||
|
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
|
|
||||||
|
# 转换数据库节点为字典格式,方便查找
|
||||||
|
db_nodes_dict = {node['concept']: node for node in db_nodes}
|
||||||
|
|
||||||
|
# 检查并更新节点
|
||||||
|
for concept, data in memory_nodes:
|
||||||
|
memory_items = data.get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 计算内存中节点的特征值
|
||||||
|
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||||
|
|
||||||
|
if concept not in db_nodes_dict:
|
||||||
|
# 数据库中缺少的节点,添加
|
||||||
|
logger.info(f"添加新节点: {concept}")
|
||||||
|
node_data = {
|
||||||
|
'concept': concept,
|
||||||
|
'memory_items': memory_items,
|
||||||
|
'hash': memory_hash
|
||||||
|
}
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||||
|
else:
|
||||||
|
# 获取数据库中节点的特征值
|
||||||
|
db_node = db_nodes_dict[concept]
|
||||||
|
db_hash = db_node.get('hash', None)
|
||||||
|
|
||||||
|
# 如果特征值不同,则更新节点
|
||||||
|
if db_hash != memory_hash:
|
||||||
|
logger.info(f"更新节点内容: {concept}")
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||||
|
{'concept': concept},
|
||||||
|
{'$set': {
|
||||||
|
'memory_items': memory_items,
|
||||||
|
'hash': memory_hash
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查并删除数据库中多余的节点
|
||||||
|
memory_concepts = set(node[0] for node in memory_nodes)
|
||||||
|
for db_node in db_nodes:
|
||||||
|
if db_node['concept'] not in memory_concepts:
|
||||||
|
logger.info(f"删除多余节点: {db_node['concept']}")
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||||
|
|
||||||
|
# 处理边的信息
|
||||||
|
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||||
|
memory_edges = list(self.memory_graph.G.edges())
|
||||||
|
|
||||||
|
# 创建边的哈希值字典
|
||||||
|
db_edge_dict = {}
|
||||||
|
for edge in db_edges:
|
||||||
|
edge_hash = self.calculate_edge_hash(edge['source'], edge['target'])
|
||||||
|
db_edge_dict[(edge['source'], edge['target'])] = {
|
||||||
|
'hash': edge_hash,
|
||||||
|
'num': edge.get('num', 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查并更新边
|
||||||
|
for source, target in memory_edges:
|
||||||
|
edge_hash = self.calculate_edge_hash(source, target)
|
||||||
|
edge_key = (source, target)
|
||||||
|
|
||||||
|
if edge_key not in db_edge_dict:
|
||||||
|
# 添加新边
|
||||||
|
logger.info(f"添加新边: {source} - {target}")
|
||||||
|
edge_data = {
|
||||||
|
'source': source,
|
||||||
|
'target': target,
|
||||||
|
'num': 1,
|
||||||
|
'hash': edge_hash
|
||||||
|
}
|
||||||
|
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
else:
|
||||||
|
# 检查边的特征值是否变化
|
||||||
|
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||||
|
logger.info(f"更新边: {source} - {target}")
|
||||||
|
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||||
|
{'source': source, 'target': target},
|
||||||
|
{'$set': {'hash': edge_hash}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除多余的边
|
||||||
|
memory_edge_set = set(memory_edges)
|
||||||
|
for edge_key in db_edge_dict:
|
||||||
|
if edge_key not in memory_edge_set:
|
||||||
|
source, target = edge_key
|
||||||
|
logger.info(f"删除多余边: {source} - {target}")
|
||||||
|
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||||
|
'source': source,
|
||||||
|
'target': target
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.success("完成记忆图谱与数据库的差异同步")
|
||||||
|
|
||||||
|
def find_topic_llm(self,text, topic_num):
|
||||||
|
# prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||||
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def topic_what(self,text, topic):
|
||||||
|
# prompt = f'这是一段文字:{text}。我想知道这段文字里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
|
||||||
|
prompt = f'这是一段文字:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def remove_node_from_db(self, topic):
|
||||||
|
"""
|
||||||
|
从数据库中删除指定节点及其相关的边
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 要删除的节点概念
|
||||||
|
"""
|
||||||
|
# 删除节点
|
||||||
|
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
|
||||||
|
# 删除所有涉及该节点的边
|
||||||
|
self.memory_graph.db.db.graph_data.edges.delete_many({
|
||||||
|
'$or': [
|
||||||
|
{'source': topic},
|
||||||
|
{'target': topic}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
def forget_topic(self, topic):
|
||||||
|
"""
|
||||||
|
随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点
|
||||||
|
只在内存中的图上操作,不直接与数据库交互
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 要删除记忆的话题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
removed_item: 被删除的记忆项,如果没有删除任何记忆则返回 None
|
||||||
|
"""
|
||||||
|
if topic not in self.memory_graph.G:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取话题节点数据
|
||||||
|
node_data = self.memory_graph.G.nodes[topic]
|
||||||
|
|
||||||
|
# 如果节点存在memory_items
|
||||||
|
if 'memory_items' in node_data:
|
||||||
|
memory_items = node_data['memory_items']
|
||||||
|
|
||||||
|
# 确保memory_items是列表
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 如果有记忆项可以删除
|
||||||
|
if memory_items:
|
||||||
|
# 随机选择一个记忆项删除
|
||||||
|
removed_item = random.choice(memory_items)
|
||||||
|
memory_items.remove(removed_item)
|
||||||
|
|
||||||
|
# 更新节点的记忆项
|
||||||
|
if memory_items:
|
||||||
|
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||||
|
else:
|
||||||
|
# 如果没有记忆项了,删除整个节点
|
||||||
|
self.memory_graph.G.remove_node(topic)
|
||||||
|
|
||||||
|
return removed_item
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def operation_forget_topic(self, percentage=0.1):
|
||||||
|
"""
|
||||||
|
随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘
|
||||||
|
|
||||||
|
Args:
|
||||||
|
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||||
|
"""
|
||||||
|
# 获取所有节点
|
||||||
|
all_nodes = list(self.memory_graph.G.nodes())
|
||||||
|
# 计算要检查的节点数量
|
||||||
|
check_count = max(1, int(len(all_nodes) * percentage))
|
||||||
|
# 随机选择节点
|
||||||
|
nodes_to_check = random.sample(all_nodes, check_count)
|
||||||
|
|
||||||
|
forgotten_nodes = []
|
||||||
|
for node in nodes_to_check:
|
||||||
|
# 获取节点的连接数
|
||||||
|
connections = self.memory_graph.G.degree(node)
|
||||||
|
|
||||||
|
# 获取节点的内容条数
|
||||||
|
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
content_count = len(memory_items)
|
||||||
|
|
||||||
|
# 检查连接强度
|
||||||
|
weak_connections = True
|
||||||
|
if connections > 1: # 只有当连接数大于1时才检查强度
|
||||||
|
for neighbor in self.memory_graph.G.neighbors(node):
|
||||||
|
strength = self.memory_graph.G[node][neighbor].get('strength', 1)
|
||||||
|
if strength > 2:
|
||||||
|
weak_connections = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果满足遗忘条件
|
||||||
|
if (connections <= 1 and weak_connections) or content_count <= 2:
|
||||||
|
removed_item = self.forget_topic(node)
|
||||||
|
if removed_item:
|
||||||
|
forgotten_nodes.append((node, removed_item))
|
||||||
|
logger.info(f"遗忘节点 {node} 的记忆: {removed_item}")
|
||||||
|
|
||||||
|
# 同步到数据库
|
||||||
|
if forgotten_nodes:
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
logger.info(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
||||||
|
else:
|
||||||
|
logger.info("本次检查没有节点满足遗忘条件")
|
||||||
|
|
||||||
|
async def merge_memory(self, topic):
|
||||||
|
"""
|
||||||
|
对指定话题的记忆进行合并压缩
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 要合并的话题节点
|
||||||
|
"""
|
||||||
|
# 获取节点的记忆项
|
||||||
|
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
# 如果记忆项不足,直接返回
|
||||||
|
if len(memory_items) < 10:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 随机选择10条记忆
|
||||||
|
selected_memories = random.sample(memory_items, 10)
|
||||||
|
|
||||||
|
# 拼接成文本
|
||||||
|
merged_text = "\n".join(selected_memories)
|
||||||
|
print(f"\n[合并记忆] 话题: {topic}")
|
||||||
|
print(f"选择的记忆:\n{merged_text}")
|
||||||
|
|
||||||
|
# 使用memory_compress生成新的压缩记忆
|
||||||
|
compressed_memories = await self.memory_compress(merged_text, 0.1)
|
||||||
|
|
||||||
|
# 从原记忆列表中移除被选中的记忆
|
||||||
|
for memory in selected_memories:
|
||||||
|
memory_items.remove(memory)
|
||||||
|
|
||||||
|
# 添加新的压缩记忆
|
||||||
|
for _, compressed_memory in compressed_memories:
|
||||||
|
memory_items.append(compressed_memory)
|
||||||
|
print(f"添加压缩记忆: {compressed_memory}")
|
||||||
|
|
||||||
|
# 更新节点的记忆项
|
||||||
|
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||||
|
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||||
|
|
||||||
|
async def operation_merge_memory(self, percentage=0.1):
|
||||||
|
"""
|
||||||
|
随机检查一定比例的节点,对内容数量超过100的节点进行记忆合并
|
||||||
|
|
||||||
|
Args:
|
||||||
|
percentage: 要检查的节点比例,默认为0.1(10%)
|
||||||
|
"""
|
||||||
|
# 获取所有节点
|
||||||
|
all_nodes = list(self.memory_graph.G.nodes())
|
||||||
|
# 计算要检查的节点数量
|
||||||
|
check_count = max(1, int(len(all_nodes) * percentage))
|
||||||
|
# 随机选择节点
|
||||||
|
nodes_to_check = random.sample(all_nodes, check_count)
|
||||||
|
|
||||||
|
merged_nodes = []
|
||||||
|
for node in nodes_to_check:
|
||||||
|
# 获取节点的内容条数
|
||||||
|
memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
|
||||||
|
if not isinstance(memory_items, list):
|
||||||
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
content_count = len(memory_items)
|
||||||
|
|
||||||
|
# 如果内容数量超过100,进行合并
|
||||||
|
if content_count > 100:
|
||||||
|
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
|
||||||
|
await self.merge_memory(node)
|
||||||
|
merged_nodes.append(node)
|
||||||
|
|
||||||
|
# 同步到数据库
|
||||||
|
if merged_nodes:
|
||||||
|
self.sync_memory_to_db()
|
||||||
|
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
||||||
|
else:
|
||||||
|
print("\n本次检查没有需要合并的节点")
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
|
||||||
|
G = memory_graph.G
|
||||||
|
|
||||||
|
# 创建一个新图用于可视化
|
||||||
|
H = G.copy()
|
||||||
|
|
||||||
|
# 计算节点大小和颜色
|
||||||
|
node_colors = []
|
||||||
|
node_sizes = []
|
||||||
|
nodes = list(H.nodes())
|
||||||
|
|
||||||
|
# 获取最大记忆数用于归一化节点大小
|
||||||
|
max_memories = 1
|
||||||
|
for node in nodes:
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
max_memories = max(max_memories, memory_count)
|
||||||
|
|
||||||
|
# 计算每个节点的大小和颜色
|
||||||
|
for node in nodes:
|
||||||
|
# 计算节点大小(基于记忆数量)
|
||||||
|
memory_items = H.nodes[node].get('memory_items', [])
|
||||||
|
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||||
|
# 使用指数函数使变化更明显
|
||||||
|
ratio = memory_count / max_memories
|
||||||
|
size = 400 + 2000 * (ratio ** 2) # 增大节点大小
|
||||||
|
node_sizes.append(size)
|
||||||
|
|
||||||
|
# 计算节点颜色(基于连接数)
|
||||||
|
degree = H.degree(node)
|
||||||
|
if degree >= 30:
|
||||||
|
node_colors.append((1.0, 0, 0)) # 亮红色 (#FF0000)
|
||||||
|
else:
|
||||||
|
# 将1-10映射到0-1的范围
|
||||||
|
color_ratio = (degree - 1) / 29.0 if degree > 1 else 0
|
||||||
|
# 使用蓝到红的渐变
|
||||||
|
red = min(0.9, color_ratio)
|
||||||
|
blue = max(0.0, 1.0 - color_ratio)
|
||||||
|
node_colors.append((red, 0, blue))
|
||||||
|
|
||||||
|
# 绘制图形
|
||||||
|
plt.figure(figsize=(16, 12)) # 减小图形尺寸
|
||||||
|
pos = nx.spring_layout(H,
|
||||||
|
k=1, # 调整节点间斥力
|
||||||
|
iterations=100, # 增加迭代次数
|
||||||
|
scale=1.5, # 减小布局尺寸
|
||||||
|
weight='strength') # 使用边的strength属性作为权重
|
||||||
|
|
||||||
|
nx.draw(H, pos,
|
||||||
|
with_labels=True,
|
||||||
|
node_color=node_colors,
|
||||||
|
node_size=node_sizes,
|
||||||
|
font_size=12, # 保持增大的字体大小
|
||||||
|
font_family='SimHei',
|
||||||
|
font_weight='bold',
|
||||||
|
edge_color='gray',
|
||||||
|
width=1.5) # 统一的边宽度
|
||||||
|
|
||||||
|
title = '记忆图谱可视化 - 节点大小表示记忆数量\n节点颜色:蓝(弱连接)到红(强连接)渐变,边的透明度表示连接强度\n连接强度越大的节点距离越近'
|
||||||
|
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# 初始化数据库
|
||||||
|
logger.info("正在初始化数据库连接...")
|
||||||
|
db = Database.get_instance()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
|
||||||
|
|
||||||
|
# 创建记忆图
|
||||||
|
memory_graph = Memory_graph()
|
||||||
|
|
||||||
|
# 创建海马体
|
||||||
|
hippocampus = Hippocampus(memory_graph)
|
||||||
|
|
||||||
|
# 从数据库同步数据
|
||||||
|
hippocampus.sync_memory_from_db()
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
|
|
||||||
|
# 构建记忆
|
||||||
|
if test_pare['do_build_memory']:
|
||||||
|
logger.info("开始构建记忆...")
|
||||||
|
chat_size = 20
|
||||||
|
await hippocampus.operation_build_memory(chat_size=chat_size)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
|
||||||
|
|
||||||
|
if test_pare['do_forget_topic']:
|
||||||
|
logger.info("开始遗忘记忆...")
|
||||||
|
await hippocampus.operation_forget_topic(percentage=0.1)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
|
|
||||||
|
if test_pare['do_merge_memory']:
|
||||||
|
logger.info("开始合并记忆...")
|
||||||
|
await hippocampus.operation_merge_memory(percentage=0.1)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
||||||
|
|
||||||
|
if test_pare['do_visualize_graph']:
|
||||||
|
# 展示优化后的图形
|
||||||
|
logger.info("生成记忆图谱可视化...")
|
||||||
|
print("\n生成优化后的记忆图谱:")
|
||||||
|
visualize_graph_lite(memory_graph)
|
||||||
|
|
||||||
|
if test_pare['do_query']:
|
||||||
|
# 交互式查询
|
||||||
|
while True:
|
||||||
|
query = input("\n请输入新的查询概念(输入'退出'以结束):")
|
||||||
|
if query.lower() == '退出':
|
||||||
|
break
|
||||||
|
|
||||||
|
items_list = memory_graph.get_related_item(query)
|
||||||
|
if items_list:
|
||||||
|
first_layer, second_layer = items_list
|
||||||
|
if first_layer:
|
||||||
|
print("\n直接相关的记忆:")
|
||||||
|
for item in first_layer:
|
||||||
|
print(f"- {item}")
|
||||||
|
if second_layer:
|
||||||
|
print("\n间接相关的记忆:")
|
||||||
|
for item in second_layer:
|
||||||
|
print(f"- {item}")
|
||||||
|
else:
|
||||||
|
print("未找到相关记忆。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
125
src/plugins/memory_system/offline_llm.py
Normal file
125
src/plugins/memory_system/offline_llm.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from typing import Tuple, Union
|
||||||
|
import time
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
class LLMModel:
|
||||||
|
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.params = kwargs
|
||||||
|
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||||
|
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||||
|
|
||||||
|
if not self.api_key or not self.base_url:
|
||||||
|
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||||
|
|
||||||
|
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
|
||||||
|
|
||||||
|
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
"""根据输入的提示生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的 chat/completions 端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15 # 基础等待时间(秒)
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.error(f"请求失败: {str(e)}")
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
"""异步方式根据输入的提示生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的 chat/completions 端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.error(f"请求失败: {str(e)}")
|
||||||
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
@@ -2,20 +2,26 @@ import aiohttp
|
|||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
import re
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from loguru import logger
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
|
from ..chat.utils_image import compress_base64_image_by_scale
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class LLM_request:
|
class LLM_request:
|
||||||
def __init__(self, model = global_config.llm_normal,**kwargs):
|
def __init__(self, model, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = getattr(config, model["key"])
|
self.api_key = getattr(config, model["key"])
|
||||||
self.base_url = getattr(config, model["base_url"])
|
self.base_url = getattr(config, model["base_url"])
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}")
|
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
|
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||||||
self.model_name = model["name"]
|
self.model_name = model["name"]
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
@@ -35,6 +41,7 @@ class LLM_request:
|
|||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
# 发送请求到完整的chat/completions端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
base_wait_time = 15
|
base_wait_time = 15
|
||||||
@@ -45,28 +52,41 @@ class LLM_request:
|
|||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
if response.status == 429:
|
if response.status == 429:
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if response.status in [500, 503]:
|
||||||
|
logger.error(f"服务器错误: {response.status}")
|
||||||
|
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
content = result["choices"][0]["message"]["content"]
|
message = result["choices"][0]["message"]
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
content = message.get("content", "")
|
||||||
|
think_match = None
|
||||||
|
reasoning_content = message.get("reasoning_content", "")
|
||||||
|
if not reasoning_content:
|
||||||
|
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||||
return content, reasoning_content
|
return content, reasoning_content
|
||||||
return "没有返回结果", ""
|
return "没有返回结果", ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||||
|
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
|
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||||
|
|
||||||
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
"""根据输入的提示和图片生成模型的异步响应"""
|
"""根据输入的提示和图片生成模型的异步响应"""
|
||||||
@@ -76,7 +96,8 @@ class LLM_request:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
data = {
|
def build_request_data(img_base64: str):
|
||||||
|
return {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
@@ -89,7 +110,7 @@ class LLM_request:
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -98,19 +119,91 @@ class LLM_request:
|
|||||||
**self.params
|
**self.params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
# 发送请求到完整的chat/completions端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
base_wait_time = 15
|
base_wait_time = 15
|
||||||
|
|
||||||
|
current_image_base64 = image_base64
|
||||||
|
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
|
||||||
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
|
data = build_request_data(current_image_base64)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
if response.status == 429:
|
if response.status == 429:
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif response.status == 413:
|
||||||
|
logger.warning("图片太大(413),尝试压缩...")
|
||||||
|
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status() # 检查其他响应状态
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
message = result["choices"][0]["message"]
|
||||||
|
content = message.get("content", "")
|
||||||
|
think_match = None
|
||||||
|
reasoning_content = message.get("reasoning_content", "")
|
||||||
|
if not reasoning_content:
|
||||||
|
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||||
|
return content, reasoning_content
|
||||||
|
return "没有返回结果", ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||||
|
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
|
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||||
|
|
||||||
|
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
"""异步方式根据输入的提示生成模型的响应"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
data = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.5,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求到完整的 chat/completions 端点
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||||
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -126,13 +219,17 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
|
logger.error(f"请求失败: {str(e)}")
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
"""同步方法:根据输入的提示和图片生成模型的响应"""
|
"""同步方法:根据输入的提示和图片生成模型的响应"""
|
||||||
headers = {
|
headers = {
|
||||||
@@ -140,6 +237,8 @@ class LLM_request:
|
|||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
image_base64=compress_base64_image_by_scale(image_base64)
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -165,6 +264,7 @@ class LLM_request:
|
|||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
# 发送请求到完整的chat/completions端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||||
|
|
||||||
max_retries = 2
|
max_retries = 2
|
||||||
base_wait_time = 6
|
base_wait_time = 6
|
||||||
@@ -174,8 +274,8 @@ class LLM_request:
|
|||||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
if response.status_code == 429:
|
if response.status_code == 429:
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -183,17 +283,138 @@ class LLM_request:
|
|||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
content = result["choices"][0]["message"]["content"]
|
message = result["choices"][0]["message"]
|
||||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
content = message.get("content", "")
|
||||||
|
think_match = None
|
||||||
|
reasoning_content = message.get("reasoning_content", "")
|
||||||
|
if not reasoning_content:
|
||||||
|
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||||
return content, reasoning_content
|
return content, reasoning_content
|
||||||
return "没有返回结果", ""
|
return "没有返回结果", ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||||
|
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
|
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||||
|
|
||||||
|
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""同步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||||
|
|
||||||
|
max_retries = 2
|
||||||
|
base_wait_time = 6
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""异步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.error("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,22 +1,23 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Union
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from src.plugins.chat.config import global_config
|
from src.plugins.chat.config import global_config
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
|
from loguru import logger
|
||||||
|
import json
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
@@ -42,8 +43,6 @@ class ScheduleGenerator:
|
|||||||
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
||||||
|
|
||||||
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
||||||
if target_date is None:
|
|
||||||
target_date = datetime.datetime.now()
|
|
||||||
|
|
||||||
date_str = target_date.strftime("%Y-%m-%d")
|
date_str = target_date.strftime("%Y-%m-%d")
|
||||||
weekday = target_date.strftime("%A")
|
weekday = target_date.strftime("%A")
|
||||||
@@ -59,15 +58,20 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
elif read_only == False:
|
elif read_only == False:
|
||||||
print(f"{date_str}的日程不存在,准备生成新的日程。")
|
print(f"{date_str}的日程不存在,准备生成新的日程。")
|
||||||
prompt = f"""我是{global_config.BOT_NICKNAME},一个曾经学习地质,现在学习心理学和脑科学的女大学生,喜欢刷qq,贴吧,知乎和小红书,请为我生成{date_str}({weekday})的日程安排,包括:
|
prompt = f"""我是{global_config.BOT_NICKNAME},{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}({weekday})的日程安排,包括:"""+\
|
||||||
|
"""
|
||||||
1. 早上的学习和工作安排
|
1. 早上的学习和工作安排
|
||||||
2. 下午的活动和任务
|
2. 下午的活动和任务
|
||||||
3. 晚上的计划和休息时间
|
3. 晚上的计划和休息时间
|
||||||
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用JSON格式返回日程表,仅返回内容,不要返回注释,时间采用24小时制,格式为{"时间": "活动","时间": "活动",...}。"""
|
||||||
|
|
||||||
|
try:
|
||||||
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
||||||
# print(self.schedule_text)
|
|
||||||
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成日程失败: {str(e)}")
|
||||||
|
schedule_text = "生成日程时出错了"
|
||||||
|
# print(self.schedule_text)
|
||||||
else:
|
else:
|
||||||
print(f"{date_str}的日程不存在。")
|
print(f"{date_str}的日程不存在。")
|
||||||
schedule_text = "忘了"
|
schedule_text = "忘了"
|
||||||
@@ -77,20 +81,15 @@ class ScheduleGenerator:
|
|||||||
schedule_form = self._parse_schedule(schedule_text)
|
schedule_form = self._parse_schedule(schedule_text)
|
||||||
return schedule_text,schedule_form
|
return schedule_text,schedule_form
|
||||||
|
|
||||||
def _parse_schedule(self, schedule_text: str) -> Dict[str, str]:
|
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
|
||||||
"""解析日程文本,转换为时间和活动的字典"""
|
"""解析日程文本,转换为时间和活动的字典"""
|
||||||
schedule_dict = {}
|
try:
|
||||||
# 按行分割日程文本
|
schedule_dict = json.loads(schedule_text)
|
||||||
lines = schedule_text.strip().split('\n')
|
|
||||||
for line in lines:
|
|
||||||
# print(line)
|
|
||||||
if ',' in line:
|
|
||||||
# 假设格式为 "时间: 活动"
|
|
||||||
time_str, activity = line.split(',', 1)
|
|
||||||
# print(time_str)
|
|
||||||
# print(activity)
|
|
||||||
schedule_dict[time_str.strip()] = activity.strip()
|
|
||||||
return schedule_dict
|
return schedule_dict
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(schedule_text)
|
||||||
|
print(f"解析日程失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
def _parse_time(self, time_str: str) -> str:
|
def _parse_time(self, time_str: str) -> str:
|
||||||
"""解析时间字符串,转换为时间"""
|
"""解析时间字符串,转换为时间"""
|
||||||
@@ -105,6 +104,8 @@ class ScheduleGenerator:
|
|||||||
min_diff = float('inf')
|
min_diff = float('inf')
|
||||||
|
|
||||||
# 检查今天的日程
|
# 检查今天的日程
|
||||||
|
if not self.today_schedule.keys():
|
||||||
|
return "摸鱼"
|
||||||
for time_str in self.today_schedule.keys():
|
for time_str in self.today_schedule.keys():
|
||||||
diff = abs(self._time_diff(current_time, time_str))
|
diff = abs(self._time_diff(current_time, time_str))
|
||||||
if closest_time is None or diff < min_diff:
|
if closest_time is None or diff < min_diff:
|
||||||
@@ -128,6 +129,10 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
def _time_diff(self, time1: str, time2: str) -> int:
|
def _time_diff(self, time1: str, time2: str) -> int:
|
||||||
"""计算两个时间字符串之间的分钟差"""
|
"""计算两个时间字符串之间的分钟差"""
|
||||||
|
if time1=="24:00":
|
||||||
|
time1="23:59"
|
||||||
|
if time2=="24:00":
|
||||||
|
time2="23:59"
|
||||||
t1 = datetime.datetime.strptime(time1, "%H:%M")
|
t1 = datetime.datetime.strptime(time1, "%H:%M")
|
||||||
t2 = datetime.datetime.strptime(time2, "%H:%M")
|
t2 = datetime.datetime.strptime(time2, "%H:%M")
|
||||||
diff = int((t2 - t1).total_seconds() / 60)
|
diff = int((t2 - t1).total_seconds() / 60)
|
||||||
@@ -141,7 +146,10 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
def print_schedule(self):
|
def print_schedule(self):
|
||||||
"""打印完整的日程安排"""
|
"""打印完整的日程安排"""
|
||||||
|
if not self._parse_schedule(self.today_schedule_text):
|
||||||
|
print("今日日程有误,将在下次运行时重新生成")
|
||||||
|
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
|
||||||
|
else:
|
||||||
print("\n=== 今日日程安排 ===")
|
print("\n=== 今日日程安排 ===")
|
||||||
for time_str, activity in self.today_schedule.items():
|
for time_str, activity in self.today_schedule.items():
|
||||||
print(f"时间[{time_str}]: 活动[{activity}]")
|
print(f"时间[{time_str}]: 活动[{activity}]")
|
||||||
|
|||||||
@@ -1,70 +0,0 @@
|
|||||||
from textblob import TextBlob
|
|
||||||
import jieba
|
|
||||||
from translate import Translator
|
|
||||||
|
|
||||||
def analyze_emotion(text):
|
|
||||||
"""
|
|
||||||
分析文本的情感,返回情感极性和主观性得分
|
|
||||||
:param text: 输入文本
|
|
||||||
:return: (情感极性, 主观性) 元组
|
|
||||||
情感极性: -1(非常消极) 到 1(非常积极)
|
|
||||||
主观性: 0(客观) 到 1(主观)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 创建翻译器
|
|
||||||
translator = Translator(to_lang="en", from_lang="zh")
|
|
||||||
|
|
||||||
# 如果是中文文本,先翻译成英文
|
|
||||||
# 因为TextBlob的情感分析主要基于英文
|
|
||||||
translated_text = translator.translate(text)
|
|
||||||
|
|
||||||
# 创建TextBlob对象
|
|
||||||
blob = TextBlob(translated_text)
|
|
||||||
|
|
||||||
# 获取情感极性和主观性
|
|
||||||
polarity = blob.sentiment.polarity
|
|
||||||
subjectivity = blob.sentiment.subjectivity
|
|
||||||
|
|
||||||
return polarity, subjectivity
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"分析过程中出现错误: {str(e)}")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def get_emotion_description(polarity, subjectivity):
|
|
||||||
"""
|
|
||||||
根据情感极性和主观性生成描述性文字
|
|
||||||
"""
|
|
||||||
if polarity is None or subjectivity is None:
|
|
||||||
return "无法分析情感"
|
|
||||||
|
|
||||||
# 情感极性描述
|
|
||||||
if polarity > 0.5:
|
|
||||||
emotion = "非常积极"
|
|
||||||
elif polarity > 0:
|
|
||||||
emotion = "较为积极"
|
|
||||||
elif polarity == 0:
|
|
||||||
emotion = "中性"
|
|
||||||
elif polarity > -0.5:
|
|
||||||
emotion = "较为消极"
|
|
||||||
else:
|
|
||||||
emotion = "非常消极"
|
|
||||||
|
|
||||||
# 主观性描述
|
|
||||||
if subjectivity > 0.7:
|
|
||||||
subj = "非常主观"
|
|
||||||
elif subjectivity > 0.3:
|
|
||||||
subj = "较为主观"
|
|
||||||
else:
|
|
||||||
subj = "较为客观"
|
|
||||||
|
|
||||||
return f"情感倾向: {emotion}, 表达方式: {subj}"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试样例
|
|
||||||
test_text = "今天天气真好,我感到非常开心!"
|
|
||||||
polarity, subjectivity = analyze_emotion(test_text)
|
|
||||||
print(f"测试文本: {test_text}")
|
|
||||||
print(f"情感极性: {polarity:.2f}")
|
|
||||||
print(f"主观性得分: {subjectivity:.2f}")
|
|
||||||
print(get_emotion_description(polarity, subjectivity))
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
def setup_bert_analyzer():
|
|
||||||
"""
|
|
||||||
设置中文BERT情感分析器
|
|
||||||
"""
|
|
||||||
# 使用专门针对中文情感分析的模型
|
|
||||||
model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 加载模型和分词器
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
||||||
|
|
||||||
# 创建情感分析pipeline
|
|
||||||
analyzer = pipeline("sentiment-analysis",
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer)
|
|
||||||
|
|
||||||
return analyzer
|
|
||||||
except Exception as e:
|
|
||||||
print(f"模型加载错误: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def analyze_emotion_bert(text, analyzer):
|
|
||||||
"""
|
|
||||||
使用BERT模型进行中文情感分析
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not analyzer:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 进行情感分析
|
|
||||||
result = analyzer(text)[0]
|
|
||||||
|
|
||||||
return {
|
|
||||||
'label': result['label'],
|
|
||||||
'score': result['score']
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"分析过程中出现错误: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_emotion_description_bert(result):
|
|
||||||
"""
|
|
||||||
将BERT的情感分析结果转换为描述性文字
|
|
||||||
"""
|
|
||||||
if not result:
|
|
||||||
return "无法分析情感"
|
|
||||||
|
|
||||||
label = "积极" if result['label'] == 'positive' else "消极"
|
|
||||||
confidence = result['score']
|
|
||||||
|
|
||||||
if confidence > 0.9:
|
|
||||||
strength = "强烈"
|
|
||||||
elif confidence > 0.7:
|
|
||||||
strength = "明显"
|
|
||||||
else:
|
|
||||||
strength = "轻微"
|
|
||||||
|
|
||||||
return f"{strength}{label}"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 初始化分析器
|
|
||||||
analyzer = setup_bert_analyzer()
|
|
||||||
|
|
||||||
# 测试样例
|
|
||||||
test_text = "这个产品质量很好,使用起来非常方便,推荐购买!"
|
|
||||||
result = analyze_emotion_bert(test_text, analyzer)
|
|
||||||
|
|
||||||
print(f"测试文本: {test_text}")
|
|
||||||
if result:
|
|
||||||
print(f"情感倾向: {get_emotion_description_bert(result)}")
|
|
||||||
print(f"置信度: {result['score']:.2f}")
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import hanlp
|
|
||||||
|
|
||||||
def analyze_emotion_hanlp(text):
|
|
||||||
"""
|
|
||||||
使用HanLP进行中文情感分析
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 使用更基础的模型
|
|
||||||
tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
|
|
||||||
|
|
||||||
# 分词
|
|
||||||
words = tokenizer(text)
|
|
||||||
|
|
||||||
# 简单的情感词典方法
|
|
||||||
positive_words = {'好', '棒', '优秀', '喜欢', '开心', '快乐', '美味', '推荐', '优质', '满意'}
|
|
||||||
negative_words = {'差', '糟', '烂', '讨厌', '失望', '难受', '恶心', '不满', '差劲', '垃圾'}
|
|
||||||
|
|
||||||
# 计算情感得分
|
|
||||||
score = 0
|
|
||||||
for word in words:
|
|
||||||
if word in positive_words:
|
|
||||||
score += 1
|
|
||||||
elif word in negative_words:
|
|
||||||
score -= 1
|
|
||||||
|
|
||||||
# 归一化得分
|
|
||||||
if score > 0:
|
|
||||||
return 1
|
|
||||||
elif score < 0:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"分析过程中出现错误: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_emotion_description_hanlp(score):
|
|
||||||
"""
|
|
||||||
将HanLP的情感分析结果转换为描述性文字
|
|
||||||
"""
|
|
||||||
if score is None:
|
|
||||||
return "无法分析情感"
|
|
||||||
elif score == 1:
|
|
||||||
return "积极"
|
|
||||||
elif score == 0:
|
|
||||||
return "消极"
|
|
||||||
else:
|
|
||||||
return "中性"
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试样例
|
|
||||||
test_texts = [
|
|
||||||
"这家餐厅的服务态度很好,菜品也很美味!",
|
|
||||||
"这个产品质量太差了,一点都不值这个价",
|
|
||||||
"今天天气不错,但是工作很累"
|
|
||||||
]
|
|
||||||
|
|
||||||
for test_text in test_texts:
|
|
||||||
result = analyze_emotion_hanlp(test_text)
|
|
||||||
print(f"\n测试文本: {test_text}")
|
|
||||||
print(f"情感倾向: {get_emotion_description_hanlp(result)}")
|
|
||||||
488
src/test/typo_creator.py
Normal file
488
src/test/typo_creator.py
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
"""
|
||||||
|
错别字生成器 - 流程说明
|
||||||
|
|
||||||
|
整体替换逻辑:
|
||||||
|
1. 数据准备
|
||||||
|
- 加载字频词典:使用jieba词典计算汉字使用频率
|
||||||
|
- 创建拼音映射:建立拼音到汉字的映射关系
|
||||||
|
- 加载词频信息:从jieba词典获取词语使用频率
|
||||||
|
|
||||||
|
2. 分词处理
|
||||||
|
- 使用jieba将输入句子分词
|
||||||
|
- 区分单字词和多字词
|
||||||
|
- 保留标点符号和空格
|
||||||
|
|
||||||
|
3. 词语级别替换(针对多字词)
|
||||||
|
- 触发条件:词长>1 且 随机概率<0.3
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取词语拼音
|
||||||
|
b. 生成所有可能的同音字组合
|
||||||
|
c. 过滤条件:
|
||||||
|
- 必须是jieba词典中的有效词
|
||||||
|
- 词频必须达到原词频的10%以上
|
||||||
|
- 综合评分(词频70%+字频30%)必须达到阈值
|
||||||
|
d. 按综合评分排序,选择最合适的替换词
|
||||||
|
|
||||||
|
4. 字级别替换(针对单字词或未进行整词替换的多字词)
|
||||||
|
- 单字替换概率:0.3
|
||||||
|
- 多字词中的单字替换概率:0.3 * (0.7 ^ (词长-1))
|
||||||
|
- 替换流程:
|
||||||
|
a. 获取字的拼音
|
||||||
|
b. 声调错误处理(20%概率)
|
||||||
|
c. 获取同音字列表
|
||||||
|
d. 过滤条件:
|
||||||
|
- 字频必须达到最小阈值
|
||||||
|
- 频率差异不能过大(指数衰减计算)
|
||||||
|
e. 按频率排序选择替换字
|
||||||
|
|
||||||
|
5. 频率控制机制
|
||||||
|
- 字频控制:使用归一化的字频(0-1000范围)
|
||||||
|
- 词频控制:使用jieba词典中的词频
|
||||||
|
- 频率差异计算:使用指数衰减函数
|
||||||
|
- 最小频率阈值:确保替换字/词不会太生僻
|
||||||
|
|
||||||
|
6. 输出信息
|
||||||
|
- 原文和错字版本的对照
|
||||||
|
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
|
||||||
|
- 替换类型说明(整词替换/声调错误/同音字替换)
|
||||||
|
- 词语分析和完整拼音
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
1. 所有替换都必须使用有意义的词语
|
||||||
|
2. 替换词的使用频率不能过低
|
||||||
|
3. 多字词优先考虑整词替换
|
||||||
|
4. 考虑声调变化的情况
|
||||||
|
5. 保持标点符号和空格不变
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pypinyin import pinyin, Style
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import jieba.posseg as pseg
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
def load_or_create_char_frequency():
|
||||||
|
"""
|
||||||
|
加载或创建汉字频率字典
|
||||||
|
"""
|
||||||
|
cache_file = Path("char_frequency.json")
|
||||||
|
|
||||||
|
# 如果缓存文件存在,直接加载
|
||||||
|
if cache_file.exists():
|
||||||
|
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# 使用内置的词频文件
|
||||||
|
char_freq = defaultdict(int)
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
|
||||||
|
# 读取jieba的词典文件
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
word, freq = line.strip().split()[:2]
|
||||||
|
# 对词中的每个字进行频率累加
|
||||||
|
for char in word:
|
||||||
|
if is_chinese_char(char):
|
||||||
|
char_freq[char] += int(freq)
|
||||||
|
|
||||||
|
# 归一化频率值
|
||||||
|
max_freq = max(char_freq.values())
|
||||||
|
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||||
|
|
||||||
|
# 保存到缓存文件
|
||||||
|
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return normalized_freq
|
||||||
|
|
||||||
|
# 创建拼音到汉字的映射字典
|
||||||
|
def create_pinyin_dict():
|
||||||
|
"""
|
||||||
|
创建拼音到汉字的映射字典
|
||||||
|
"""
|
||||||
|
# 常用汉字范围
|
||||||
|
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||||
|
pinyin_dict = defaultdict(list)
|
||||||
|
|
||||||
|
# 为每个汉字建立拼音映射
|
||||||
|
for char in chars:
|
||||||
|
try:
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
pinyin_dict[py].append(char)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
|
def is_chinese_char(char):
|
||||||
|
"""
|
||||||
|
判断是否为汉字
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return '\u4e00' <= char <= '\u9fff'
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_pinyin(sentence):
|
||||||
|
"""
|
||||||
|
将中文句子拆分成单个汉字并获取其拼音
|
||||||
|
:param sentence: 输入的中文句子
|
||||||
|
:return: 每个汉字及其拼音的列表
|
||||||
|
"""
|
||||||
|
# 将句子拆分成单个字符
|
||||||
|
characters = list(sentence)
|
||||||
|
|
||||||
|
# 获取每个字符的拼音
|
||||||
|
result = []
|
||||||
|
for char in characters:
|
||||||
|
# 跳过空格和非汉字字符
|
||||||
|
if char.isspace() or not is_chinese_char(char):
|
||||||
|
continue
|
||||||
|
# 获取拼音(数字声调)
|
||||||
|
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||||
|
result.append((char, py))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取同音字,按照使用频率排序
|
||||||
|
"""
|
||||||
|
homophones = pinyin_dict[py]
|
||||||
|
# 移除原字并过滤低频字
|
||||||
|
if char in homophones:
|
||||||
|
homophones.remove(char)
|
||||||
|
|
||||||
|
# 过滤掉低频字
|
||||||
|
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
# 按照字频排序
|
||||||
|
sorted_homophones = sorted(homophones,
|
||||||
|
key=lambda x: char_frequency.get(x, 0),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# 只返回前10个同音字,避免输出过多
|
||||||
|
return sorted_homophones[:10]
|
||||||
|
|
||||||
|
def get_similar_tone_pinyin(py):
|
||||||
|
"""
|
||||||
|
获取相似声调的拼音
|
||||||
|
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
||||||
|
处理特殊情况:
|
||||||
|
1. 轻声(如 'de5' 或 'le')
|
||||||
|
2. 非数字结尾的拼音
|
||||||
|
"""
|
||||||
|
# 检查拼音是否为空或无效
|
||||||
|
if not py or len(py) < 1:
|
||||||
|
return py
|
||||||
|
|
||||||
|
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||||
|
if not py[-1].isdigit():
|
||||||
|
# 为非数字结尾的拼音添加数字声调1
|
||||||
|
return py + '1'
|
||||||
|
|
||||||
|
base = py[:-1] # 去掉声调
|
||||||
|
tone = int(py[-1]) # 获取声调
|
||||||
|
|
||||||
|
# 处理轻声(通常用5表示)或无效声调
|
||||||
|
if tone not in [1, 2, 3, 4]:
|
||||||
|
return base + str(random.choice([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
# 正常处理声调
|
||||||
|
possible_tones = [1, 2, 3, 4]
|
||||||
|
possible_tones.remove(tone) # 移除原声调
|
||||||
|
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||||
|
return base + str(new_tone)
|
||||||
|
|
||||||
|
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
||||||
|
"""
|
||||||
|
根据频率差计算替换概率
|
||||||
|
频率差越大,概率越低
|
||||||
|
:param orig_freq: 原字频率
|
||||||
|
:param target_freq: 目标字频率
|
||||||
|
:param max_freq_diff: 最大允许的频率差
|
||||||
|
:return: 0-1之间的概率值
|
||||||
|
"""
|
||||||
|
if target_freq > orig_freq:
|
||||||
|
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||||
|
|
||||||
|
freq_diff = orig_freq - target_freq
|
||||||
|
if freq_diff > max_freq_diff:
|
||||||
|
return 0.0 # 频率差太大,不替换
|
||||||
|
|
||||||
|
# 使用指数衰减函数计算概率
|
||||||
|
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||||
|
return math.exp(-3 * freq_diff / max_freq_diff)
|
||||||
|
|
||||||
|
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
||||||
|
"""
|
||||||
|
获取与给定字频率相近的同音字,可能包含声调错误
|
||||||
|
"""
|
||||||
|
homophones = []
|
||||||
|
|
||||||
|
# 有20%的概率使用错误声调
|
||||||
|
if random.random() < tone_error_rate:
|
||||||
|
wrong_tone_py = get_similar_tone_pinyin(py)
|
||||||
|
homophones.extend(pinyin_dict[wrong_tone_py])
|
||||||
|
|
||||||
|
# 添加正确声调的同音字
|
||||||
|
homophones.extend(pinyin_dict[py])
|
||||||
|
|
||||||
|
if not homophones:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取原字的频率
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
|
||||||
|
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||||
|
freq_diff = [(h, char_frequency.get(h, 0))
|
||||||
|
for h in homophones
|
||||||
|
if h != char and char_frequency.get(h, 0) >= min_freq]
|
||||||
|
|
||||||
|
if not freq_diff:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算每个候选字的替换概率
|
||||||
|
candidates_with_prob = []
|
||||||
|
for h, freq in freq_diff:
|
||||||
|
prob = calculate_replacement_probability(orig_freq, freq)
|
||||||
|
if prob > 0: # 只保留有效概率的候选字
|
||||||
|
candidates_with_prob.append((h, prob))
|
||||||
|
|
||||||
|
if not candidates_with_prob:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 根据概率排序
|
||||||
|
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 返回概率最高的几个字
|
||||||
|
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||||
|
|
||||||
|
def get_word_pinyin(word):
|
||||||
|
"""
|
||||||
|
获取词语的拼音列表
|
||||||
|
"""
|
||||||
|
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||||
|
|
||||||
|
def segment_sentence(sentence):
|
||||||
|
"""
|
||||||
|
使用jieba分词,返回词语列表
|
||||||
|
"""
|
||||||
|
return list(jieba.cut(sentence))
|
||||||
|
|
||||||
|
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
|
||||||
|
"""
|
||||||
|
获取整个词的同音词,只返回高频的有意义词语
|
||||||
|
:param word: 输入词语
|
||||||
|
:param pinyin_dict: 拼音字典
|
||||||
|
:param char_frequency: 字频字典
|
||||||
|
:param min_freq: 最小频率阈值
|
||||||
|
:return: 同音词列表
|
||||||
|
"""
|
||||||
|
if len(word) == 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取词的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
word_pinyin_str = ''.join(word_pinyin)
|
||||||
|
|
||||||
|
# 创建词语频率字典
|
||||||
|
word_freq = defaultdict(float)
|
||||||
|
|
||||||
|
# 遍历所有可能的同音字组合
|
||||||
|
candidates = []
|
||||||
|
for py in word_pinyin:
|
||||||
|
chars = pinyin_dict.get(py, [])
|
||||||
|
if not chars:
|
||||||
|
return []
|
||||||
|
candidates.append(chars)
|
||||||
|
|
||||||
|
# 生成所有可能的组合
|
||||||
|
import itertools
|
||||||
|
all_combinations = itertools.product(*candidates)
|
||||||
|
|
||||||
|
# 获取jieba词典和词频信息
|
||||||
|
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||||
|
valid_words = {} # 改用字典存储词语及其频率
|
||||||
|
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
word_text = parts[0]
|
||||||
|
word_freq = float(parts[1]) # 获取词频
|
||||||
|
valid_words[word_text] = word_freq
|
||||||
|
|
||||||
|
# 获取原词的词频作为参考
|
||||||
|
original_word_freq = valid_words.get(word, 0)
|
||||||
|
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||||
|
|
||||||
|
# 过滤和计算频率
|
||||||
|
homophones = []
|
||||||
|
for combo in all_combinations:
|
||||||
|
new_word = ''.join(combo)
|
||||||
|
if new_word != word and new_word in valid_words:
|
||||||
|
new_word_freq = valid_words[new_word]
|
||||||
|
# 只保留词频达到阈值的词
|
||||||
|
if new_word_freq >= min_word_freq:
|
||||||
|
# 计算词的平均字频(考虑字频和词频)
|
||||||
|
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||||
|
# 综合评分:结合词频和字频
|
||||||
|
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
|
||||||
|
if combined_score >= min_freq:
|
||||||
|
homophones.append((new_word, combined_score))
|
||||||
|
|
||||||
|
# 按综合分数排序并限制返回数量
|
||||||
|
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||||
|
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||||
|
|
||||||
|
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
|
||||||
|
"""
|
||||||
|
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||||
|
只使用高频的有意义词语进行替换
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
typo_info = []
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# 如果是标点符号或空格,直接添加
|
||||||
|
if all(not is_chinese_char(c) for c in word):
|
||||||
|
result.append(word)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取词语的拼音
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
|
||||||
|
# 尝试整词替换
|
||||||
|
if len(word) > 1 and random.random() < word_replace_rate:
|
||||||
|
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
|
||||||
|
if word_homophones:
|
||||||
|
typo_word = random.choice(word_homophones)
|
||||||
|
# 计算词的平均频率
|
||||||
|
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
|
||||||
|
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||||
|
|
||||||
|
# 添加到结果中
|
||||||
|
result.append(typo_word)
|
||||||
|
typo_info.append((word, typo_word,
|
||||||
|
' '.join(word_pinyin),
|
||||||
|
' '.join(get_word_pinyin(typo_word)),
|
||||||
|
orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果不进行整词替换,则进行单字替换
|
||||||
|
if len(word) == 1:
|
||||||
|
char = word
|
||||||
|
py = word_pinyin[0]
|
||||||
|
if random.random() < error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
result.append(char)
|
||||||
|
else:
|
||||||
|
# 处理多字词的单字替换
|
||||||
|
word_result = []
|
||||||
|
for i, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||||
|
# 词中的字替换概率降低
|
||||||
|
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
|
||||||
|
|
||||||
|
if random.random() < word_error_rate:
|
||||||
|
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
||||||
|
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
||||||
|
if similar_chars:
|
||||||
|
typo_char = random.choice(similar_chars)
|
||||||
|
typo_freq = char_frequency.get(typo_char, 0)
|
||||||
|
orig_freq = char_frequency.get(char, 0)
|
||||||
|
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
||||||
|
if random.random() < replace_prob:
|
||||||
|
word_result.append(typo_char)
|
||||||
|
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||||
|
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||||
|
continue
|
||||||
|
word_result.append(char)
|
||||||
|
result.append(''.join(word_result))
|
||||||
|
|
||||||
|
return ''.join(result), typo_info
|
||||||
|
|
||||||
|
def format_frequency(freq):
|
||||||
|
"""
|
||||||
|
格式化频率显示
|
||||||
|
"""
|
||||||
|
return f"{freq:.2f}"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 首先创建拼音字典和加载字频统计
|
||||||
|
print("正在加载汉字数据库,请稍候...")
|
||||||
|
pinyin_dict = create_pinyin_dict()
|
||||||
|
char_frequency = load_or_create_char_frequency()
|
||||||
|
|
||||||
|
# 获取用户输入
|
||||||
|
sentence = input("请输入中文句子:")
|
||||||
|
|
||||||
|
# 创建包含错别字的句子
|
||||||
|
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
||||||
|
error_rate=0.3, min_freq=5,
|
||||||
|
tone_error_rate=0.2, word_replace_rate=0.3)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n原句:", sentence)
|
||||||
|
print("错字版:", typo_sentence)
|
||||||
|
|
||||||
|
if typo_info:
|
||||||
|
print("\n错别字信息:")
|
||||||
|
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||||
|
# 判断是否为词语替换
|
||||||
|
is_word = ' ' in orig_py
|
||||||
|
if is_word:
|
||||||
|
error_type = "整词替换"
|
||||||
|
else:
|
||||||
|
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||||
|
error_type = "声调错误" if tone_error else "同音字替换"
|
||||||
|
|
||||||
|
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
||||||
|
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
||||||
|
|
||||||
|
# 获取拼音结果
|
||||||
|
result = get_pinyin(sentence)
|
||||||
|
|
||||||
|
# 打印完整拼音
|
||||||
|
print("\n完整拼音:")
|
||||||
|
print(" ".join(py for _, py in result))
|
||||||
|
|
||||||
|
# 打印词语分析
|
||||||
|
print("\n词语分析:")
|
||||||
|
words = segment_sentence(sentence)
|
||||||
|
for word in words:
|
||||||
|
if any(is_chinese_char(c) for c in word):
|
||||||
|
word_pinyin = get_word_pinyin(word)
|
||||||
|
print(f"词语:{word}")
|
||||||
|
print(f"拼音:{' '.join(word_pinyin)}")
|
||||||
|
print("---")
|
||||||
|
|
||||||
|
# 计算并打印总耗时
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
print(f"\n总耗时:{total_time:.2f}秒")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,301 +0,0 @@
|
|||||||
from pypinyin import pinyin, Style
|
|
||||||
from collections import defaultdict
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import unicodedata
|
|
||||||
import jieba
|
|
||||||
import jieba.posseg as pseg
|
|
||||||
from pathlib import Path
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
|
|
||||||
def load_or_create_char_frequency():
|
|
||||||
"""
|
|
||||||
加载或创建汉字频率字典
|
|
||||||
"""
|
|
||||||
cache_file = Path("char_frequency.json")
|
|
||||||
|
|
||||||
# 如果缓存文件存在,直接加载
|
|
||||||
if cache_file.exists():
|
|
||||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
# 使用内置的词频文件
|
|
||||||
char_freq = defaultdict(int)
|
|
||||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
|
||||||
|
|
||||||
# 读取jieba的词典文件
|
|
||||||
with open(dict_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
word, freq = line.strip().split()[:2]
|
|
||||||
# 对词中的每个字进行频率累加
|
|
||||||
for char in word:
|
|
||||||
if is_chinese_char(char):
|
|
||||||
char_freq[char] += int(freq)
|
|
||||||
|
|
||||||
# 归一化频率值
|
|
||||||
max_freq = max(char_freq.values())
|
|
||||||
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
|
||||||
|
|
||||||
# 保存到缓存文件
|
|
||||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
return normalized_freq
|
|
||||||
|
|
||||||
# 创建拼音到汉字的映射字典
|
|
||||||
def create_pinyin_dict():
|
|
||||||
"""
|
|
||||||
创建拼音到汉字的映射字典
|
|
||||||
"""
|
|
||||||
# 常用汉字范围
|
|
||||||
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
|
||||||
pinyin_dict = defaultdict(list)
|
|
||||||
|
|
||||||
# 为每个汉字建立拼音映射
|
|
||||||
for char in chars:
|
|
||||||
try:
|
|
||||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
|
||||||
pinyin_dict[py].append(char)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return pinyin_dict
|
|
||||||
|
|
||||||
def is_chinese_char(char):
|
|
||||||
"""
|
|
||||||
判断是否为汉字
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return '\u4e00' <= char <= '\u9fff'
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_pinyin(sentence):
|
|
||||||
"""
|
|
||||||
将中文句子拆分成单个汉字并获取其拼音
|
|
||||||
:param sentence: 输入的中文句子
|
|
||||||
:return: 每个汉字及其拼音的列表
|
|
||||||
"""
|
|
||||||
# 将句子拆分成单个字符
|
|
||||||
characters = list(sentence)
|
|
||||||
|
|
||||||
# 获取每个字符的拼音
|
|
||||||
result = []
|
|
||||||
for char in characters:
|
|
||||||
# 跳过空格和非汉字字符
|
|
||||||
if char.isspace() or not is_chinese_char(char):
|
|
||||||
continue
|
|
||||||
# 获取拼音(数字声调)
|
|
||||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
|
||||||
result.append((char, py))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
|
|
||||||
"""
|
|
||||||
获取同音字,按照使用频率排序
|
|
||||||
"""
|
|
||||||
homophones = pinyin_dict[py]
|
|
||||||
# 移除原字并过滤低频字
|
|
||||||
if char in homophones:
|
|
||||||
homophones.remove(char)
|
|
||||||
|
|
||||||
# 过滤掉低频字
|
|
||||||
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
|
|
||||||
|
|
||||||
# 按照字频排序
|
|
||||||
sorted_homophones = sorted(homophones,
|
|
||||||
key=lambda x: char_frequency.get(x, 0),
|
|
||||||
reverse=True)
|
|
||||||
|
|
||||||
# 只返回前10个同音字,避免输出过多
|
|
||||||
return sorted_homophones[:10]
|
|
||||||
|
|
||||||
def get_similar_tone_pinyin(py):
|
|
||||||
"""
|
|
||||||
获取相似声调的拼音
|
|
||||||
例如:'ni3' 可能返回 'ni2' 或 'ni4'
|
|
||||||
"""
|
|
||||||
base = py[:-1] # 去掉声调
|
|
||||||
tone = int(py[-1]) # 获取声调
|
|
||||||
possible_tones = [1, 2, 3, 4]
|
|
||||||
possible_tones.remove(tone) # 移除原声调
|
|
||||||
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
|
||||||
return base + str(new_tone)
|
|
||||||
|
|
||||||
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
|
|
||||||
"""
|
|
||||||
根据频率差计算替换概率
|
|
||||||
频率差越大,概率越低
|
|
||||||
:param orig_freq: 原字频率
|
|
||||||
:param target_freq: 目标字频率
|
|
||||||
:param max_freq_diff: 最大允许的频率差
|
|
||||||
:return: 0-1之间的概率值
|
|
||||||
"""
|
|
||||||
if target_freq > orig_freq:
|
|
||||||
return 1.0 # 如果替换字频率更高,保持原有概率
|
|
||||||
|
|
||||||
freq_diff = orig_freq - target_freq
|
|
||||||
if freq_diff > max_freq_diff:
|
|
||||||
return 0.0 # 频率差太大,不替换
|
|
||||||
|
|
||||||
# 使用指数衰减函数计算概率
|
|
||||||
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
|
||||||
return math.exp(-3 * freq_diff / max_freq_diff)
|
|
||||||
|
|
||||||
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
|
|
||||||
"""
|
|
||||||
获取与给定字频率相近的同音字,可能包含声调错误
|
|
||||||
"""
|
|
||||||
homophones = []
|
|
||||||
|
|
||||||
# 有20%的概率使用错误声调
|
|
||||||
if random.random() < tone_error_rate:
|
|
||||||
wrong_tone_py = get_similar_tone_pinyin(py)
|
|
||||||
homophones.extend(pinyin_dict[wrong_tone_py])
|
|
||||||
|
|
||||||
# 添加正确声调的同音字
|
|
||||||
homophones.extend(pinyin_dict[py])
|
|
||||||
|
|
||||||
if not homophones:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取原字的频率
|
|
||||||
orig_freq = char_frequency.get(char, 0)
|
|
||||||
|
|
||||||
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
|
||||||
freq_diff = [(h, char_frequency.get(h, 0))
|
|
||||||
for h in homophones
|
|
||||||
if h != char and char_frequency.get(h, 0) >= min_freq]
|
|
||||||
|
|
||||||
if not freq_diff:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 计算每个候选字的替换概率
|
|
||||||
candidates_with_prob = []
|
|
||||||
for h, freq in freq_diff:
|
|
||||||
prob = calculate_replacement_probability(orig_freq, freq)
|
|
||||||
if prob > 0: # 只保留有效概率的候选字
|
|
||||||
candidates_with_prob.append((h, prob))
|
|
||||||
|
|
||||||
if not candidates_with_prob:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 根据概率排序
|
|
||||||
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
# 返回概率最高的几个字
|
|
||||||
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
|
||||||
|
|
||||||
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2):
|
|
||||||
"""
|
|
||||||
创建包含同音字错误的句子,保留原文标点符号
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
typo_info = []
|
|
||||||
|
|
||||||
# 获取每个字的拼音
|
|
||||||
chars_with_pinyin = get_pinyin(sentence)
|
|
||||||
|
|
||||||
# 创建原字到拼音的映射,用于跟踪已处理的字符
|
|
||||||
processed_chars = {char: py for char, py in chars_with_pinyin}
|
|
||||||
|
|
||||||
# 遍历原句中的每个字符
|
|
||||||
char_index = 0
|
|
||||||
for i, char in enumerate(sentence):
|
|
||||||
if char.isspace():
|
|
||||||
# 保留空格
|
|
||||||
result.append(char)
|
|
||||||
elif char in processed_chars:
|
|
||||||
# 处理汉字
|
|
||||||
py = processed_chars[char]
|
|
||||||
# 基础错误率
|
|
||||||
if random.random() < error_rate:
|
|
||||||
# 获取频率相近的同音字(可能包含声调错误)
|
|
||||||
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
|
|
||||||
min_freq=min_freq, tone_error_rate=tone_error_rate)
|
|
||||||
if similar_chars:
|
|
||||||
# 随机选择一个替换字
|
|
||||||
typo_char = random.choice(similar_chars)
|
|
||||||
# 获取替换字的频率
|
|
||||||
typo_freq = char_frequency.get(typo_char, 0)
|
|
||||||
orig_freq = char_frequency.get(char, 0)
|
|
||||||
|
|
||||||
# 计算实际替换概率
|
|
||||||
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
|
|
||||||
|
|
||||||
# 根据频率差进行概率替换
|
|
||||||
if random.random() < replace_prob:
|
|
||||||
result.append(typo_char)
|
|
||||||
# 获取替换字的实际拼音
|
|
||||||
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
|
||||||
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
|
||||||
else:
|
|
||||||
result.append(char)
|
|
||||||
else:
|
|
||||||
result.append(char)
|
|
||||||
else:
|
|
||||||
result.append(char)
|
|
||||||
char_index += 1
|
|
||||||
else:
|
|
||||||
# 保留非汉字字符(标点符号等)
|
|
||||||
result.append(char)
|
|
||||||
|
|
||||||
return ''.join(result), typo_info
|
|
||||||
|
|
||||||
def format_frequency(freq):
|
|
||||||
"""
|
|
||||||
格式化频率显示
|
|
||||||
"""
|
|
||||||
return f"{freq:.2f}"
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 首先创建拼音字典和加载字频统计
|
|
||||||
print("正在加载汉字数据库,请稍候...")
|
|
||||||
pinyin_dict = create_pinyin_dict()
|
|
||||||
char_frequency = load_or_create_char_frequency()
|
|
||||||
|
|
||||||
# 获取用户输入
|
|
||||||
sentence = input("请输入中文句子:")
|
|
||||||
|
|
||||||
# 创建包含错别字的句子
|
|
||||||
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
|
|
||||||
min_freq=5, tone_error_rate=0.2)
|
|
||||||
|
|
||||||
# 打印结果
|
|
||||||
print("\n原句:", sentence)
|
|
||||||
print("错字版:", typo_sentence)
|
|
||||||
|
|
||||||
if typo_info:
|
|
||||||
print("\n错别字信息:")
|
|
||||||
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
|
||||||
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
|
||||||
error_type = "声调错误" if tone_error else "同音字替换"
|
|
||||||
print(f"原字:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
|
|
||||||
f"错字:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
|
|
||||||
|
|
||||||
# 获取拼音结果
|
|
||||||
result = get_pinyin(sentence)
|
|
||||||
|
|
||||||
# 打印完整拼音
|
|
||||||
print("\n完整拼音:")
|
|
||||||
print(" ".join(py for _, py in result))
|
|
||||||
|
|
||||||
# 打印所有可能的同音字
|
|
||||||
print("\n每个字的所有同音字(按频率排序,仅显示频率>=5的字):")
|
|
||||||
for char, py in result:
|
|
||||||
homophones = get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5)
|
|
||||||
char_freq = char_frequency.get(char, 0)
|
|
||||||
print(f"{char}: {py} [频率:{format_frequency(char_freq)}]")
|
|
||||||
if homophones:
|
|
||||||
homophone_info = []
|
|
||||||
for h in homophones:
|
|
||||||
h_freq = char_frequency.get(h, 0)
|
|
||||||
homophone_info.append(f"{h}[{format_frequency(h_freq)}]")
|
|
||||||
print(f"同音字: {','.join(homophone_info)}")
|
|
||||||
else:
|
|
||||||
print("没有找到频率>=5的同音字")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# 插件配置
|
# 插件配置
|
||||||
PLUGINS=["src2.plugins.chat"]
|
PLUGINS=["src2.plugins.chat"]
|
||||||
|
|
||||||
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
|
|||||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
#key and url
|
#key and url
|
||||||
|
|
||||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
#定义你要用的api的base_url
|
||||||
DEEP_SEEK_KEY=
|
DEEP_SEEK_KEY=
|
||||||
CHAT_ANY_WHERE_KEY=
|
CHAT_ANY_WHERE_KEY=
|
||||||
SILICONFLOW_KEY=
|
SILICONFLOW_KEY=
|
||||||
Reference in New Issue
Block a user