初始化
This commit is contained in:
4
.github/prompts/chat.prompt.md
vendored
Normal file
4
.github/prompts/chat.prompt.md
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
---
|
||||
mode: agent
|
||||
---
|
||||
记得执行前激活虚拟环境,用的shell是powershell与linux语法有区别
|
||||
@@ -1,121 +0,0 @@
|
||||
# 贡献者契约行为准则
|
||||
|
||||
## 我们的承诺
|
||||
|
||||
作为成员、贡献者和维护者,我们承诺为每个人提供友好、安全和受欢迎的环境,无论年龄、体型、身体或精神上的残疾、民族、性别特征、性别认同和表达、经验水平、教育、社会经济地位、国籍、个人外貌、种族、宗教或性取向如何。
|
||||
|
||||
我们承诺以有助于建立开放、友好、多元化、包容和健康社区的方式行事和互动。
|
||||
|
||||
## 我们的标准
|
||||
|
||||
有助于为我们的社区创造积极环境的行为示例包括:
|
||||
|
||||
* 表现出对其他人的同理心和善意
|
||||
* 尊重不同的意见、观点和经验
|
||||
* 优雅地给出和接受建设性反馈
|
||||
* 承担责任,为我们的错误向受影响的人道歉,并从中学习经验
|
||||
* 专注于不仅对我们个人,而且对整个社区最有利的事情
|
||||
* 使用友善和包容的语言
|
||||
* 专业地讨论技术问题,避免人身攻击
|
||||
|
||||
不可接受的行为示例包括:
|
||||
|
||||
* 使用性暗示的语言或图像,以及任何形式的性关注或性挑逗
|
||||
* 恶意评论、侮辱或贬损性评论,以及人身攻击或政治攻击
|
||||
* 公开或私下的骚扰
|
||||
* 未经明确许可,发布他人的私人信息,如物理地址或电子邮件地址
|
||||
* 在专业环境中合理认为不当的其他行为
|
||||
* 故意传播错误信息或误导性内容
|
||||
* 恶意破坏项目资源或社区讨论
|
||||
|
||||
## 执行责任
|
||||
|
||||
社区维护者负责澄清和执行我们可接受行为的标准,并会对他们认为不当、威胁、冒犯或有害的任何行为采取适当和公平的纠正措施。
|
||||
|
||||
社区维护者有权删除、编辑或拒绝与本行为准则不符的评论、提交、代码、wiki编辑、问题和其他贡献,并会在适当时传达审核决定的原因。
|
||||
|
||||
## 适用范围
|
||||
|
||||
本行为准则适用于所有社区空间,包括但不限于:
|
||||
|
||||
* GitHub 仓库及相关讨论区
|
||||
* Issue 和 Pull Request 讨论
|
||||
* 项目相关的在线论坛、聊天室和社交媒体
|
||||
* 项目官方活动和会议
|
||||
* 代表项目或社区的任何其他场合
|
||||
|
||||
当个人代表项目或其社区时,本行为准则也适用于公共空间。代表的示例包括使用官方电子邮件地址、通过官方社交媒体账户发布信息,或在在线或线下活动中担任指定代表。
|
||||
|
||||
## 特定于MaiBot项目的指导原则
|
||||
|
||||
### 技术讨论原则
|
||||
* 保持技术讨论的专业性和建设性
|
||||
* 在提出问题前,请先查看现有文档和已有的issues
|
||||
* 提供清晰、详细的错误报告和功能请求
|
||||
* 尊重不同的技术选择和实现方案
|
||||
|
||||
### AI/LLM相关内容规范
|
||||
* 讨论AI技术应当负责任和伦理
|
||||
* 不得分享或讨论可能造成伤害的AI应用
|
||||
* 尊重数据隐私和用户权益
|
||||
* 遵守相关法律法规和平台政策
|
||||
|
||||
### 多语言支持
|
||||
* 主要使用中文进行交流,但欢迎其他语言的贡献者
|
||||
* 对非中文母语用户保持耐心和友善
|
||||
* 在必要时提供翻译帮助
|
||||
|
||||
## 报告机制
|
||||
|
||||
如果您遇到或目睹违反行为准则的行为,请通过以下方式报告:
|
||||
|
||||
1. **GitHub Issues**: 对于公开的违规行为,可以在相关issue中直接指出
|
||||
2. **私下联系**: 可以通过GitHub私信联系项目维护者
|
||||
3. **邮件联系**: [如果有项目邮箱地址,请在此提供]
|
||||
|
||||
所有报告都将得到及时和公正的处理。我们承诺保护报告者的隐私和安全。
|
||||
|
||||
## 执行措施
|
||||
|
||||
社区维护者将遵循以下社区影响指导原则来确定违反本行为准则的后果:
|
||||
|
||||
### 1. 更正
|
||||
**社区影响**: 使用不当语言或其他被认为在社区中不专业或不受欢迎的行为。
|
||||
|
||||
**后果**: 由社区维护者私下发出书面警告,提供关于违规性质的明确说明和行为不当的原因解释。可能会要求公开道歉。
|
||||
|
||||
### 2. 警告
|
||||
**社区影响**: 通过单个事件或一系列行为违规。
|
||||
|
||||
**后果**: 警告并说明继续违规的后果。在规定的时间内,不得与相关人员互动,包括主动与执行行为准则的人员互动。这包括避免在社区空间以及外部渠道(如社交媒体)中的互动。违反这些条款可能导致临时或永久禁令。
|
||||
|
||||
### 3. 临时禁令
|
||||
**社区影响**: 严重违反社区标准,包括持续的不当行为。
|
||||
|
||||
**后果**: 在规定的时间内临时禁止与社区进行任何形式的互动或公开交流。在此期间,不允许与相关人员进行公开或私下互动,包括主动与执行行为准则的人员互动。违反这些条款可能导致永久禁令。
|
||||
|
||||
### 4. 永久禁令
|
||||
**社区影响**: 表现出违反社区标准的模式,包括持续的不当行为、对个人的骚扰,或对某类个人的攻击或贬低。
|
||||
|
||||
**后果**: 永久禁止在社区内进行任何形式的公开互动。
|
||||
|
||||
## 归属
|
||||
|
||||
本行为准则改编自[贡献者契约](https://www.contributor-covenant.org/),版本2.1,可在 https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 获得。
|
||||
|
||||
社区影响指导原则的灵感来自[Mozilla 的行为准则执行阶梯](https://github.com/mozilla/diversity)。
|
||||
|
||||
有关本行为准则的常见问题解答,请参见 https://www.contributor-covenant.org/faq。翻译版本可在 https://www.contributor-covenant.org/translations 获得。
|
||||
|
||||
## 联系方式
|
||||
|
||||
如果您对本行为准则有任何疑问或建议,请通过以下方式联系我们:
|
||||
|
||||
* 在GitHub上创建issue进行讨论
|
||||
* 联系项目维护者
|
||||
|
||||
---
|
||||
|
||||
**感谢您帮助我们建设一个友好、包容的开源社区!**
|
||||
|
||||
*最后更新时间: 2025年6月21日*
|
||||
21
MaiBot-Napcat-Adapter-dev/.devcontainer/devcontainer.json
Normal file
21
MaiBot-Napcat-Adapter-dev/.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "MaiBot-Napcat-Adapter-DevContainer",
|
||||
"image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye",
|
||||
"features": {
|
||||
"ghcr.io/rocker-org/devcontainer-features/apt-packages:1": {
|
||||
"packages": [
|
||||
"tmux"
|
||||
]
|
||||
},
|
||||
"ghcr.io/devcontainers/features/github-cli:1": {}
|
||||
},
|
||||
"forwardPorts": [
|
||||
"8095:8095"
|
||||
],
|
||||
"postCreateCommand": "pip3 install --user -r requirements.txt",
|
||||
"customizations" : {
|
||||
"jetbrains" : {
|
||||
"backend" : "PyCharm"
|
||||
}
|
||||
}
|
||||
}
|
||||
54
MaiBot-Napcat-Adapter-dev/.github/workflows/docker-image.yml
vendored
Normal file
54
MaiBot-Napcat-Adapter-dev/.github/workflows/docker-image.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Docker Image CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main", "dev" ]
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKERHUB_USER: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
DATE_TAG: $(date -u +'%Y-%m-%dT%H-%M-%S')
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Clone maim_message
|
||||
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
|
||||
|
||||
- name: Determine Image Tags
|
||||
id: tags
|
||||
run: |
|
||||
if [ "${{ github.ref_name }}" == "main" ]; then
|
||||
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:latest,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:main-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ github.ref_name }}" == "dev" ]; then
|
||||
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ${{ steps.tags.outputs.tags }}
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }}
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }},mode=max
|
||||
labels: |
|
||||
org.opencontainers.image.created=${{ steps.tags.outputs.date_tag }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
277
MaiBot-Napcat-Adapter-dev/.gitignore
vendored
Normal file
277
MaiBot-Napcat-Adapter-dev/.gitignore
vendored
Normal file
@@ -0,0 +1,277 @@
|
||||
|
||||
log/
|
||||
logs/
|
||||
out/
|
||||
|
||||
.env
|
||||
.env.*
|
||||
.cursor
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
llm_statistics.txt
|
||||
mongodb
|
||||
napcat
|
||||
run_dev.bat
|
||||
elua.confirmed
|
||||
# C extensions
|
||||
*.so
|
||||
/results
|
||||
config_backup/
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# jieba
|
||||
jieba.cache
|
||||
|
||||
# .vscode
|
||||
!.vscode/settings.json
|
||||
|
||||
# direnv
|
||||
/.direnv
|
||||
|
||||
# JetBrains
|
||||
.idea
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
# PyEnv
|
||||
# If using PyEnv and configured to use a specific Python version locally
|
||||
# a .local-version file will be created in the root of the project to specify the version.
|
||||
.python-version
|
||||
|
||||
OtherRes.txt
|
||||
|
||||
/eula.confirmed
|
||||
/privacy.confirmed
|
||||
|
||||
logs
|
||||
|
||||
.ruff_cache
|
||||
|
||||
.vscode
|
||||
|
||||
/config/*
|
||||
config/old/bot_config_20250405_212257.toml
|
||||
temp/
|
||||
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
config.toml
|
||||
config.toml.back
|
||||
test
|
||||
data/NapcatAdapter.db
|
||||
data/NapcatAdapter.db-shm
|
||||
data/NapcatAdapter.db-wal
|
||||
20
MaiBot-Napcat-Adapter-dev/Dockerfile
Normal file
20
MaiBot-Napcat-Adapter-dev/Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
FROM python:3.13.5-slim
|
||||
LABEL authors="infinitycat233"
|
||||
|
||||
# Copy uv and maim_message
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
COPY maim_message /maim_message
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
# Install requirements
|
||||
RUN uv pip install --system --upgrade pip
|
||||
RUN uv pip install --system -e /maim_message
|
||||
RUN uv pip install --system -r /requirements.txt
|
||||
|
||||
WORKDIR /adapters
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8095
|
||||
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
674
MaiBot-Napcat-Adapter-dev/LICENSE
Normal file
674
MaiBot-Napcat-Adapter-dev/LICENSE
Normal file
@@ -0,0 +1,674 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
83
MaiBot-Napcat-Adapter-dev/README.md
Normal file
83
MaiBot-Napcat-Adapter-dev/README.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# MaiBot 与 Napcat 的 Adapter
|
||||
运行方式:独立/放在MaiBot本体作为插件
|
||||
|
||||
# 使用说明
|
||||
请参考[官方文档](https://docs.mai-mai.org/manual/adapters/napcat.html)
|
||||
|
||||
# 消息流转过程
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Napcat as Napcat客户端
|
||||
participant Adapter as MaiBot-Napcat适配器
|
||||
participant Queue as 消息队列
|
||||
participant Handler as 消息处理器
|
||||
participant MaiBot as MaiBot服务
|
||||
|
||||
Note over Napcat,MaiBot: 初始化阶段
|
||||
Napcat->>Adapter: WebSocket连接(ws://localhost:8095)
|
||||
Adapter->>MaiBot: WebSocket连接(ws://localhost:8000)
|
||||
|
||||
Note over Napcat,MaiBot: 心跳检测
|
||||
loop 每30秒
|
||||
Napcat->>Adapter: 发送心跳包
|
||||
Adapter->>Napcat: 心跳响应
|
||||
end
|
||||
|
||||
Note over Napcat,MaiBot: 消息处理流程
|
||||
Napcat->>Adapter: 发送消息
|
||||
Adapter->>Queue: 消息入队(message_queue)
|
||||
Queue->>Handler: 消息出队处理
|
||||
Handler->>Handler: 解析消息类型
|
||||
alt 文本消息
|
||||
Handler->>MaiBot: 发送文本消息
|
||||
else 图片消息
|
||||
Handler->>MaiBot: 发送图片消息
|
||||
else 混合消息
|
||||
Handler->>MaiBot: 发送混合消息
|
||||
else 转发消息
|
||||
Handler->>MaiBot: 发送转发消息
|
||||
end
|
||||
MaiBot-->>Adapter: 消息响应
|
||||
Adapter-->>Napcat: 消息响应
|
||||
|
||||
Note over Napcat,MaiBot: 优雅关闭
|
||||
Adapter->>MaiBot: 关闭连接
|
||||
Adapter->>Queue: 清空消息队列
|
||||
Adapter->>Napcat: 关闭连接
|
||||
```
|
||||
|
||||
|
||||
# TO DO List
|
||||
- [x] 读取自动心跳测试连接
|
||||
- [x] 接受消息解析
|
||||
- [x] 文本解析
|
||||
- [x] 图片解析
|
||||
- [x] 文本与消息混合解析
|
||||
- [x] 转发解析(含图片动态解析)
|
||||
- [ ] 群公告解析
|
||||
- [x] 回复解析
|
||||
- [ ] 群临时消息(可能不做)
|
||||
- [ ] 链接解析
|
||||
- [x] 戳一戳解析
|
||||
- [x] 读取戳一戳的自定义内容
|
||||
- [ ] 语音解析(?)
|
||||
- [ ] 所有的notice类
|
||||
- [x] 撤回(已添加相关指令)
|
||||
- [x] 发送消息
|
||||
- [x] 发送文本
|
||||
- [x] 发送图片
|
||||
- [x] 发送表情包
|
||||
- [x] 引用回复(完成但是没测试)
|
||||
- [ ] 戳回去(?)
|
||||
- [x] 发送语音
|
||||
- [x] 使用echo与uuid保证消息顺序
|
||||
- [x] 执行部分管理员功能
|
||||
- [x] 禁言别人
|
||||
- [x] 全体禁言
|
||||
- [x] 群踢人功能
|
||||
|
||||
# 特别鸣谢
|
||||
特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持
|
||||
|
||||
以及[@墨梓柒](https://github.com/DrSmoothl)对部分代码想法的支持
|
||||
60
MaiBot-Napcat-Adapter-dev/command_args.md
Normal file
60
MaiBot-Napcat-Adapter-dev/command_args.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Command Arguments
|
||||
```python
|
||||
Seg.type = "command"
|
||||
```
|
||||
## 群聊禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_BAN",
|
||||
"args": {
|
||||
"qq_id": "用户QQ号",
|
||||
"duration": "禁言时长(秒)"
|
||||
},
|
||||
}
|
||||
```
|
||||
其中,群聊ID将会通过Group_Info.group_id自动获取。
|
||||
|
||||
**当`duration`为 0 时相当于解除禁言。**
|
||||
## 群聊全体禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_WHOLE_BAN",
|
||||
"args": {
|
||||
"enable": "是否开启全体禁言(True/False)"
|
||||
},
|
||||
}
|
||||
```
|
||||
其中,群聊ID将会通过Group_Info.group_id自动获取。
|
||||
|
||||
`enable`的参数需要为boolean类型,True表示开启全体禁言,False表示关闭全体禁言。
|
||||
## 群聊踢人
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "GROUP_KICK",
|
||||
"args": {
|
||||
"qq_id": "用户QQ号",
|
||||
},
|
||||
}
|
||||
```
|
||||
其中,群聊ID将会通过Group_Info.group_id自动获取。
|
||||
|
||||
## 戳一戳
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "SEND_POKE",
|
||||
"args": {
|
||||
"qq_id": "目标QQ号"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 撤回消息
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"name": "DELETE_MSG",
|
||||
"args": {
|
||||
"message_id": "消息所对应的message_id"
|
||||
}
|
||||
}
|
||||
```
|
||||
其中message_id是消息的实际qq_id,于新版的mmc中可以从数据库获取(如果工作正常的话)
|
||||
90
MaiBot-Napcat-Adapter-dev/main.py
Normal file
90
MaiBot-Napcat-Adapter-dev/main.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import json
|
||||
import websockets as Server
|
||||
from src.logger import logger
|
||||
from src.recv_handler.message_handler import message_handler
|
||||
from src.recv_handler.meta_event_handler import meta_event_handler
|
||||
from src.recv_handler.notice_handler import notice_handler
|
||||
from src.recv_handler.message_sending import message_send_instance
|
||||
from src.send_handler import send_handler
|
||||
from src.config import global_config
|
||||
from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
|
||||
from src.response_pool import put_response, check_timeout_response
|
||||
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
|
||||
async def message_recv(server_connection: Server.ServerConnection):
|
||||
await message_handler.set_server_connection(server_connection)
|
||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
|
||||
async def message_process():
|
||||
while True:
|
||||
message = await message_queue.get()
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
elif post_type == "meta_event":
|
||||
await meta_event_handler.handle_meta_event(message)
|
||||
elif post_type == "notice":
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的post_type: {post_type}")
|
||||
message_queue.task_done()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
async def main():
|
||||
message_send_instance.maibot_router = router
|
||||
_ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response())
|
||||
|
||||
|
||||
async def napcat_server():
|
||||
logger.info("正在启动adapter...")
|
||||
async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26) as server:
|
||||
logger.info(
|
||||
f"Adapter已启动,监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}"
|
||||
)
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
async def graceful_shutdown():
|
||||
try:
|
||||
logger.info("正在关闭adapter...")
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15)
|
||||
await mmc_stop_com() # 后置避免神秘exception
|
||||
logger.info("Adapter已成功关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"Adapter关闭中出现错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("收到中断信号,正在优雅关闭...")
|
||||
loop.run_until_complete(graceful_shutdown())
|
||||
except Exception as e:
|
||||
logger.exception(f"主程序异常: {str(e)}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if loop and not loop.is_closed():
|
||||
loop.close()
|
||||
sys.exit(0)
|
||||
44
MaiBot-Napcat-Adapter-dev/notify_args.md
Normal file
44
MaiBot-Napcat-Adapter-dev/notify_args.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Notify Args
|
||||
```python
|
||||
Seg.type = "notify"
|
||||
```
|
||||
## 群聊成员被禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"sub_type": "ban",
|
||||
"duration": "对应的禁言时间,单位为秒",
|
||||
"banned_user_info": "被禁言的用户的信息,为标准UserInfo转换成的字典"
|
||||
}
|
||||
```
|
||||
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
|
||||
|
||||
**注意: `banned_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象**
|
||||
## 群聊开启全体禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"sub_type": "whole_ban",
|
||||
"duration": -1,
|
||||
"banned_user_info": None
|
||||
}
|
||||
```
|
||||
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
|
||||
## 群聊成员被解除禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"sub_type": "whole_lift_ban",
|
||||
"lifted_user_info": "被解除禁言的用户的信息,为标准UserInfo对象"
|
||||
}
|
||||
```
|
||||
**对于自然禁言解除的情况,此时`MessageBase.UserInfo`为`None`**
|
||||
|
||||
对于手动解除禁言的情况,此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
|
||||
|
||||
**注意: `lifted_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象**
|
||||
## 群聊关闭全体禁言
|
||||
```python
|
||||
Seg.data: Dict[str, Any] = {
|
||||
"sub_type": "whole_lift_ban",
|
||||
"lifted_user_info": None,
|
||||
}
|
||||
```
|
||||
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
|
||||
44
MaiBot-Napcat-Adapter-dev/pyproject.toml
Normal file
44
MaiBot-Napcat-Adapter-dev/pyproject.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[project]
|
||||
name = "MaiBotNapcatAdapter"
|
||||
version = "0.4.7"
|
||||
description = "A MaiBot adapter for Napcat"
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
include = ["*.py"]
|
||||
|
||||
# 行长度设置
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
|
||||
# 启用的规则
|
||||
select = [
|
||||
"E", # pycodestyle 错误
|
||||
"F", # pyflakes
|
||||
"B", # flake8-bugbear
|
||||
]
|
||||
|
||||
ignore = ["E711","E501"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
indent-style = "space"
|
||||
|
||||
|
||||
# 使用双引号表示字符串
|
||||
quote-style = "double"
|
||||
|
||||
# 尊重魔法尾随逗号
|
||||
# 例如:
|
||||
# items = [
|
||||
# "apple",
|
||||
# "banana",
|
||||
# "cherry",
|
||||
# ]
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# 自动检测合适的换行符
|
||||
line-ending = "auto"
|
||||
10
MaiBot-Napcat-Adapter-dev/requirements.txt
Normal file
10
MaiBot-Napcat-Adapter-dev/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
websockets
|
||||
aiohttp
|
||||
asyncio
|
||||
requests
|
||||
maim_message
|
||||
loguru
|
||||
pillow
|
||||
tomlkit
|
||||
rich
|
||||
sqlmodel
|
||||
24
MaiBot-Napcat-Adapter-dev/src/__init__.py
Normal file
24
MaiBot-Napcat-Adapter-dev/src/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from enum import Enum
|
||||
import tomlkit
|
||||
import os
|
||||
from .logger import logger
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml")
|
||||
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
|
||||
version = toml_data["project"]["version"]
|
||||
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")
|
||||
5
MaiBot-Napcat-Adapter-dev/src/config/__init__.py
Normal file
5
MaiBot-Napcat-Adapter-dev/src/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .config import global_config
|
||||
|
||||
__all__ = [
|
||||
"global_config",
|
||||
]
|
||||
146
MaiBot-Napcat-Adapter-dev/src/config/config.py
Normal file
146
MaiBot-Napcat-Adapter-dev/src/config/config.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from ..logger import logger
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
ChatConfig,
|
||||
DebugConfig,
|
||||
MaiBotServerConfig,
|
||||
NapcatServerConfig,
|
||||
NicknameConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TEMPLATE_DIR = "template"
|
||||
|
||||
|
||||
def update_config():
|
||||
# 定义文件路径
|
||||
template_path = f"{TEMPLATE_DIR}/template_config.toml"
|
||||
old_config_path = "config.toml"
|
||||
new_config_path = "config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.info("配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
||||
# 如果是新创建的配置文件,直接返回
|
||||
quit()
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份文件夹
|
||||
backup_dir = "config_backup"
|
||||
os.makedirs(backup_dir, exist_ok=True)
|
||||
|
||||
# 备份文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(backup_dir, f"config.toml.bak.{timestamp}")
|
||||
|
||||
# 备份旧配置文件
|
||||
shutil.copy2(old_config_path, old_backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
quit()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
nickname: NicknameConfig
|
||||
napcat_server: NapcatServerConfig
|
||||
maibot_server: MaiBotServerConfig
|
||||
chat: ChatConfig
|
||||
voice: VoiceConfig
|
||||
debug: DebugConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 更新配置
|
||||
update_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path="config.toml")
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
136
MaiBot-Napcat-Adapter-dev/src/config/config_base.py
Normal file
136
MaiBot-Napcat-Adapter-dev/src/config/config_base.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: Dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
field_type = f.type
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
return field_type.from_dict(value)
|
||||
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_args_type = get_args(field_type)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||||
if field_origin_type is set:
|
||||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||||
if field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_args_type):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_args_type) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_args_type
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理Optional类型
|
||||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||||
if value is None:
|
||||
return None
|
||||
# 如果有数据,检查实际类型
|
||||
if type(value) not in field_args_type:
|
||||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||||
return cls._convert_field(value, field_args_type[0])
|
||||
|
||||
# 处理int, str, float, bool等基础类型
|
||||
if field_origin_type is None:
|
||||
if isinstance(value, field_type):
|
||||
return field_type(value)
|
||||
else:
|
||||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
# 处理其他类型
|
||||
if field_type is Any:
|
||||
return value
|
||||
|
||||
# 其他类型直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
80
MaiBot-Napcat-Adapter-dev/src/config/official_configs.py
Normal file
80
MaiBot-Napcat-Adapter-dev/src/config/official_configs.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
ADAPTER_PLATFORM = "qq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NicknameConfig(ConfigBase):
|
||||
nickname: str
|
||||
"""机器人昵称"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NapcatServerConfig(ConfigBase):
|
||||
host: str = "localhost"
|
||||
"""Napcat服务端的主机地址"""
|
||||
|
||||
port: int = 8095
|
||||
"""Napcat服务端的端口号"""
|
||||
|
||||
heartbeat_interval: int = 30
|
||||
"""Napcat心跳间隔时间,单位为秒"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiBotServerConfig(ConfigBase):
|
||||
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
|
||||
"""平台名称,“qq”"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""MaiMCore的主机地址"""
|
||||
|
||||
port: int = 8000
|
||||
"""MaiMCore的端口号"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatConfig(ConfigBase):
|
||||
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""群聊列表类型 白名单/黑名单"""
|
||||
|
||||
group_list: list[int] = field(default_factory=[])
|
||||
"""群聊列表"""
|
||||
|
||||
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""私聊列表类型 白名单/黑名单"""
|
||||
|
||||
private_list: list[int] = field(default_factory=[])
|
||||
"""私聊列表"""
|
||||
|
||||
ban_user_id: list[int] = field(default_factory=[])
|
||||
"""被封禁的用户ID列表,封禁后将无法与其进行交互"""
|
||||
|
||||
ban_qq_bot: bool = False
|
||||
"""是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互"""
|
||||
|
||||
enable_poke: bool = True
|
||||
"""是否启用戳一戳功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
use_tts: bool = False
|
||||
"""是否启用TTS功能"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
"""日志级别,默认为INFO"""
|
||||
162
MaiBot-Napcat-Adapter-dev/src/database.py
Normal file
162
MaiBot-Napcat-Adapter-dev/src/database.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from src.logger import logger
|
||||
|
||||
"""
|
||||
表记录的方式:
|
||||
| group_id | user_id | lift_time |
|
||||
|----------|---------|-----------|
|
||||
|
||||
其中使用 user_id == 0 表示群全体禁言
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BanUser:
|
||||
"""
|
||||
程序处理使用的实例
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
group_id: int
|
||||
lift_time: Optional[int] = Field(default=-1)
|
||||
|
||||
|
||||
class DB_BanUser(SQLModel, table=True):
|
||||
"""
|
||||
表示数据库中的用户禁言记录。
|
||||
使用双重主键
|
||||
"""
|
||||
|
||||
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
|
||||
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
|
||||
lift_time: Optional[int] # 禁言解除的时间(时间戳)
|
||||
|
||||
|
||||
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
|
||||
"""
|
||||
检查两个 BanUser 对象是否相同。
|
||||
"""
|
||||
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
数据库管理类,负责与数据库交互。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
|
||||
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
|
||||
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
|
||||
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
|
||||
self._ensure_database() # 确保数据库和表已创建
|
||||
|
||||
def _ensure_database(self) -> None:
|
||||
"""
|
||||
确保数据库和表已创建。
|
||||
"""
|
||||
logger.info("确保数据库文件和表已创建...")
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
logger.success("数据库和表已创建或已存在")
|
||||
|
||||
def update_ban_record(self, ban_list: List[BanUser]) -> None:
|
||||
# sourcery skip: class-extract-method
|
||||
"""
|
||||
更新禁言列表到数据库。
|
||||
支持在不存在时创建新记录,对于多余的项目自动删除。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
all_records = session.exec(select(DB_BanUser)).all()
|
||||
for ban_user in ban_list:
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
|
||||
)
|
||||
if existing_record := session.exec(statement).first():
|
||||
if existing_record.lift_time == ban_user.lift_time:
|
||||
logger.debug(f"禁言记录未变更: {existing_record}")
|
||||
continue
|
||||
# 更新现有记录的 lift_time
|
||||
existing_record.lift_time = ban_user.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {existing_record}")
|
||||
else:
|
||||
# 创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_user}")
|
||||
# 删除不在 ban_list 中的记录
|
||||
for db_record in all_records:
|
||||
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
|
||||
if not any(is_identical(record, ban_user) for ban_user in ban_list):
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
|
||||
)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
session.commit()
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: {ban_record}")
|
||||
|
||||
session.commit()
|
||||
logger.info("禁言记录已更新")
|
||||
|
||||
def get_ban_records(self) -> List[BanUser]:
|
||||
"""
|
||||
读取所有禁言记录。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser)
|
||||
records = session.exec(statement).all()
|
||||
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
|
||||
|
||||
def create_ban_record(self, ban_record: BanUser) -> None:
|
||||
"""
|
||||
为特定群组中的用户创建禁言记录。
|
||||
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
|
||||
其同时还是简化版的更新方式。
|
||||
"""
|
||||
with Session(self.engine) as session:
|
||||
# 检查记录是否已存在
|
||||
statement = select(DB_BanUser).where(
|
||||
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
|
||||
)
|
||||
existing_record = session.exec(statement).first()
|
||||
if existing_record:
|
||||
# 如果记录已存在,更新 lift_time
|
||||
existing_record.lift_time = ban_record.lift_time
|
||||
session.add(existing_record)
|
||||
logger.debug(f"更新禁言记录: {ban_record}")
|
||||
else:
|
||||
# 如果记录不存在,创建新记录
|
||||
db_record = DB_BanUser(
|
||||
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
|
||||
)
|
||||
session.add(db_record)
|
||||
logger.debug(f"创建新禁言记录: {ban_record}")
|
||||
session.commit()
|
||||
|
||||
def delete_ban_record(self, ban_record: BanUser):
|
||||
"""
|
||||
删除特定用户在特定群组中的禁言记录。
|
||||
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
|
||||
"""
|
||||
user_id = ban_record.user_id
|
||||
group_id = ban_record.group_id
|
||||
with Session(self.engine) as session:
|
||||
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
|
||||
if ban_record := session.exec(statement).first():
|
||||
session.delete(ban_record)
|
||||
session.commit()
|
||||
logger.debug(f"删除禁言记录: {ban_record}")
|
||||
else:
|
||||
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
|
||||
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
21
MaiBot-Napcat-Adapter-dev/src/logger.py
Normal file
21
MaiBot-Napcat-Adapter-dev/src/logger.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from loguru import logger
|
||||
from .config import global_config
|
||||
import sys
|
||||
|
||||
# 默认 logger
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level=global_config.debug.level,
|
||||
format="<blue>{time:YYYY-MM-DD HH:mm:ss}</blue> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
filter=lambda record: "name" not in record["extra"] or record["extra"].get("name") != "maim_message",
|
||||
)
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level="INFO",
|
||||
format="<red>{time:YYYY-MM-DD HH:mm:ss}</red> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
filter=lambda record: record["extra"].get("name") == "maim_message",
|
||||
)
|
||||
# 创建样式不同的 logger
|
||||
custom_logger = logger.bind(name="maim_message")
|
||||
logger = logger.bind(name="MaiBot-Napcat-Adapter")
|
||||
24
MaiBot-Napcat-Adapter-dev/src/mmc_com_layer.py
Normal file
24
MaiBot-Napcat-Adapter-dev/src/mmc_com_layer.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from .config import global_config
|
||||
from .logger import logger, custom_logger
|
||||
from .send_handler import send_handler
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
global_config.maibot_server.platform_name: TargetConfig(
|
||||
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
router = Router(route_config, custom_logger)
|
||||
|
||||
|
||||
async def mmc_start_com():
|
||||
logger.info("正在连接MaiBot")
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
|
||||
async def mmc_stop_com():
|
||||
await router.stop()
|
||||
87
MaiBot-Napcat-Adapter-dev/src/recv_handler/__init__.py
Normal file
87
MaiBot-Napcat-Adapter-dev/src/recv_handler/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetaEventType:
|
||||
lifecycle = "lifecycle" # 生命周期
|
||||
|
||||
class Lifecycle:
|
||||
connect = "connect" # 生命周期 - WebSocket 连接成功
|
||||
|
||||
heartbeat = "heartbeat" # 心跳
|
||||
|
||||
|
||||
class MessageType: # 接受消息大类
|
||||
private = "private" # 私聊消息
|
||||
|
||||
class Private:
|
||||
friend = "friend" # 私聊消息 - 好友
|
||||
group = "group" # 私聊消息 - 群临时
|
||||
group_self = "group_self" # 私聊消息 - 群中自身发送
|
||||
other = "other" # 私聊消息 - 其他
|
||||
|
||||
group = "group" # 群聊消息
|
||||
|
||||
class Group:
|
||||
normal = "normal" # 群聊消息 - 普通
|
||||
anonymous = "anonymous" # 群聊消息 - 匿名消息
|
||||
notice = "notice" # 群聊消息 - 系统提示
|
||||
|
||||
|
||||
class NoticeType: # 通知事件
|
||||
friend_recall = "friend_recall" # 私聊消息撤回
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
class GroupBan:
|
||||
ban = "ban" # 禁言
|
||||
lift_ban = "lift_ban" # 解除禁言
|
||||
|
||||
|
||||
class RealMessageType: # 实际消息分类
|
||||
text = "text" # 纯文本
|
||||
face = "face" # qq表情
|
||||
image = "image" # 图片
|
||||
record = "record" # 语音
|
||||
video = "video" # 视频
|
||||
at = "at" # @某人
|
||||
rps = "rps" # 猜拳魔法表情
|
||||
dice = "dice" # 骰子
|
||||
shake = "shake" # 私聊窗口抖动(只收)
|
||||
poke = "poke" # 群聊戳一戳
|
||||
share = "share" # 链接分享(json形式)
|
||||
reply = "reply" # 回复消息
|
||||
forward = "forward" # 转发消息
|
||||
node = "node" # 转发消息节点
|
||||
|
||||
|
||||
class MessageSentType:
|
||||
private = "private"
|
||||
|
||||
class Private:
|
||||
friend = "friend"
|
||||
group = "group"
|
||||
|
||||
group = "group"
|
||||
|
||||
class Group:
|
||||
normal = "normal"
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
"""命令类型"""
|
||||
|
||||
GROUP_BAN = "set_group_ban" # 禁言用户
|
||||
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
|
||||
GROUP_KICK = "set_group_kick" # 踢出群聊
|
||||
SEND_POKE = "send_poke" # 戳一戳
|
||||
DELETE_MSG = "delete_msg" # 撤回消息
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]
|
||||
669
MaiBot-Napcat-Adapter-dev/src/recv_handler/message_handler.py
Normal file
669
MaiBot-Napcat-Adapter-dev/src/recv_handler/message_handler.py
Normal file
@@ -0,0 +1,669 @@
|
||||
from src.logger import logger
|
||||
from src.config import global_config
|
||||
from src.utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_image_base64,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
get_message_detail,
|
||||
)
|
||||
from .qq_emoji_list import qq_face
|
||||
from .message_sending import message_send_instance
|
||||
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||
|
||||
import time
|
||||
import json
|
||||
import websockets as Server
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
TemplateInfo,
|
||||
FormatInfo,
|
||||
)
|
||||
|
||||
|
||||
from src.response_pool import get_response
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
async def check_allow_to_chat(
|
||||
self,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
ignore_bot: Optional[bool] = False,
|
||||
ignore_global_list: Optional[bool] = False,
|
||||
) -> bool:
|
||||
# sourcery skip: hoist-statement-from-if, merge-else-if-into-elif
|
||||
"""
|
||||
检查是否允许聊天
|
||||
Parameters:
|
||||
user_id: int: 用户ID
|
||||
group_id: int: 群ID
|
||||
ignore_bot: bool: 是否忽略机器人检查
|
||||
ignore_global_list: bool: 是否忽略全局黑名单检查
|
||||
Returns:
|
||||
bool: 是否允许聊天
|
||||
"""
|
||||
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
|
||||
logger.debug("开始检查聊天白名单/黑名单")
|
||||
if group_id:
|
||||
if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list:
|
||||
logger.warning("群聊不在聊天白名单中,消息被丢弃")
|
||||
return False
|
||||
elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list:
|
||||
logger.warning("群聊在聊天黑名单中,消息被丢弃")
|
||||
return False
|
||||
else:
|
||||
if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list:
|
||||
logger.warning("私聊不在聊天白名单中,消息被丢弃")
|
||||
return False
|
||||
elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list:
|
||||
logger.warning("私聊在聊天黑名单中,消息被丢弃")
|
||||
return False
|
||||
if user_id in global_config.chat.ban_user_id and not ignore_global_list:
|
||||
logger.warning("用户在全局黑名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
if global_config.chat.ban_qq_bot and group_id and not ignore_bot:
|
||||
logger.debug("开始判断是否为机器人")
|
||||
member_info = await get_member_info(self.server_connection, group_id, user_id)
|
||||
if member_info:
|
||||
is_bot = member_info.get("is_robot")
|
||||
if is_bot is None:
|
||||
logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新")
|
||||
else:
|
||||
if is_bot:
|
||||
logger.warning("QQ官方机器人消息拦截已启用,消息被丢弃,新机器人加入拦截名单")
|
||||
self.bot_id_list[user_id] = True
|
||||
return False
|
||||
else:
|
||||
self.bot_id_list[user_id] = False
|
||||
|
||||
return True
|
||||
|
||||
async def handle_raw_message(self, raw_message: dict) -> None:
|
||||
# sourcery skip: low-code-quality, remove-unreachable-code
|
||||
"""
|
||||
从Napcat接受的原始消息处理
|
||||
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
|
||||
format_info: FormatInfo = FormatInfo(
|
||||
content_format=["text", "image", "emoji", "voice"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
) # 格式化信息
|
||||
if message_type == MessageType.private:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Private.friend:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), None):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 不存在群信息
|
||||
group_info: GroupInfo = None
|
||||
elif sub_type == MessageType.Private.group:
|
||||
"""
|
||||
本部分暂时不做支持,先放着
|
||||
"""
|
||||
logger.warning("群临时消息类型不支持")
|
||||
return None
|
||||
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
# 由于临时会话中,Napcat默认不发送成员昵称,所以需要单独获取
|
||||
fetched_member_info: dict = await get_member_info(
|
||||
self.server_connection,
|
||||
raw_message.get("group_id"),
|
||||
sender_info.get("user_id"),
|
||||
)
|
||||
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=nickname,
|
||||
user_cardname=None,
|
||||
)
|
||||
|
||||
# -------------------这里需要群信息吗?-------------------
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info: dict = await get_group_info(self.server_connection, raw_message.get("group_id"))
|
||||
group_name = ""
|
||||
if fetched_group_info.get("group_name"):
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"私聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
elif message_type == MessageType.group:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
if sub_type == MessageType.Group.normal:
|
||||
sender_info: dict = raw_message.get("sender")
|
||||
|
||||
if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")):
|
||||
return None
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
)
|
||||
|
||||
# 获取群聊相关信息,在此单独处理group_name,因为默认发送的消息中没有
|
||||
fetched_group_info = await get_group_info(self.server_connection, raw_message.get("group_id"))
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"群聊消息类型 {sub_type} 不支持")
|
||||
return None
|
||||
|
||||
additional_config: dict = {}
|
||||
if global_config.voice.use_tts:
|
||||
additional_config["allow_tts"] = True
|
||||
|
||||
# 消息信息
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id=message_id,
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=template_info,
|
||||
format_info=format_info,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
# 处理实际信息
|
||||
if not raw_message.get("message"):
|
||||
logger.warning("原始消息内容为空")
|
||||
return None
|
||||
|
||||
# 获取Seg列表
|
||||
seg_message: List[Seg] = await self.handle_real_message(raw_message)
|
||||
if not seg_message:
|
||||
logger.warning("处理后消息内容为空")
|
||||
return None
|
||||
submit_seg: Seg = Seg(
|
||||
type="seglist",
|
||||
data=seg_message,
|
||||
)
|
||||
# MessageBase创建
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message"),
|
||||
)
|
||||
|
||||
logger.info("发送到Maibot处理信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
处理实际消息
|
||||
Parameters:
|
||||
real_message: dict: 实际消息
|
||||
Returns:
|
||||
seg_message: list[Seg]: 处理后的消息段列表
|
||||
"""
|
||||
real_message: list = raw_message.get("message")
|
||||
if not real_message:
|
||||
return None
|
||||
seg_message: List[Seg] = []
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("text处理失败")
|
||||
case RealMessageType.face:
|
||||
ret_seg = await self.handle_face_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("face处理失败或不支持")
|
||||
case RealMessageType.reply:
|
||||
if not in_reply:
|
||||
ret_seg = await self.handle_reply_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message += ret_seg
|
||||
else:
|
||||
logger.warning("reply处理失败")
|
||||
case RealMessageType.image:
|
||||
ret_seg = await self.handle_image_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("image处理失败")
|
||||
case RealMessageType.record:
|
||||
ret_seg = await self.handle_record_message(sub_message)
|
||||
if ret_seg:
|
||||
seg_message.clear()
|
||||
seg_message.append(ret_seg)
|
||||
break # 使得消息只有record消息
|
||||
else:
|
||||
logger.warning("record处理失败或不支持")
|
||||
case RealMessageType.video:
|
||||
logger.warning("不支持视频解析")
|
||||
case RealMessageType.at:
|
||||
ret_seg = await self.handle_at_message(
|
||||
sub_message,
|
||||
raw_message.get("self_id"),
|
||||
raw_message.get("group_id"),
|
||||
)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("at处理失败")
|
||||
case RealMessageType.rps:
|
||||
logger.warning("暂时不支持猜拳魔法表情解析")
|
||||
case RealMessageType.dice:
|
||||
logger.warning("暂时不支持骰子表情解析")
|
||||
case RealMessageType.shake:
|
||||
# 预计等价于戳一戳
|
||||
logger.warning("暂时不支持窗口抖动解析")
|
||||
case RealMessageType.share:
|
||||
logger.warning("暂时不支持链接解析")
|
||||
case RealMessageType.forward:
|
||||
messages = await self._get_forward_message(sub_message)
|
||||
if not messages:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
ret_seg = await self.handle_forward_message(messages)
|
||||
if ret_seg:
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("转发消息处理失败")
|
||||
case RealMessageType.node:
|
||||
logger.warning("不支持转发消息节点解析")
|
||||
case _:
|
||||
logger.warning(f"未知消息类型: {sub_message_type}")
|
||||
return seg_message
|
||||
|
||||
async def handle_text_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理纯文本信息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
plain_text: str = message_data.get("text")
|
||||
return Seg(type="text", data=plain_text)
|
||||
|
||||
async def handle_face_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理表情消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
face_raw_id: str = str(message_data.get("id"))
|
||||
if face_raw_id in qq_face:
|
||||
face_content: str = qq_face.get(face_raw_id)
|
||||
return Seg(type="text", data=face_content)
|
||||
else:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
|
||||
async def handle_image_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理图片消息与表情包消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
image_sub_type = message_data.get("sub_type")
|
||||
try:
|
||||
image_base64 = await get_image_base64(message_data.get("url"))
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
if image_sub_type == 0:
|
||||
"""这部分认为是图片"""
|
||||
return Seg(type="image", data=image_base64)
|
||||
elif image_sub_type not in [4, 9]:
|
||||
"""这部分认为是表情包"""
|
||||
return Seg(type="emoji", data=image_base64)
|
||||
else:
|
||||
logger.warning(f"不支持的图片子类型:{image_sub_type}")
|
||||
return None
|
||||
|
||||
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None:
|
||||
# sourcery skip: use-named-expression
|
||||
"""
|
||||
处理at消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
self_id: int: 机器人QQ号
|
||||
group_id: int: 群号
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
if message_data:
|
||||
qq_id = message_data.get("qq")
|
||||
if str(self_id) == str(qq_id):
|
||||
logger.debug("机器人被at")
|
||||
self_info: dict = await get_self_info(self.server_connection)
|
||||
if self_info:
|
||||
return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id)
|
||||
if member_info:
|
||||
return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>")
|
||||
else:
|
||||
return None
|
||||
|
||||
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理语音消息
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
"""
|
||||
message_data: dict = raw_message.get("data")
|
||||
file: str = message_data.get("file")
|
||||
if not file:
|
||||
logger.warning("语音消息缺少文件信息")
|
||||
return None
|
||||
try:
|
||||
record_detail = await get_record_detail(self.server_connection, file)
|
||||
if not record_detail:
|
||||
logger.warning("获取语音消息详情失败")
|
||||
return None
|
||||
audio_base64: str = record_detail.get("base64")
|
||||
except Exception as e:
|
||||
logger.error(f"语音消息处理失败: {str(e)}")
|
||||
return None
|
||||
if not audio_base64:
|
||||
logger.error("语音消息处理失败,未获取到音频数据")
|
||||
return None
|
||||
return Seg(type="voice", data=audio_base64)
|
||||
|
||||
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
|
||||
# sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
处理回复消息
|
||||
|
||||
"""
|
||||
raw_message_data: dict = raw_message.get("data")
|
||||
message_id: int = None
|
||||
if raw_message_data:
|
||||
message_id = raw_message_data.get("id")
|
||||
else:
|
||||
return None
|
||||
message_detail: dict = await get_message_detail(self.server_connection, message_id)
|
||||
if not message_detail:
|
||||
logger.warning("获取被引用的消息详情失败")
|
||||
return None
|
||||
reply_message = await self.handle_real_message(message_detail, in_reply=True)
|
||||
if reply_message is None:
|
||||
reply_message = "(获取发言内容失败)"
|
||||
sender_info: dict = message_detail.get("sender")
|
||||
sender_nickname: str = sender_info.get("nickname")
|
||||
sender_id: str = sender_info.get("user_id")
|
||||
seg_message: List[Seg] = []
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
seg_message.append(Seg(type="text", data="[回复 未知用户:"))
|
||||
else:
|
||||
seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>:"))
|
||||
seg_message += reply_message
|
||||
seg_message.append(Seg(type="text", data="],说:"))
|
||||
return seg_message
|
||||
|
||||
async def handle_forward_message(self, message_list: list) -> Seg | None:
|
||||
"""
|
||||
递归处理转发消息,并按照动态方式确定图片处理方式
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表
|
||||
"""
|
||||
handled_message, image_count = await self._handle_forward_message(message_list, 0)
|
||||
handled_message: Seg
|
||||
image_count: int
|
||||
if not handled_message:
|
||||
return None
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
logger.trace("图片数量小于5,开始解析图片为base64")
|
||||
return await self._recursive_parse_image_seg(handled_message, True)
|
||||
elif image_count > 0:
|
||||
logger.trace("图片数量大于等于5,开始解析图片为占位符")
|
||||
# 处理图片数量大于等于5的情况,此时解析图片为占位符
|
||||
return await self._recursive_parse_image_seg(handled_message, False)
|
||||
else:
|
||||
# 处理没有图片的情况,此时直接返回
|
||||
logger.trace("没有图片,直接返回")
|
||||
return handled_message
|
||||
|
||||
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
|
||||
# sourcery skip: merge-else-if-into-elif
|
||||
if to_image:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[图片]")
|
||||
return Seg(type="image", data=encoded_image)
|
||||
elif seg_data.type == "emoji":
|
||||
image_url = seg_data.data
|
||||
try:
|
||||
encoded_image = await get_image_base64(image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"图片处理失败: {str(e)}")
|
||||
return Seg(type="text", data="[表情包]")
|
||||
return Seg(type="emoji", data=encoded_image)
|
||||
else:
|
||||
logger.trace(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
else:
|
||||
if seg_data.type == "seglist":
|
||||
new_seg_list = []
|
||||
for i_seg in seg_data.data:
|
||||
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
|
||||
new_seg_list.append(parsed_seg)
|
||||
return Seg(type="seglist", data=new_seg_list)
|
||||
elif seg_data.type == "image":
|
||||
return Seg(type="text", data="[图片]")
|
||||
elif seg_data.type == "emoji":
|
||||
return Seg(type="text", data="[动画表情]")
|
||||
else:
|
||||
logger.trace(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
|
||||
# sourcery skip: low-code-quality
|
||||
"""
|
||||
递归处理实际转发消息
|
||||
Parameters:
|
||||
message_list: list: 转发消息列表,首层对应messages字段,后面对应content字段
|
||||
layer: int: 当前层级
|
||||
Returns:
|
||||
seg_data: Seg: 处理后的消息段
|
||||
image_count: int: 图片数量
|
||||
"""
|
||||
seg_list: List[Seg] = []
|
||||
image_count = 0
|
||||
if message_list is None:
|
||||
return None, 0
|
||||
for sub_message in message_list:
|
||||
sub_message: dict
|
||||
sender_info: dict = sub_message.get("sender")
|
||||
user_nickname: str = sender_info.get("nickname", "QQ用户")
|
||||
user_nickname_str = f"【{user_nickname}】:"
|
||||
break_seg = Seg(type="text", data="\n")
|
||||
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
|
||||
if not message_of_sub_message_list:
|
||||
logger.warning("转发消息内容为空")
|
||||
continue
|
||||
message_of_sub_message = message_of_sub_message_list[0]
|
||||
if message_of_sub_message.get("type") == RealMessageType.forward:
|
||||
if layer >= 3:
|
||||
full_seg_data = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】:【转发消息】\n",
|
||||
)
|
||||
else:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
contents = sub_message_data.get("content")
|
||||
seg_data, count = await self._handle_forward_message(contents, layer + 1)
|
||||
image_count += count
|
||||
head_tip = Seg(
|
||||
type="text",
|
||||
data=("--" * layer) + f"【{user_nickname}】: 合并转发消息内容:\n",
|
||||
)
|
||||
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
|
||||
seg_list.append(full_seg_data)
|
||||
elif message_of_sub_message.get("type") == RealMessageType.text:
|
||||
sub_message_data = message_of_sub_message.get("data")
|
||||
if not sub_message_data:
|
||||
continue
|
||||
text_message = sub_message_data.get("text")
|
||||
seg_data = Seg(type="text", data=text_message)
|
||||
data_list: List[Any] = []
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
seg_list.append(Seg(type="seglist", data=data_list))
|
||||
elif message_of_sub_message.get("type") == RealMessageType.image:
|
||||
image_count += 1
|
||||
image_data = message_of_sub_message.get("data")
|
||||
sub_type = image_data.get("sub_type")
|
||||
image_url = image_data.get("url")
|
||||
data_list: List[Any] = []
|
||||
if sub_type == 0:
|
||||
seg_data = Seg(type="image", data=image_url)
|
||||
else:
|
||||
seg_data = Seg(type="emoji", data=image_url)
|
||||
if layer > 0:
|
||||
data_list = [
|
||||
Seg(type="text", data=("--" * layer) + user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
else:
|
||||
data_list = [
|
||||
Seg(type="text", data=user_nickname_str),
|
||||
seg_data,
|
||||
break_seg,
|
||||
]
|
||||
full_seg_data = Seg(type="seglist", data=data_list)
|
||||
seg_list.append(full_seg_data)
|
||||
return Seg(type="seglist", data=seg_list), image_count
|
||||
|
||||
async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None:
|
||||
forward_message_data: Dict = raw_message.get("data")
|
||||
if not forward_message_data:
|
||||
logger.warning("转发消息内容为空")
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_forward_msg",
|
||||
"params": {"message_id": forward_message_id},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await self.server_connection.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取转发消息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
logger.warning("转发消息内容为空或获取失败")
|
||||
return None
|
||||
return response_data.get("messages")
|
||||
|
||||
|
||||
message_handler = MessageHandler()
|
||||
@@ -0,0 +1,31 @@
|
||||
from src.logger import logger
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
"""
|
||||
负责把消息发送到麦麦
|
||||
"""
|
||||
|
||||
maibot_router: Router = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def message_send(self, message_base: MessageBase) -> bool:
|
||||
"""
|
||||
发送消息
|
||||
Parameters:
|
||||
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
|
||||
"""
|
||||
try:
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
return send_status
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
|
||||
|
||||
message_send_instance = MessageSending()
|
||||
@@ -0,0 +1,49 @@
|
||||
from src.logger import logger
|
||||
from src.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from . import MetaEventType
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""
|
||||
处理Meta事件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interval = global_config.napcat_server.heartbeat_interval
|
||||
self._interval_checking = False
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
if event_type == MetaEventType.lifecycle:
|
||||
sub_type = message.get("sub_type")
|
||||
if sub_type == MetaEventType.Lifecycle.connect:
|
||||
self_id = message.get("self_id")
|
||||
self.last_heart_beat = time.time()
|
||||
logger.success(f"Bot {self_id} 连接成功")
|
||||
asyncio.create_task(self.check_heartbeat(self_id))
|
||||
elif event_type == MetaEventType.heartbeat:
|
||||
if message["status"].get("online") and message["status"].get("good"):
|
||||
if not self._interval_checking:
|
||||
asyncio.create_task(self.check_heartbeat())
|
||||
self.last_heart_beat = time.time()
|
||||
self.interval = message.get("interval") / 1000
|
||||
else:
|
||||
self_id = message.get("self_id")
|
||||
logger.warning(f"Bot {self_id} Napcat 端异常!")
|
||||
|
||||
async def check_heartbeat(self, id: int) -> None:
|
||||
self._interval_checking = True
|
||||
while True:
|
||||
now_time = time.time()
|
||||
if now_time - self.last_heart_beat > self.interval * 2:
|
||||
logger.error(f"Bot {id} 可能发生了连接断开,被下线,或者Napcat卡死!")
|
||||
break
|
||||
else:
|
||||
logger.debug("心跳正常")
|
||||
await asyncio.sleep(self.interval)
|
||||
|
||||
|
||||
meta_event_handler = MetaEventHandler()
|
||||
516
MaiBot-Napcat-Adapter-dev/src/recv_handler/notice_handler.py
Normal file
516
MaiBot-Napcat-Adapter-dev/src/recv_handler/notice_handler.py
Normal file
@@ -0,0 +1,516 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import websockets as Server
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.logger import logger
|
||||
from src.config import global_config
|
||||
from src.database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
|
||||
from src.utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_self_info,
|
||||
get_stranger_info,
|
||||
read_ban_list,
|
||||
)
|
||||
|
||||
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
|
||||
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
|
||||
|
||||
|
||||
class NoticeHandler:
|
||||
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
|
||||
lifted_list: list[BanUser] = [] # 已经自然解除禁言
|
||||
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
while self.server_connection.state != Server.State.OPEN:
|
||||
await asyncio.sleep(0.5)
|
||||
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
|
||||
|
||||
asyncio.create_task(self.auto_lift_detect())
|
||||
asyncio.create_task(self.send_notice())
|
||||
asyncio.create_task(self.handle_natural_lift())
|
||||
|
||||
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||
"""
|
||||
将用户禁言记录添加到self.banned_list中
|
||||
如果是全体禁言,则user_id为0
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
lift_time = -1
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
||||
for record in self.banned_list:
|
||||
if is_identical(record, ban_record):
|
||||
self.banned_list.remove(record)
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 作为更新
|
||||
return
|
||||
self.banned_list.append(ban_record)
|
||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
||||
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
||||
self.lifted_list.append(ban_record)
|
||||
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
|
||||
|
||||
async def handle_notice(self, raw_message: dict) -> None:
|
||||
notice_type = raw_message.get("notice_type")
|
||||
# message_time: int = raw_message.get("time")
|
||||
message_time: float = time.time() # 应可乐要求,现在是float了
|
||||
|
||||
group_id = raw_message.get("group_id")
|
||||
user_id = raw_message.get("user_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
|
||||
handled_message: Seg = None
|
||||
user_info: UserInfo = None
|
||||
system_notice: bool = False
|
||||
|
||||
match notice_type:
|
||||
case NoticeType.friend_recall:
|
||||
logger.info("好友撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.group_recall:
|
||||
logger.info("群内用户撤回一条消息")
|
||||
logger.info(f"撤回消息ID:{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
|
||||
logger.warning("暂时不支持撤回消息处理")
|
||||
case NoticeType.notify:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if global_config.chat.enable_poke and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
logger.info("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_ban:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.GroupBan.ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理群禁言")
|
||||
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case NoticeType.GroupBan.lift_ban:
|
||||
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
|
||||
return None
|
||||
logger.info("处理解除群禁言")
|
||||
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
|
||||
system_notice = True
|
||||
case _:
|
||||
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
|
||||
case _:
|
||||
logger.warning(f"不支持的notice类型: {notice_type}")
|
||||
return None
|
||||
if not handled_message or not user_info:
|
||||
logger.warning("notice处理失败或不支持")
|
||||
return None
|
||||
|
||||
group_info: GroupInfo = None
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(self.server_connection, group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=FormatInfo(
|
||||
content_format=["text", "notify"],
|
||||
accept_format=ACCEPT_FORMAT,
|
||||
),
|
||||
additional_config={"target_id": target_id}, # 在这里塞了一个target_id,方便mmc那边知道被戳的人是谁
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=handled_message,
|
||||
raw_message=json.dumps(raw_message),
|
||||
)
|
||||
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
else:
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_poke_notify(
|
||||
self, raw_message: dict, group_id: int, user_id: int
|
||||
) -> Tuple[Seg | None, UserInfo | None]:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
|
||||
self_info: dict = await get_self_info(self.server_connection)
|
||||
|
||||
if not self_info:
|
||||
logger.error("自身信息获取失败")
|
||||
return None, None
|
||||
|
||||
self_id = raw_message.get("self_id")
|
||||
target_id = raw_message.get("target_id")
|
||||
target_name: str = None
|
||||
raw_info: list = raw_message.get("raw_info")
|
||||
|
||||
if group_id:
|
||||
user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id)
|
||||
else:
|
||||
user_qq_info: dict = await get_stranger_info(self.server_connection, user_id)
|
||||
if user_qq_info:
|
||||
user_name = user_qq_info.get("nickname")
|
||||
user_cardname = user_qq_info.get("card")
|
||||
else:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.info("无法获取戳一戳对方的用户昵称")
|
||||
|
||||
# 计算Seg
|
||||
if self_id == target_id:
|
||||
display_name = ""
|
||||
target_name = self_info.get("nickname")
|
||||
|
||||
elif self_id == user_id:
|
||||
# 让ada不发送麦麦戳别人的消息
|
||||
return None, None
|
||||
|
||||
else:
|
||||
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
|
||||
if group_id:
|
||||
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, target_id)
|
||||
if fetched_member_info:
|
||||
target_name = fetched_member_info.get("nickname")
|
||||
else:
|
||||
target_name = "QQ用户"
|
||||
logger.info("无法获取被戳一戳方的用户昵称")
|
||||
display_name = user_name
|
||||
else:
|
||||
return None, None
|
||||
|
||||
first_txt: str = "戳了戳"
|
||||
second_txt: str = ""
|
||||
try:
|
||||
first_txt = raw_info[2].get("txt", "戳了戳")
|
||||
second_txt = raw_info[4].get("txt", "")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="text",
|
||||
data=f"{display_name}{first_txt}{target_name}{second_txt}(这是QQ的一个功能,用于提及某人,但没那么明显)",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.server_connection, group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
user_id = raw_message.get("user_id")
|
||||
banned_user_info: UserInfo = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
sub_type: str = None
|
||||
|
||||
duration = raw_message.get("duration")
|
||||
if duration is None:
|
||||
logger.error("禁言时长不能为空,无法处理禁言通知")
|
||||
return None, None
|
||||
|
||||
if user_id == 0: # 为全体禁言
|
||||
sub_type: str = "whole_ban"
|
||||
self._ban_operation(group_id)
|
||||
else: # 为单人禁言
|
||||
# 获取被禁言人的信息
|
||||
sub_type: str = "ban"
|
||||
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
banned_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"duration": duration,
|
||||
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
|
||||
},
|
||||
)
|
||||
|
||||
return seg_data, operator_info
|
||||
|
||||
async def handle_lift_ban_notify(
|
||||
self, raw_message: dict, group_id: int
|
||||
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None, None
|
||||
|
||||
# 计算user_info
|
||||
operator_id = raw_message.get("operator_id")
|
||||
operator_nickname: str = None
|
||||
operator_cardname: str = None
|
||||
|
||||
member_info: dict = await get_member_info(self.server_connection, group_id, operator_id)
|
||||
if member_info:
|
||||
operator_nickname = member_info.get("nickname")
|
||||
operator_cardname = member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
)
|
||||
|
||||
# 计算Seg
|
||||
sub_type: str = None
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
lifted_user_info: UserInfo = None
|
||||
|
||||
user_id = raw_message.get("user_id")
|
||||
if user_id == 0: # 全体禁言解除
|
||||
sub_type = "whole_lift_ban"
|
||||
self._lift_operation(group_id)
|
||||
else: # 单人禁言解除
|
||||
sub_type = "lift_ban"
|
||||
# 获取被解除禁言人的信息
|
||||
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
else:
|
||||
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
self._lift_operation(group_id, user_id)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": sub_type,
|
||||
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
|
||||
},
|
||||
)
|
||||
return seg_data, operator_info
|
||||
|
||||
async def put_notice(self, message_base: MessageBase) -> None:
|
||||
"""
|
||||
将处理后的通知消息放入通知队列
|
||||
"""
|
||||
if notice_queue.full() or unsuccessful_notice_queue.full():
|
||||
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
|
||||
else:
|
||||
await notice_queue.put(message_base)
|
||||
|
||||
async def handle_natural_lift(self) -> None:
|
||||
while True:
|
||||
if len(self.lifted_list) != 0:
|
||||
lift_record = self.lifted_list.pop()
|
||||
group_id = lift_record.group_id
|
||||
user_id = lift_record.user_id
|
||||
|
||||
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
|
||||
|
||||
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
||||
|
||||
fetched_group_info = await get_group_info(self.server_connection, group_id)
|
||||
group_name: str = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
message_id="notice",
|
||||
time=time.time(),
|
||||
user_info=None, # 自然解除禁言没有操作者
|
||||
group_info=group_info,
|
||||
template_info=None,
|
||||
format_info=None,
|
||||
)
|
||||
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=seg_message,
|
||||
raw_message=json.dumps(
|
||||
{
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_ban",
|
||||
"sub_type": "lift_ban",
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"operator_id": None, # 自然解除禁言没有操作者
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
await asyncio.sleep(0.5) # 确保队列处理间隔
|
||||
else:
|
||||
await asyncio.sleep(5) # 每5秒检查一次
|
||||
|
||||
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理解除禁言通知")
|
||||
return None
|
||||
|
||||
if user_id == 0: # 理论上永远不会触发
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "whole_lift_ban",
|
||||
"lifted_user_info": None,
|
||||
},
|
||||
)
|
||||
|
||||
user_nickname: str = "QQ用户"
|
||||
user_cardname: str = None
|
||||
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
|
||||
if fetched_member_info:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
return Seg(
|
||||
type="notify",
|
||||
data={
|
||||
"sub_type": "lift_ban",
|
||||
"lifted_user_info": lifted_user_info.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
async def auto_lift_detect(self) -> None:
|
||||
while True:
|
||||
if len(self.banned_list) == 0:
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
for ban_record in self.banned_list:
|
||||
if ban_record.user_id == 0 or ban_record.lift_time == -1:
|
||||
continue
|
||||
if ban_record.lift_time <= int(time.time()):
|
||||
# 触发自然解除禁言
|
||||
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
self.lifted_list.append(ban_record)
|
||||
self.banned_list.remove(ban_record)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def send_notice(self) -> None:
|
||||
"""
|
||||
发送通知消息到Napcat
|
||||
"""
|
||||
while True:
|
||||
if not unsuccessful_notice_queue.empty():
|
||||
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
unsuccessful_notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
to_be_send: MessageBase = await notice_queue.get()
|
||||
try:
|
||||
send_status = await message_send_instance.message_send(to_be_send)
|
||||
if send_status:
|
||||
notice_queue.task_done()
|
||||
else:
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
except Exception as e:
|
||||
logger.error(f"发送通知消息失败: {str(e)}")
|
||||
await unsuccessful_notice_queue.put(to_be_send)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
notice_handler = NoticeHandler()
|
||||
250
MaiBot-Napcat-Adapter-dev/src/recv_handler/qq_emoji_list.py
Normal file
250
MaiBot-Napcat-Adapter-dev/src/recv_handler/qq_emoji_list.py
Normal file
@@ -0,0 +1,250 @@
|
||||
qq_face: dict = {
|
||||
"0": "[表情:惊讶]",
|
||||
"1": "[表情:撇嘴]",
|
||||
"2": "[表情:色]",
|
||||
"3": "[表情:发呆]",
|
||||
"4": "[表情:得意]",
|
||||
"5": "[表情:流泪]",
|
||||
"6": "[表情:害羞]",
|
||||
"7": "[表情:闭嘴]",
|
||||
"8": "[表情:睡]",
|
||||
"9": "[表情:大哭]",
|
||||
"10": "[表情:尴尬]",
|
||||
"11": "[表情:发怒]",
|
||||
"12": "[表情:调皮]",
|
||||
"13": "[表情:呲牙]",
|
||||
"14": "[表情:微笑]",
|
||||
"15": "[表情:难过]",
|
||||
"16": "[表情:酷]",
|
||||
"18": "[表情:抓狂]",
|
||||
"19": "[表情:吐]",
|
||||
"20": "[表情:偷笑]",
|
||||
"21": "[表情:可爱]",
|
||||
"22": "[表情:白眼]",
|
||||
"23": "[表情:傲慢]",
|
||||
"24": "[表情:饥饿]",
|
||||
"25": "[表情:困]",
|
||||
"26": "[表情:惊恐]",
|
||||
"27": "[表情:流汗]",
|
||||
"28": "[表情:憨笑]",
|
||||
"29": "[表情:悠闲]",
|
||||
"30": "[表情:奋斗]",
|
||||
"31": "[表情:咒骂]",
|
||||
"32": "[表情:疑问]",
|
||||
"33": "[表情: 嘘]",
|
||||
"34": "[表情:晕]",
|
||||
"35": "[表情:折磨]",
|
||||
"36": "[表情:衰]",
|
||||
"37": "[表情:骷髅]",
|
||||
"38": "[表情:敲打]",
|
||||
"39": "[表情:再见]",
|
||||
"41": "[表情:发抖]",
|
||||
"42": "[表情:爱情]",
|
||||
"43": "[表情:跳跳]",
|
||||
"46": "[表情:猪头]",
|
||||
"49": "[表情:拥抱]",
|
||||
"53": "[表情:蛋糕]",
|
||||
"56": "[表情:刀]",
|
||||
"59": "[表情:便便]",
|
||||
"60": "[表情:咖啡]",
|
||||
"63": "[表情:玫瑰]",
|
||||
"64": "[表情:凋谢]",
|
||||
"66": "[表情:爱心]",
|
||||
"67": "[表情:心碎]",
|
||||
"74": "[表情:太阳]",
|
||||
"75": "[表情:月亮]",
|
||||
"76": "[表情:赞]",
|
||||
"77": "[表情:踩]",
|
||||
"78": "[表情:握手]",
|
||||
"79": "[表情:胜利]",
|
||||
"85": "[表情:飞吻]",
|
||||
"86": "[表情:怄火]",
|
||||
"89": "[表情:西瓜]",
|
||||
"96": "[表情:冷汗]",
|
||||
"97": "[表情:擦汗]",
|
||||
"98": "[表情:抠鼻]",
|
||||
"99": "[表情:鼓掌]",
|
||||
"100": "[表情:糗大了]",
|
||||
"101": "[表情:坏笑]",
|
||||
"102": "[表情:左哼哼]",
|
||||
"103": "[表情:右哼哼]",
|
||||
"104": "[表情:哈欠]",
|
||||
"105": "[表情:鄙视]",
|
||||
"106": "[表情:委屈]",
|
||||
"107": "[表情:快哭了]",
|
||||
"108": "[表情:阴险]",
|
||||
"109": "[表情:左亲亲]",
|
||||
"110": "[表情:吓]",
|
||||
"111": "[表情:可怜]",
|
||||
"112": "[表情:菜刀]",
|
||||
"114": "[表情:篮球]",
|
||||
"116": "[表情:示爱]",
|
||||
"118": "[表情:抱拳]",
|
||||
"119": "[表情:勾引]",
|
||||
"120": "[表情:拳头]",
|
||||
"121": "[表情:差劲]",
|
||||
"123": "[表情:NO]",
|
||||
"124": "[表情:OK]",
|
||||
"125": "[表情:转圈]",
|
||||
"129": "[表情:挥手]",
|
||||
"137": "[表情:鞭炮]",
|
||||
"144": "[表情:喝彩]",
|
||||
"146": "[表情:爆筋]",
|
||||
"147": "[表情:棒棒糖]",
|
||||
"169": "[表情:手枪]",
|
||||
"171": "[表情:茶]",
|
||||
"172": "[表情:眨眼睛]",
|
||||
"173": "[表情:泪奔]",
|
||||
"174": "[表情:无奈]",
|
||||
"175": "[表情:卖萌]",
|
||||
"176": "[表情:小纠结]",
|
||||
"177": "[表情:喷血]",
|
||||
"178": "[表情:斜眼笑]",
|
||||
"179": "[表情:doge]",
|
||||
"181": "[表情:戳一戳]",
|
||||
"182": "[表情:笑哭]",
|
||||
"183": "[表情:我最美]",
|
||||
"185": "[表情:羊驼]",
|
||||
"187": "[表情:幽灵]",
|
||||
"201": "[表情:点赞]",
|
||||
"212": "[表情:托腮]",
|
||||
"262": "[表情:脑阔疼]",
|
||||
"263": "[表情:沧桑]",
|
||||
"264": "[表情:捂脸]",
|
||||
"265": "[表情:辣眼睛]",
|
||||
"266": "[表情:哦哟]",
|
||||
"267": "[表情:头秃]",
|
||||
"268": "[表情:问号脸]",
|
||||
"269": "[表情:暗中观察]",
|
||||
"270": "[表情:emm]",
|
||||
"271": "[表情:吃 瓜]",
|
||||
"272": "[表情:呵呵哒]",
|
||||
"273": "[表情:我酸了]",
|
||||
"277": "[表情:汪汪]",
|
||||
"281": "[表情:无眼笑]",
|
||||
"282": "[表情:敬礼]",
|
||||
"283": "[表情:狂笑]",
|
||||
"284": "[表情:面无表情]",
|
||||
"285": "[表情:摸鱼]",
|
||||
"286": "[表情:魔鬼笑]",
|
||||
"287": "[表情:哦]",
|
||||
"289": "[表情:睁眼]",
|
||||
"293": "[表情:摸锦鲤]",
|
||||
"294": "[表情:期待]",
|
||||
"295": "[表情:拿到红包]",
|
||||
"297": "[表情:拜谢]",
|
||||
"298": "[表情:元宝]",
|
||||
"299": "[表情:牛啊]",
|
||||
"300": "[表情:胖三斤]",
|
||||
"302": "[表情:左拜年]",
|
||||
"303": "[表情:右拜年]",
|
||||
"305": "[表情:右亲亲]",
|
||||
"306": "[表情:牛气冲天]",
|
||||
"307": "[表情:喵喵]",
|
||||
"311": "[表情:打call]",
|
||||
"312": "[表情:变形]",
|
||||
"314": "[表情:仔细分析]",
|
||||
"317": "[表情:菜汪]",
|
||||
"318": "[表情:崇拜]",
|
||||
"319": "[表情: 比心]",
|
||||
"320": "[表情:庆祝]",
|
||||
"323": "[表情:嫌弃]",
|
||||
"324": "[表情:吃糖]",
|
||||
"325": "[表情:惊吓]",
|
||||
"326": "[表情:生气]",
|
||||
"332": "[表情:举牌牌]",
|
||||
"333": "[表情:烟花]",
|
||||
"334": "[表情:虎虎生威]",
|
||||
"336": "[表情:豹富]",
|
||||
"337": "[表情:花朵脸]",
|
||||
"338": "[表情:我想开了]",
|
||||
"339": "[表情:舔屏]",
|
||||
"341": "[表情:打招呼]",
|
||||
"342": "[表情:酸Q]",
|
||||
"343": "[表情:我方了]",
|
||||
"344": "[表情:大怨种]",
|
||||
"345": "[表情:红包多多]",
|
||||
"346": "[表情:你真棒棒]",
|
||||
"347": "[表情:大展宏兔]",
|
||||
"349": "[表情:坚强]",
|
||||
"350": "[表情:贴贴]",
|
||||
"351": "[表情:敲敲]",
|
||||
"352": "[表情:咦]",
|
||||
"353": "[表情:拜托]",
|
||||
"354": "[表情:尊嘟假嘟]",
|
||||
"355": "[表情:耶]",
|
||||
"356": "[表情:666]",
|
||||
"357": "[表情:裂开]",
|
||||
"392": "[表情:龙年 快乐]",
|
||||
"393": "[表情:新年中龙]",
|
||||
"394": "[表情:新年大龙]",
|
||||
"395": "[表情:略略略]",
|
||||
"😊": "[表情:嘿嘿]",
|
||||
"😌": "[表情:羞涩]",
|
||||
"😚": "[ 表情:亲亲]",
|
||||
"😓": "[表情:汗]",
|
||||
"😰": "[表情:紧张]",
|
||||
"😝": "[表情:吐舌]",
|
||||
"😁": "[表情:呲牙]",
|
||||
"😜": "[表情:淘气]",
|
||||
"☺": "[表情:可爱]",
|
||||
"😍": "[表情:花痴]",
|
||||
"😔": "[表情:失落]",
|
||||
"😄": "[表情:高兴]",
|
||||
"😏": "[表情:哼哼]",
|
||||
"😒": "[表情:不屑]",
|
||||
"😳": "[表情:瞪眼]",
|
||||
"😘": "[表情:飞吻]",
|
||||
"😭": "[表情:大哭]",
|
||||
"😱": "[表情:害怕]",
|
||||
"😂": "[表情:激动]",
|
||||
"💪": "[表情:肌肉]",
|
||||
"👊": "[表情:拳头]",
|
||||
"👍": "[表情 :厉害]",
|
||||
"👏": "[表情:鼓掌]",
|
||||
"👎": "[表情:鄙视]",
|
||||
"🙏": "[表情:合十]",
|
||||
"👌": "[表情:好的]",
|
||||
"👆": "[表情:向上]",
|
||||
"👀": "[表情:眼睛]",
|
||||
"🍜": "[表情:拉面]",
|
||||
"🍧": "[表情:刨冰]",
|
||||
"🍞": "[表情:面包]",
|
||||
"🍺": "[表情:啤酒]",
|
||||
"🍻": "[表情:干杯]",
|
||||
"☕": "[表情:咖啡]",
|
||||
"🍎": "[表情:苹果]",
|
||||
"🍓": "[表情:草莓]",
|
||||
"🍉": "[表情:西瓜]",
|
||||
"🚬": "[表情:吸烟]",
|
||||
"🌹": "[表情:玫瑰]",
|
||||
"🎉": "[表情:庆祝]",
|
||||
"💝": "[表情:礼物]",
|
||||
"💣": "[表情:炸弹]",
|
||||
"✨": "[表情:闪光]",
|
||||
"💨": "[表情:吹气]",
|
||||
"💦": "[表情:水]",
|
||||
"🔥": "[表情:火]",
|
||||
"💤": "[表情:睡觉]",
|
||||
"💩": "[表情:便便]",
|
||||
"💉": "[表情:打针]",
|
||||
"📫": "[表情:邮箱]",
|
||||
"🐎": "[表情:骑马]",
|
||||
"👧": "[表情:女孩]",
|
||||
"👦": "[表情:男孩]",
|
||||
"🐵": "[表情:猴]",
|
||||
"🐷": "[表情:猪]",
|
||||
"🐮": "[表情:牛]",
|
||||
"🐔": "[表情:公鸡]",
|
||||
"🐸": "[表情:青蛙]",
|
||||
"👻": "[表情:幽灵]",
|
||||
"🐛": "[表情:虫]",
|
||||
"🐶": "[表情:狗]",
|
||||
"🐳": "[表情:鲸鱼]",
|
||||
"👢": "[表情:靴子]",
|
||||
"☀": "[表情:晴天]",
|
||||
"❔": "[表情:问号]",
|
||||
"🔫": "[表情:手枪]",
|
||||
"💓": "[表情:爱 心]",
|
||||
"🏪": "[表情:便利店]",
|
||||
}
|
||||
44
MaiBot-Napcat-Adapter-dev/src/response_pool.py
Normal file
44
MaiBot-Napcat-Adapter-dev/src/response_pool.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
from .config import global_config
|
||||
from .logger import logger
|
||||
|
||||
response_dict: Dict = {}
|
||||
response_time_dict: Dict = {}
|
||||
|
||||
|
||||
async def get_response(request_id: str, timeout: int = 10) -> dict:
|
||||
response = await asyncio.wait_for(_get_response(request_id), timeout)
|
||||
_ = response_time_dict.pop(request_id)
|
||||
logger.trace(f"响应信息id: {request_id} 已从响应字典中取出")
|
||||
return response
|
||||
|
||||
async def _get_response(request_id: str) -> dict:
|
||||
"""
|
||||
内部使用的获取响应函数,主要用于在需要时获取响应
|
||||
"""
|
||||
while request_id not in response_dict:
|
||||
await asyncio.sleep(0.2)
|
||||
return response_dict.pop(request_id)
|
||||
|
||||
async def put_response(response: dict):
|
||||
echo_id = response.get("echo")
|
||||
now_time = time.time()
|
||||
response_dict[echo_id] = response
|
||||
response_time_dict[echo_id] = now_time
|
||||
logger.trace(f"响应信息id: {echo_id} 已存入响应字典")
|
||||
|
||||
|
||||
async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
response_dict.pop(echo_id)
|
||||
response_time_dict.pop(echo_id)
|
||||
logger.warning(f"响应消息 {echo_id} 超时,已删除")
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
|
||||
461
MaiBot-Napcat-Adapter-dev/src/send_handler.py
Normal file
461
MaiBot-Napcat-Adapter-dev/src/send_handler.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import json
|
||||
import websockets as Server
|
||||
import uuid
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
)
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from . import CommandType
|
||||
from .config import global_config
|
||||
from .response_pool import get_response
|
||||
from .logger import logger
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
|
||||
|
||||
class SendHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
self.server_connection = server_connection
|
||||
|
||||
async def handle_message(self, raw_message_base_dict: dict) -> None:
|
||||
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
logger.info("接收到来自MaiBot的消息,处理中")
|
||||
if message_segment.type == "command":
|
||||
return await self.send_command(raw_message_base)
|
||||
else:
|
||||
return await self.send_normal_message(raw_message_base)
|
||||
|
||||
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理普通消息发送
|
||||
"""
|
||||
logger.info("处理普通信息中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: GroupInfo = message_info.group_info
|
||||
user_info: UserInfo = message_info.user_info
|
||||
target_id: int = None
|
||||
action: str = None
|
||||
id_name: str = None
|
||||
processed_message: list = []
|
||||
try:
|
||||
processed_message = await self.handle_seg_recursive(message_segment)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时发生错误: {e}")
|
||||
return
|
||||
|
||||
if not processed_message:
|
||||
logger.critical("现在暂时不支持解析此回复!")
|
||||
return None
|
||||
|
||||
if group_info and user_info:
|
||||
logger.debug("发送群聊消息")
|
||||
target_id = group_info.group_id
|
||||
action = "send_group_msg"
|
||||
id_name = "group_id"
|
||||
elif user_info:
|
||||
logger.debug("发送私聊消息")
|
||||
target_id = user_info.user_id
|
||||
action = "send_private_msg"
|
||||
id_name = "user_id"
|
||||
else:
|
||||
logger.error("无法识别的消息类型")
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
response = await self.send_message_to_napcat(
|
||||
action,
|
||||
{
|
||||
id_name: target_id,
|
||||
"message": processed_message,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "ok":
|
||||
logger.info("消息发送成功")
|
||||
qq_message_id = response.get("data", {}).get("message_id")
|
||||
await self.message_sent_back(raw_message_base, qq_message_id)
|
||||
else:
|
||||
logger.warning(f"消息发送失败,napcat返回:{str(response)}")
|
||||
|
||||
async def send_command(self, raw_message_base: MessageBase) -> None:
|
||||
"""
|
||||
处理命令类
|
||||
"""
|
||||
logger.info("处理命令中")
|
||||
message_info: BaseMessageInfo = raw_message_base.message_info
|
||||
message_segment: Seg = raw_message_base.message_segment
|
||||
group_info: GroupInfo = message_info.group_info
|
||||
seg_data: Dict[str, Any] = message_segment.data
|
||||
command_name: str = seg_data.get("name")
|
||||
try:
|
||||
match command_name:
|
||||
case CommandType.GROUP_BAN.name:
|
||||
command, args_dict = self.handle_ban_command(seg_data.get("args"), group_info)
|
||||
case CommandType.GROUP_WHOLE_BAN.name:
|
||||
command, args_dict = self.handle_whole_ban_command(seg_data.get("args"), group_info)
|
||||
case CommandType.GROUP_KICK.name:
|
||||
command, args_dict = self.handle_kick_command(seg_data.get("args"), group_info)
|
||||
case CommandType.SEND_POKE.name:
|
||||
command, args_dict = self.handle_poke_command(seg_data.get("args"), group_info)
|
||||
case CommandType.DELETE_MSG.name:
|
||||
command, args_dict = self.delete_msg_command(seg_data.get("args"))
|
||||
case CommandType.AI_VOICE_SEND.name:
|
||||
command, args_dict = self.handle_ai_voice_send_command(seg_data.get("args"), group_info)
|
||||
case _:
|
||||
logger.error(f"未知命令: {command_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时发生错误: {e}")
|
||||
return None
|
||||
|
||||
if not command or not args_dict:
|
||||
logger.error("命令或参数缺失")
|
||||
return None
|
||||
|
||||
response = await self.send_message_to_napcat(command, args_dict)
|
||||
if response.get("status") == "ok":
|
||||
logger.info(f"命令 {command_name} 执行成功")
|
||||
else:
|
||||
logger.warning(f"命令 {command_name} 执行失败,napcat返回:{str(response)}")
|
||||
|
||||
def get_level(self, seg_data: Seg) -> int:
|
||||
if seg_data.type == "seglist":
|
||||
return 1 + max(self.get_level(seg) for seg in seg_data.data)
|
||||
else:
|
||||
return 1
|
||||
|
||||
async def handle_seg_recursive(self, seg_data: Seg) -> list:
|
||||
payload: list = []
|
||||
if seg_data.type == "seglist":
|
||||
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
|
||||
if not seg_data.data:
|
||||
return []
|
||||
for seg in seg_data.data:
|
||||
payload = self.process_message_by_type(seg, payload)
|
||||
else:
|
||||
payload = self.process_message_by_type(seg_data, payload)
|
||||
return payload
|
||||
|
||||
def process_message_by_type(self, seg: Seg, payload: list) -> list:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
new_payload = payload
|
||||
if seg.type == "reply":
|
||||
target_id = seg.data
|
||||
if target_id == "notice":
|
||||
return payload
|
||||
new_payload = self.build_payload(payload, self.handle_reply_message(target_id), True)
|
||||
elif seg.type == "text":
|
||||
text = seg.data
|
||||
if not text:
|
||||
return payload
|
||||
new_payload = self.build_payload(payload, self.handle_text_message(text), False)
|
||||
elif seg.type == "face":
|
||||
logger.warning("MaiBot 发送了qq原生表情,暂时不支持")
|
||||
elif seg.type == "image":
|
||||
image = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
|
||||
elif seg.type == "emoji":
|
||||
emoji = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
|
||||
elif seg.type == "voice":
|
||||
voice = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
|
||||
elif seg.type == "voiceurl":
|
||||
voice_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
|
||||
elif seg.type == "music":
|
||||
song_id = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
|
||||
elif seg.type == "videourl":
|
||||
video_url = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
|
||||
elif seg.type == "file":
|
||||
file_path = seg.data
|
||||
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
|
||||
return new_payload
|
||||
|
||||
def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list:
|
||||
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
|
||||
"""构建发送的消息体"""
|
||||
if is_reply:
|
||||
temp_list = []
|
||||
temp_list.append(addon)
|
||||
for i in payload:
|
||||
if i.get("type") == "reply":
|
||||
logger.debug("检测到多个回复,使用最新的回复")
|
||||
continue
|
||||
temp_list.append(i)
|
||||
return temp_list
|
||||
else:
|
||||
payload.append(addon)
|
||||
return payload
|
||||
|
||||
def handle_reply_message(self, id: str) -> dict:
|
||||
"""处理回复消息"""
|
||||
return {"type": "reply", "data": {"id": id}}
|
||||
|
||||
def handle_text_message(self, message: str) -> dict:
|
||||
"""处理文本消息"""
|
||||
return {"type": "text", "data": {"text": message}}
|
||||
|
||||
def handle_image_message(self, encoded_image: str) -> dict:
|
||||
"""处理图片消息"""
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 0,
|
||||
},
|
||||
} # base64 编码的图片
|
||||
|
||||
def handle_emoji_message(self, encoded_emoji: str) -> dict:
|
||||
"""处理表情消息"""
|
||||
encoded_image = encoded_emoji
|
||||
image_format = get_image_format(encoded_emoji)
|
||||
if image_format != "gif":
|
||||
encoded_image = convert_image_to_gif(encoded_emoji)
|
||||
return {
|
||||
"type": "image",
|
||||
"data": {
|
||||
"file": f"base64://{encoded_image}",
|
||||
"subtype": 1,
|
||||
"summary": "[动画表情]",
|
||||
},
|
||||
}
|
||||
|
||||
def handle_voice_message(self, encoded_voice: str) -> dict:
|
||||
"""处理语音消息"""
|
||||
if not global_config.voice.use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
if not encoded_voice:
|
||||
return {}
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": f"base64://{encoded_voice}"},
|
||||
}
|
||||
|
||||
def handle_voiceurl_message(self, voice_url: str) -> dict:
|
||||
"""处理语音链接消息"""
|
||||
return {
|
||||
"type": "record",
|
||||
"data": {"file": voice_url},
|
||||
}
|
||||
|
||||
def handle_music_message(self, song_id: str) -> dict:
|
||||
"""处理音乐消息"""
|
||||
return {
|
||||
"type": "music",
|
||||
"data": {"type": "163", "id": song_id},
|
||||
}
|
||||
def handle_videourl_message(self, video_url: str) -> dict:
|
||||
"""处理视频链接消息"""
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {"file": video_url},
|
||||
}
|
||||
|
||||
def handle_file_message(self, file_path: str) -> dict:
|
||||
"""处理文件消息"""
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {"file": f"file://{file_path}"},
|
||||
}
|
||||
|
||||
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理封禁命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
duration: int = int(args["duration"])
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if duration < 0:
|
||||
raise ValueError("封禁时间必须大于等于0")
|
||||
if not user_id or not group_id:
|
||||
raise ValueError("封禁命令缺少必要参数")
|
||||
if duration > 2592000:
|
||||
raise ValueError("封禁时间不能超过30天")
|
||||
return (
|
||||
CommandType.GROUP_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理全体禁言命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
enable = args["enable"]
|
||||
assert isinstance(enable, bool), "enable参数必须是布尔值"
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
return (
|
||||
CommandType.GROUP_WHOLE_BAN.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"enable": enable,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理群成员踢出命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.GROUP_KICK.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"reject_add_request": False, # 不拒绝加群请求
|
||||
},
|
||||
)
|
||||
|
||||
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理戳一戳命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
group_info (GroupInfo): 群聊信息(对应目标群聊)
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
user_id: int = int(args["qq_id"])
|
||||
if group_info is None:
|
||||
group_id = None
|
||||
else:
|
||||
group_id: int = int(group_info.group_id)
|
||||
if group_id <= 0:
|
||||
raise ValueError("群组ID无效")
|
||||
if user_id <= 0:
|
||||
raise ValueError("用户ID无效")
|
||||
return (
|
||||
CommandType.SEND_POKE.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""处理撤回消息命令
|
||||
|
||||
Args:
|
||||
args (Dict[str, Any]): 参数字典
|
||||
|
||||
Returns:
|
||||
Tuple[CommandType, Dict[str, Any]]
|
||||
"""
|
||||
try:
|
||||
message_id = int(args["message_id"])
|
||||
if message_id <= 0:
|
||||
raise ValueError("消息ID无效")
|
||||
except KeyError:
|
||||
raise ValueError("缺少必需参数: message_id") from None
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"消息ID无效: {args['message_id']} - {str(e)}") from None
|
||||
|
||||
return (
|
||||
CommandType.DELETE_MSG.value,
|
||||
{
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
处理AI语音发送命令的逻辑。
|
||||
并返回 NapCat 兼容的 (action, params) 元组。
|
||||
"""
|
||||
if not group_info or not group_info.group_id:
|
||||
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
|
||||
if not args:
|
||||
raise ValueError("AI语音发送命令缺少参数")
|
||||
|
||||
group_id: int = int(group_info.group_id)
|
||||
character_id = args.get("character")
|
||||
text_content = args.get("text")
|
||||
|
||||
if not character_id or not text_content:
|
||||
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
|
||||
|
||||
return (
|
||||
CommandType.AI_VOICE_SEND.value,
|
||||
{
|
||||
"group_id": group_id,
|
||||
"text": text_content,
|
||||
"character": character_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||
await self.server_connection.send(payload)
|
||||
try:
|
||||
response = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("发送消息超时,未收到响应")
|
||||
return {"status": "error", "message": "timeout"}
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
return response
|
||||
|
||||
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
|
||||
# 修改 additional_config,添加 echo 字段
|
||||
if message_base.message_info.additional_config is None:
|
||||
message_base.message_info.additional_config = {}
|
||||
|
||||
message_base.message_info.additional_config["echo"] = True
|
||||
|
||||
# 获取原始的 mmc_message_id
|
||||
mmc_message_id = message_base.message_info.message_id
|
||||
|
||||
# 修改 message_segment 为 notify 类型
|
||||
message_base.message_segment = Seg(
|
||||
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
|
||||
)
|
||||
await message_send_instance.message_send(message_base)
|
||||
logger.debug("已回送消息ID")
|
||||
return
|
||||
|
||||
|
||||
send_handler = SendHandler()
|
||||
310
MaiBot-Napcat-Adapter-dev/src/utils.py
Normal file
310
MaiBot-Napcat-Adapter-dev/src/utils.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import websockets as Server
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
from src.database import BanUser, db_manager
|
||||
from .logger import logger
|
||||
from .response_pool import get_response
|
||||
|
||||
from PIL import Image
|
||||
from typing import Union, List, Tuple, Optional
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
context = ssl.create_default_context()
|
||||
context.set_ciphers("DEFAULT@SECLEVEL=1")
|
||||
context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
kwargs["ssl_context"] = context
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群相关信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
|
||||
"""
|
||||
获取群详细信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群详细信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取群详细信息超时,群号: {group_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群详细信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取群成员信息
|
||||
|
||||
返回值需要处理可能为空的情况
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_group_member_info",
|
||||
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取成员信息失败: {e}")
|
||||
return None
|
||||
logger.debug(socket_response)
|
||||
return socket_response.get("data")
|
||||
|
||||
|
||||
async def get_image_base64(url: str) -> str:
|
||||
# sourcery skip: raise-specific-error
|
||||
"""获取图片/表情包的Base64"""
|
||||
logger.debug(f"下载图片: {url}")
|
||||
http = SSLAdapter()
|
||||
try:
|
||||
response = http.request("GET", url, timeout=10)
|
||||
if response.status != 200:
|
||||
raise Exception(f"HTTP Error: {response.status}")
|
||||
image_bytes = response.data
|
||||
return base64.b64encode(image_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片下载失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_image_to_gif(image_base64: str) -> str:
|
||||
# sourcery skip: extract-method
|
||||
"""
|
||||
将Base64编码的图片转换为GIF格式
|
||||
Parameters:
|
||||
image_base64: str: Base64编码的图片数据
|
||||
Returns:
|
||||
str: Base64编码的GIF图片数据
|
||||
"""
|
||||
logger.debug("转换图片为GIF格式")
|
||||
try:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="GIF")
|
||||
output_buffer.seek(0)
|
||||
return base64.b64encode(output_buffer.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为GIF失败: {str(e)}")
|
||||
return image_base64
|
||||
|
||||
|
||||
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
获取自身信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
Returns:
|
||||
data: dict: 返回的自身信息
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error("获取自身信息超时")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取自身信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
def get_image_format(raw_data: str) -> str:
|
||||
"""
|
||||
从Base64编码的数据中确定图片的格式。
|
||||
Parameters:
|
||||
raw_data: str: Base64编码的图片数据。
|
||||
Returns:
|
||||
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
|
||||
"""
|
||||
image_bytes = base64.b64decode(raw_data)
|
||||
return Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
|
||||
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
|
||||
"""
|
||||
获取陌生人信息
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
user_id: 用户ID
|
||||
Returns:
|
||||
dict: 返回的陌生人信息
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
except TimeoutError:
|
||||
logger.error(f"获取陌生人信息超时,用户ID: {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取陌生人信息失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
|
||||
"""
|
||||
获取消息详情,可能为空
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
message_id: 消息ID
|
||||
Returns:
|
||||
dict: 返回的消息详情
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取消息详情超时,消息ID: {message_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(response)
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def get_record_detail(
|
||||
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
|
||||
) -> dict | None:
|
||||
"""
|
||||
获取语音消息内容
|
||||
Parameters:
|
||||
websocket: WebSocket连接对象
|
||||
file: 文件名
|
||||
file_id: 文件ID
|
||||
Returns:
|
||||
dict: 返回的语音消息详情
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
except TimeoutError:
|
||||
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取语音消息详情失败: {e}")
|
||||
return None
|
||||
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
|
||||
return response.get("data")
|
||||
|
||||
|
||||
async def read_ban_list(
|
||||
websocket: Server.ServerConnection,
|
||||
) -> Tuple[List[BanUser], List[BanUser]]:
|
||||
"""
|
||||
从根目录下的data文件夹中的文件读取禁言列表。
|
||||
同时自动更新已经失效禁言
|
||||
Returns:
|
||||
Tuple[
|
||||
一个仍在禁言中的用户的BanUser列表,
|
||||
一个已经自然解除禁言的用户的BanUser列表,
|
||||
一个仍在全体禁言中的群的BanUser列表,
|
||||
一个已经自然解除全体禁言的群的BanUser列表,
|
||||
]
|
||||
"""
|
||||
try:
|
||||
ban_list = db_manager.get_ban_records()
|
||||
lifted_list: List[BanUser] = []
|
||||
logger.info("已经读取禁言列表")
|
||||
for ban_record in ban_list:
|
||||
if ban_record.user_id == 0:
|
||||
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
|
||||
if fetched_group_info is None:
|
||||
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
group_all_shut: int = fetched_group_info.get("group_all_shut")
|
||||
if group_all_shut == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
else:
|
||||
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
|
||||
if fetched_member_info is None:
|
||||
logger.warning(
|
||||
f"无法获取群成员信息,用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
|
||||
)
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
continue
|
||||
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
|
||||
if lift_ban_time == 0:
|
||||
lifted_list.append(ban_record)
|
||||
ban_list.remove(ban_record)
|
||||
else:
|
||||
ban_record.lift_time = lift_ban_time
|
||||
db_manager.update_ban_record(ban_list)
|
||||
return ban_list, lifted_list
|
||||
except Exception as e:
|
||||
logger.error(f"读取禁言列表失败: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
def save_ban_record(list: List[BanUser]):
|
||||
return db_manager.update_ban_record(list)
|
||||
34
MaiBot-Napcat-Adapter-dev/template/template_config.toml
Normal file
34
MaiBot-Napcat-Adapter-dev/template/template_config.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[inner]
|
||||
version = "0.1.1" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
host = "localhost" # Napcat设定的主机地址
|
||||
port = 8095 # Napcat设定的端口
|
||||
heartbeat_interval = 30 # 与Napcat设置的心跳相同(按秒计)
|
||||
|
||||
[maibot_server] # 连接麦麦的ws服务设置
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
|
||||
[chat] # 黑白名单功能
|
||||
group_list_type = "whitelist" # 群组名单类型,可选为:whitelist, blacklist
|
||||
group_list = [] # 群组名单
|
||||
# 当group_list_type为whitelist时,只有群组名单中的群组可以聊天
|
||||
# 当group_list_type为blacklist时,群组名单中的任何群组无法聊天
|
||||
private_list_type = "whitelist" # 私聊名单类型,可选为:whitelist, blacklist
|
||||
private_list = [] # 私聊名单
|
||||
# 当private_list_type为whitelist时,只有私聊名单中的用户可以聊天
|
||||
# 当private_list_type为blacklist时,私聊名单中的任何用户无法聊天
|
||||
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
|
||||
ban_qq_bot = false # 是否屏蔽QQ官方机器人
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
0
identifier.sqlite
Normal file
0
identifier.sqlite
Normal file
@@ -8,9 +8,6 @@ from src.plugin_system import (
|
||||
ComponentInfo,
|
||||
ActionActivationType,
|
||||
ConfigField,
|
||||
BaseEventHandler,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType
|
||||
)
|
||||
|
||||
@@ -126,22 +123,23 @@ class TimeCommand(BaseCommand):
|
||||
message = f"⏰ 当前时间:{time_str}"
|
||||
await self.send_text(message)
|
||||
|
||||
return True, f"显示了当前时间: {time_str}", True
|
||||
return True, f"显示了当前x时间: {time_str}", True
|
||||
|
||||
|
||||
class PrintMessage(BaseEventHandler):
|
||||
"""打印消息事件处理器 - 处理打印消息事件"""
|
||||
|
||||
event_type = EventType.ON_MESSAGE
|
||||
handler_name = "print_message_handler"
|
||||
handler_description = "打印接收到的消息"
|
||||
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
||||
"""执行打印消息事件处理"""
|
||||
# 打印接收到的消息
|
||||
if self.get_config("print_message.enabled", False):
|
||||
print(f"接收到消息: {message.raw_message}")
|
||||
return True, True, "消息已打印"
|
||||
# class PrintMessage(BaseEventHandler):
|
||||
# """打印消息事件处理器 - 处理打印消息事件"""
|
||||
#
|
||||
# event_type = EventType.ON_MESSAGE
|
||||
# handler_name = "print_message_handler"
|
||||
# handler_description = "打印接收到的消息"
|
||||
#
|
||||
# async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
|
||||
# """执行打印消息事件处理"""
|
||||
# # 打印接收到的消息
|
||||
#
|
||||
# if self.get_config("print_message.enabled", False):
|
||||
# print(f"接收到消息: {message.raw_message}")
|
||||
# return True, True, "消息已打印1"
|
||||
|
||||
|
||||
# ===== 插件注册 =====
|
||||
@@ -184,7 +182,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
|
||||
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
|
||||
(TimeCommand.get_command_info(), TimeCommand),
|
||||
(PrintMessage.get_handler_info(), PrintMessage),
|
||||
# (PrintMessage.get_handler_info(), PrintMessage),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.14",
|
||||
"aiohttp-cors>=0.8.1",
|
||||
"apscheduler>=3.11.0",
|
||||
"colorama>=0.4.6",
|
||||
"cryptography>=45.0.5",
|
||||
@@ -12,6 +13,8 @@ dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"faiss-cpu>=1.11.0",
|
||||
"fastapi>=0.116.0",
|
||||
"google>=3.0.0",
|
||||
"google-genai>=1.29.0",
|
||||
"jieba>=0.42.1",
|
||||
"json-repair>=0.47.6",
|
||||
"jsonlines>=4.0.0",
|
||||
@@ -22,12 +25,12 @@ dependencies = [
|
||||
"openai>=1.95.0",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"psutil>=7.0.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pymongo>=4.13.2",
|
||||
"pymysql>=1.1.1",
|
||||
"pypinyin>=0.54.0",
|
||||
"python-dateutil>=2.9.0.post0",
|
||||
"python-dotenv>=1.1.1",
|
||||
@@ -41,6 +44,7 @@ dependencies = [
|
||||
"scipy>=1.15.3",
|
||||
"seaborn>=0.13.2",
|
||||
"setuptools>=80.9.0",
|
||||
"sqlalchemy>=2.0.42",
|
||||
"strawberry-graphql[fastapi]>=0.275.5",
|
||||
"structlog>=25.4.0",
|
||||
"toml>=0.10.2",
|
||||
@@ -50,9 +54,13 @@ dependencies = [
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"watchdog>=6.0.0",
|
||||
"websockets>=15.0.1",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
default = true
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -96,3 +104,8 @@ skip-magic-trailing-comma = false
|
||||
|
||||
# 自动检测合适的换行符
|
||||
line-ending = "auto"
|
||||
|
||||
[dependency-groups]
|
||||
lint = [
|
||||
"loguru>=0.7.3",
|
||||
]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
sqlalchemy
|
||||
APScheduler
|
||||
Pillow
|
||||
aiohttp
|
||||
@@ -47,3 +48,4 @@ reportportal-client
|
||||
scikit-learn
|
||||
seaborn
|
||||
structlog
|
||||
watchdog
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
# 直接从数据库查询ChatStreams表
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
# 如果是私聊,显示用户昵称
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
'0-1天': 0,
|
||||
'1-3天': 0,
|
||||
'3-7天': 0,
|
||||
'7-14天': 0,
|
||||
'14-30天': 0,
|
||||
'30-60天': 0,
|
||||
'60-90天': 0,
|
||||
'90+天': 0
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24*3600)
|
||||
if diff_days < 1:
|
||||
distribution['0-1天'] += 1
|
||||
elif diff_days < 3:
|
||||
distribution['1-3天'] += 1
|
||||
elif diff_days < 7:
|
||||
distribution['3-7天'] += 1
|
||||
elif diff_days < 14:
|
||||
distribution['7-14天'] += 1
|
||||
elif diff_days < 30:
|
||||
distribution['14-30天'] += 1
|
||||
elif diff_days < 60:
|
||||
distribution['30-60天'] += 1
|
||||
elif diff_days < 90:
|
||||
distribution['60-90天'] += 1
|
||||
else:
|
||||
distribution['90+天'] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {
|
||||
'0-1': 0,
|
||||
'1-2': 0,
|
||||
'2-3': 0,
|
||||
'3-4': 0,
|
||||
'4-5': 0,
|
||||
'5-10': 0,
|
||||
'10+': 0
|
||||
}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution['0-1'] += 1
|
||||
elif cnt < 2:
|
||||
distribution['1-2'] += 1
|
||||
elif cnt < 3:
|
||||
distribution['2-3'] += 1
|
||||
elif cnt < 4:
|
||||
distribution['3-4'] += 1
|
||||
elif cnt < 5:
|
||||
distribution['4-5'] += 1
|
||||
elif cnt < 10:
|
||||
distribution['5-10'] += 1
|
||||
else:
|
||||
distribution['10+'] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return (Expression.select()
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.order_by(Expression.count.desc())
|
||||
.limit(top_n))
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count/total*100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count/total*100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
for i, expr in enumerate(top_exprs, 1):
|
||||
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for expression statistics"""
|
||||
# Get all expressions
|
||||
expressions = list(Expression.select())
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("表达式统计分析")
|
||||
print("="*50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
show_overall_statistics(expressions, total)
|
||||
elif 1 <= choice_num <= len(chat_info):
|
||||
chat_id, chat_name = chat_info[choice_num - 1]
|
||||
show_chat_statistics(chat_id, chat_name)
|
||||
else:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
@@ -1,217 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建临时目录: {TEMP_DIR}")
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract,
|
||||
request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_rdf_build,
|
||||
request_type="lpmm.rdf_build"
|
||||
)
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
# 使用文件锁检查和读取缓存文件
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
# 存在对应的提取结果
|
||||
logger.info(f"找到缓存的提取结果:{pg_hash}")
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f), None
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON文件损坏,删除它并重新处理
|
||||
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
return None, pg_hash
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(doc_item, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
|
||||
def signal_handler(_signum, _frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||
try:
|
||||
for future in as_completed(future_to_hash):
|
||||
if shutdown_event.is_set():
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_sha256.append(failed_hash)
|
||||
logger.error(f"提取失败:{failed_hash}")
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
|
||||
# 合并所有文件的提取结果并保存
|
||||
if open_ie_doc:
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||
)
|
||||
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
else:
|
||||
logger.warning("没有可保存的信息提取结果")
|
||||
|
||||
logger.info("--------信息提取完成--------")
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,287 +0,0 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
|
||||
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of interest_value"""
|
||||
distribution = {
|
||||
'0.000-0.010': 0,
|
||||
'0.010-0.050': 0,
|
||||
'0.050-0.100': 0,
|
||||
'0.100-0.500': 0,
|
||||
'0.500-1.000': 0,
|
||||
'1.000-2.000': 0,
|
||||
'2.000-5.000': 0,
|
||||
'5.000-10.000': 0,
|
||||
'10.000+': 0
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.interest_value is None or msg.interest_value == 0.0:
|
||||
continue
|
||||
|
||||
value = float(msg.interest_value)
|
||||
if value < 0.010:
|
||||
distribution['0.000-0.010'] += 1
|
||||
elif value < 0.050:
|
||||
distribution['0.010-0.050'] += 1
|
||||
elif value < 0.100:
|
||||
distribution['0.050-0.100'] += 1
|
||||
elif value < 0.500:
|
||||
distribution['0.100-0.500'] += 1
|
||||
elif value < 1.000:
|
||||
distribution['0.500-1.000'] += 1
|
||||
elif value < 2.000:
|
||||
distribution['1.000-2.000'] += 1
|
||||
elif value < 5.000:
|
||||
distribution['2.000-5.000'] += 1
|
||||
elif value < 10.000:
|
||||
distribution['5.000-10.000'] += 1
|
||||
else:
|
||||
distribution['10.000+'] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for interest_value"""
|
||||
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
|
||||
|
||||
if not values:
|
||||
return {
|
||||
'count': 0,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
}
|
||||
|
||||
values.sort()
|
||||
count = len(values)
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'avg': sum(values) / count,
|
||||
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
).count()
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
"""Analyze interest values with optional filters"""
|
||||
|
||||
# 构建查询条件
|
||||
query = Messages.select().where(
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
)
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_interest_value_distribution(messages)
|
||||
stats = get_interest_value_stats(messages)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Interest Value 分析结果 ===")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
||||
print(f"最小值: {stats['min']:.3f}")
|
||||
print(f"最大值: {stats['max']:.3f}")
|
||||
print(f"平均值: {stats['avg']:.3f}")
|
||||
print(f"中位数: {stats['median']:.3f}")
|
||||
|
||||
print("\nInterest Value 分布:")
|
||||
total = stats['count']
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name}: {count} ({percentage:.2f}%)")
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for interest value analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("Interest Value 分析工具")
|
||||
print("="*50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到有interest_value数据的聊天")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条有效消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_interest_values(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,237 +0,0 @@
|
||||
"""
|
||||
插件Manifest管理命令行工具
|
||||
|
||||
提供插件manifest文件的创建、验证和管理功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.utils.manifest_utils import (
|
||||
ManifestValidator,
|
||||
)
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
logger = get_logger("manifest_tool")
|
||||
|
||||
|
||||
def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool:
|
||||
"""创建最小化的manifest文件
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录
|
||||
plugin_name: 插件名称
|
||||
description: 插件描述
|
||||
author: 插件作者
|
||||
|
||||
Returns:
|
||||
bool: 是否创建成功
|
||||
"""
|
||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
|
||||
if os.path.exists(manifest_path):
|
||||
print(f"❌ Manifest文件已存在: {manifest_path}")
|
||||
return False
|
||||
|
||||
# 创建最小化manifest
|
||||
minimal_manifest = {
|
||||
"manifest_version": 1,
|
||||
"name": plugin_name,
|
||||
"version": "1.0.0",
|
||||
"description": description or f"{plugin_name}插件",
|
||||
"author": {"name": author or "Unknown"},
|
||||
}
|
||||
|
||||
try:
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(minimal_manifest, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 已创建最小化manifest文件: {manifest_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ 创建manifest文件失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool:
|
||||
"""创建完整的manifest模板文件
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否创建成功
|
||||
"""
|
||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
|
||||
if os.path.exists(manifest_path):
|
||||
print(f"❌ Manifest文件已存在: {manifest_path}")
|
||||
return False
|
||||
|
||||
# 创建完整模板
|
||||
complete_manifest = {
|
||||
"manifest_version": 1,
|
||||
"name": plugin_name,
|
||||
"version": "1.0.0",
|
||||
"description": f"{plugin_name}插件描述",
|
||||
"author": {"name": "插件作者", "url": "https://github.com/your-username"},
|
||||
"license": "MIT",
|
||||
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
|
||||
"homepage_url": "https://github.com/your-repo",
|
||||
"repository_url": "https://github.com/your-repo",
|
||||
"keywords": ["keyword1", "keyword2"],
|
||||
"categories": ["Category1"],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": False,
|
||||
"plugin_type": "general",
|
||||
"components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}],
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(complete_manifest, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 已创建完整manifest模板: {manifest_path}")
|
||||
print("💡 请根据实际情况修改manifest文件中的内容")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ 创建manifest文件失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def validate_manifest_file(plugin_dir: str) -> bool:
|
||||
"""验证manifest文件
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录
|
||||
|
||||
Returns:
|
||||
bool: 是否验证通过
|
||||
"""
|
||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
print(f"❌ 未找到manifest文件: {manifest_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest_data = json.load(f)
|
||||
|
||||
validator = ManifestValidator()
|
||||
is_valid = validator.validate_manifest(manifest_data)
|
||||
|
||||
# 显示验证结果
|
||||
print("📋 Manifest验证结果:")
|
||||
print(validator.get_validation_report())
|
||||
|
||||
if is_valid:
|
||||
print("✅ Manifest文件验证通过")
|
||||
else:
|
||||
print("❌ Manifest文件验证失败")
|
||||
|
||||
return is_valid
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"❌ Manifest文件格式错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ 验证过程中发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def scan_plugins_without_manifest(root_dir: str) -> None:
|
||||
"""扫描缺少manifest文件的插件
|
||||
|
||||
Args:
|
||||
root_dir: 扫描的根目录
|
||||
"""
|
||||
print(f"🔍 扫描目录: {root_dir}")
|
||||
|
||||
plugins_without_manifest = []
|
||||
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
# 跳过隐藏目录和__pycache__
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
|
||||
|
||||
# 检查是否包含plugin.py文件(标识为插件目录)
|
||||
if "plugin.py" in files:
|
||||
manifest_path = os.path.join(root, "_manifest.json")
|
||||
if not os.path.exists(manifest_path):
|
||||
plugins_without_manifest.append(root)
|
||||
|
||||
if plugins_without_manifest:
|
||||
print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:")
|
||||
for plugin_dir in plugins_without_manifest:
|
||||
plugin_name = os.path.basename(plugin_dir)
|
||||
print(f" - {plugin_name}: {plugin_dir}")
|
||||
print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件")
|
||||
else:
|
||||
print("✅ 所有插件都有manifest文件")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="插件Manifest管理工具")
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 创建最小化manifest命令
|
||||
create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件")
|
||||
create_minimal_parser.add_argument("plugin_dir", help="插件目录路径")
|
||||
create_minimal_parser.add_argument("--name", help="插件名称")
|
||||
create_minimal_parser.add_argument("--description", help="插件描述")
|
||||
create_minimal_parser.add_argument("--author", help="插件作者")
|
||||
|
||||
# 创建完整manifest命令
|
||||
create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板")
|
||||
create_complete_parser.add_argument("plugin_dir", help="插件目录路径")
|
||||
create_complete_parser.add_argument("--name", help="插件名称")
|
||||
|
||||
# 验证manifest命令
|
||||
validate_parser = subparsers.add_parser("validate", help="验证manifest文件")
|
||||
validate_parser.add_argument("plugin_dir", help="插件目录路径")
|
||||
|
||||
# 扫描插件命令
|
||||
scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件")
|
||||
scan_parser.add_argument("root_dir", help="扫描的根目录路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
try:
|
||||
if args.command == "create-minimal":
|
||||
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
|
||||
success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "")
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
elif args.command == "create-complete":
|
||||
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
|
||||
success = create_complete_manifest(args.plugin_dir, plugin_name)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
elif args.command == "validate":
|
||||
success = validate_manifest_file(args.plugin_dir)
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
elif args.command == "scan":
|
||||
scan_plugins_without_manifest(args.root_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 执行命令时发生错误: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,920 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import sys # 新增系统模块导入
|
||||
|
||||
# import time
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure
|
||||
from peewee import Model, Field, IntegrityError
|
||||
|
||||
# Rich 进度条和显示组件
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
TimeElapsedColumn,
|
||||
SpinnerColumn,
|
||||
)
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
# from rich.text import Text
|
||||
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import (
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mongodb_to_sqlite")
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationConfig:
|
||||
"""迁移配置类"""
|
||||
|
||||
mongo_collection: str
|
||||
target_model: Type[Model]
|
||||
field_mapping: Dict[str, str]
|
||||
batch_size: int = 500
|
||||
enable_validation: bool = True
|
||||
skip_duplicates: bool = True
|
||||
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
|
||||
|
||||
|
||||
# 数据验证相关类已移除 - 用户要求不要数据验证
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationCheckpoint:
|
||||
"""迁移断点数据"""
|
||||
|
||||
collection_name: str
|
||||
processed_count: int
|
||||
last_processed_id: Any
|
||||
timestamp: datetime
|
||||
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationStats:
|
||||
"""迁移统计信息"""
|
||||
|
||||
total_documents: int = 0
|
||||
processed_count: int = 0
|
||||
success_count: int = 0
|
||||
error_count: int = 0
|
||||
skipped_count: int = 0
|
||||
duplicate_count: int = 0
|
||||
validation_errors: int = 0
|
||||
batch_insert_count: int = 0
|
||||
errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
|
||||
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
|
||||
"""添加错误记录"""
|
||||
self.errors.append(
|
||||
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
|
||||
)
|
||||
self.error_count += 1
|
||||
|
||||
def add_validation_error(self, doc_id: Any, field: str, error: str):
|
||||
"""添加验证错误"""
|
||||
self.add_error(doc_id, f"验证失败 - {field}: {error}")
|
||||
self.validation_errors += 1
|
||||
|
||||
|
||||
class MongoToSQLiteMigrator:
|
||||
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
|
||||
|
||||
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
|
||||
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
|
||||
self.mongo_uri = mongo_uri or self._build_mongo_uri()
|
||||
self.mongo_client: Optional[MongoClient] = None
|
||||
self.mongo_db = None
|
||||
|
||||
# 迁移配置
|
||||
self.migration_configs = self._initialize_migration_configs()
|
||||
|
||||
# 进度条控制台
|
||||
self.console = Console()
|
||||
# 检查点目录
|
||||
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
|
||||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 验证规则已禁用
|
||||
self.validation_rules = self._initialize_validation_rules()
|
||||
|
||||
def _build_mongo_uri(self) -> str:
|
||||
"""构建MongoDB连接URI"""
|
||||
if mongo_uri := os.getenv("MONGODB_URI"):
|
||||
return mongo_uri
|
||||
|
||||
user = os.getenv("MONGODB_USER")
|
||||
password = os.getenv("MONGODB_PASS")
|
||||
host = os.getenv("MONGODB_HOST", "localhost")
|
||||
port = os.getenv("MONGODB_PORT", "27017")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
|
||||
|
||||
if user and password:
|
||||
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
|
||||
else:
|
||||
return f"mongodb://{host}:{port}/{self.database_name}"
|
||||
|
||||
def _initialize_migration_configs(self) -> List[MigrationConfig]:
|
||||
"""初始化迁移配置"""
|
||||
return [ # 表情包迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="emoji",
|
||||
target_model=Emoji,
|
||||
field_mapping={
|
||||
"full_path": "full_path",
|
||||
"format": "format",
|
||||
"hash": "emoji_hash",
|
||||
"description": "description",
|
||||
"emotion": "emotion",
|
||||
"usage_count": "usage_count",
|
||||
"last_used_time": "last_used_time",
|
||||
# record_time字段将在转换时自动设置为当前时间
|
||||
},
|
||||
enable_validation=False, # 禁用数据验证
|
||||
unique_fields=["full_path", "emoji_hash"],
|
||||
),
|
||||
# 聊天流迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="chat_streams",
|
||||
target_model=ChatStreams,
|
||||
field_mapping={
|
||||
"stream_id": "stream_id",
|
||||
"create_time": "create_time",
|
||||
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
|
||||
"group_info.group_id": "group_id", # 同上
|
||||
"group_info.group_name": "group_name", # 同上
|
||||
"last_active_time": "last_active_time",
|
||||
"platform": "platform",
|
||||
"user_info.platform": "user_platform",
|
||||
"user_info.user_id": "user_id",
|
||||
"user_info.user_nickname": "user_nickname",
|
||||
"user_info.user_cardname": "user_cardname",
|
||||
},
|
||||
enable_validation=False, # 禁用数据验证
|
||||
unique_fields=["stream_id"],
|
||||
),
|
||||
# 消息迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="messages",
|
||||
target_model=Messages,
|
||||
field_mapping={
|
||||
"message_id": "message_id",
|
||||
"time": "time",
|
||||
"chat_id": "chat_id",
|
||||
"chat_info.stream_id": "chat_info_stream_id",
|
||||
"chat_info.platform": "chat_info_platform",
|
||||
"chat_info.user_info.platform": "chat_info_user_platform",
|
||||
"chat_info.user_info.user_id": "chat_info_user_id",
|
||||
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
|
||||
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
|
||||
"chat_info.group_info.platform": "chat_info_group_platform",
|
||||
"chat_info.group_info.group_id": "chat_info_group_id",
|
||||
"chat_info.group_info.group_name": "chat_info_group_name",
|
||||
"chat_info.create_time": "chat_info_create_time",
|
||||
"chat_info.last_active_time": "chat_info_last_active_time",
|
||||
"user_info.platform": "user_platform",
|
||||
"user_info.user_id": "user_id",
|
||||
"user_info.user_nickname": "user_nickname",
|
||||
"user_info.user_cardname": "user_cardname",
|
||||
"processed_plain_text": "processed_plain_text",
|
||||
"memorized_times": "memorized_times",
|
||||
},
|
||||
enable_validation=False, # 禁用数据验证
|
||||
unique_fields=["message_id"],
|
||||
),
|
||||
# 图片迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="images",
|
||||
target_model=Images,
|
||||
field_mapping={
|
||||
"hash": "emoji_hash",
|
||||
"description": "description",
|
||||
"path": "path",
|
||||
"timestamp": "timestamp",
|
||||
"type": "type",
|
||||
},
|
||||
unique_fields=["path"],
|
||||
),
|
||||
# 图片描述迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="image_descriptions",
|
||||
target_model=ImageDescriptions,
|
||||
field_mapping={
|
||||
"type": "type",
|
||||
"hash": "image_description_hash",
|
||||
"description": "description",
|
||||
"timestamp": "timestamp",
|
||||
},
|
||||
unique_fields=["image_description_hash", "type"],
|
||||
),
|
||||
# 个人信息迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="person_info",
|
||||
target_model=PersonInfo,
|
||||
field_mapping={
|
||||
"person_id": "person_id",
|
||||
"person_name": "person_name",
|
||||
"name_reason": "name_reason",
|
||||
"platform": "platform",
|
||||
"user_id": "user_id",
|
||||
"nickname": "nickname",
|
||||
"relationship_value": "relationship_value",
|
||||
"konw_time": "know_time",
|
||||
},
|
||||
unique_fields=["person_id"],
|
||||
),
|
||||
# 知识库迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="knowledges",
|
||||
target_model=Knowledges,
|
||||
field_mapping={"content": "content", "embedding": "embedding"},
|
||||
unique_fields=["content"], # 假设内容唯一
|
||||
),
|
||||
# 思考日志迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="thinking_log",
|
||||
target_model=ThinkingLog,
|
||||
field_mapping={
|
||||
"chat_id": "chat_id",
|
||||
"trigger_text": "trigger_text",
|
||||
"response_text": "response_text",
|
||||
"trigger_info": "trigger_info_json",
|
||||
"response_info": "response_info_json",
|
||||
"timing_results": "timing_results_json",
|
||||
"chat_history": "chat_history_json",
|
||||
"chat_history_in_thinking": "chat_history_in_thinking_json",
|
||||
"chat_history_after_response": "chat_history_after_response_json",
|
||||
"heartflow_data": "heartflow_data_json",
|
||||
"reasoning_data": "reasoning_data_json",
|
||||
},
|
||||
unique_fields=["chat_id", "trigger_text"],
|
||||
),
|
||||
# 图节点迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="graph_data.nodes",
|
||||
target_model=GraphNodes,
|
||||
field_mapping={
|
||||
"concept": "concept",
|
||||
"memory_items": "memory_items",
|
||||
"hash": "hash",
|
||||
"created_time": "created_time",
|
||||
"last_modified": "last_modified",
|
||||
},
|
||||
unique_fields=["concept"],
|
||||
),
|
||||
# 图边迁移配置
|
||||
MigrationConfig(
|
||||
mongo_collection="graph_data.edges",
|
||||
target_model=GraphEdges,
|
||||
field_mapping={
|
||||
"source": "source",
|
||||
"target": "target",
|
||||
"strength": "strength",
|
||||
"hash": "hash",
|
||||
"created_time": "created_time",
|
||||
"last_modified": "last_modified",
|
||||
},
|
||||
unique_fields=["source", "target"], # 组合唯一性
|
||||
),
|
||||
]
|
||||
|
||||
def _initialize_validation_rules(self) -> Dict[str, Any]:
|
||||
"""数据验证已禁用 - 返回空字典"""
|
||||
return {}
|
||||
|
||||
def connect_mongodb(self) -> bool:
|
||||
"""连接到MongoDB"""
|
||||
try:
|
||||
self.mongo_client = MongoClient(
|
||||
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
self.mongo_client.admin.command("ping")
|
||||
self.mongo_db = self.mongo_client[self.database_name]
|
||||
|
||||
logger.info(f"成功连接到MongoDB: {self.database_name}")
|
||||
return True
|
||||
|
||||
except ConnectionFailure as e:
|
||||
logger.error(f"MongoDB连接失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"MongoDB连接异常: {e}")
|
||||
return False
|
||||
|
||||
def disconnect_mongodb(self):
|
||||
"""断开MongoDB连接"""
|
||||
if self.mongo_client:
|
||||
self.mongo_client.close()
|
||||
logger.info("MongoDB连接已关闭")
|
||||
|
||||
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
|
||||
"""获取嵌套字段的值"""
|
||||
if "." not in field_path:
|
||||
return document.get(field_path)
|
||||
|
||||
parts = field_path.split(".")
|
||||
value = document
|
||||
|
||||
for part in parts:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(part)
|
||||
else:
|
||||
return None
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
return value
|
||||
|
||||
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
|
||||
"""根据目标字段类型转换值"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
field_type = target_field.__class__.__name__
|
||||
|
||||
try:
|
||||
if target_field.name == "record_time" and field_type == "DateTimeField":
|
||||
return datetime.now()
|
||||
|
||||
if field_type in ["CharField", "TextField"]:
|
||||
if isinstance(value, (list, dict)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
return str(value) if value is not None else ""
|
||||
|
||||
elif field_type == "IntegerField":
|
||||
if isinstance(value, str):
|
||||
# 处理字符串数字
|
||||
clean_value = value.strip()
|
||||
if clean_value.replace(".", "").replace("-", "").isdigit():
|
||||
return int(float(clean_value))
|
||||
return 0
|
||||
return int(value) if value is not None else 0
|
||||
|
||||
elif field_type in ["FloatField", "DoubleField"]:
|
||||
return float(value) if value is not None else 0.0
|
||||
|
||||
elif field_type == "BooleanField":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
return bool(value)
|
||||
|
||||
elif field_type == "DateTimeField":
|
||||
if isinstance(value, (int, float)):
|
||||
return datetime.fromtimestamp(value)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
# 尝试解析ISO格式日期
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
try:
|
||||
# 尝试解析时间戳字符串
|
||||
return datetime.fromtimestamp(float(value))
|
||||
except ValueError:
|
||||
return datetime.now()
|
||||
return datetime.now()
|
||||
|
||||
return value
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
|
||||
return self._get_default_value_for_field(target_field)
|
||||
|
||||
def _get_default_value_for_field(self, field: Field) -> Any:
|
||||
"""获取字段的默认值"""
|
||||
field_type = field.__class__.__name__
|
||||
|
||||
if hasattr(field, "default") and field.default is not None:
|
||||
return field.default
|
||||
|
||||
if field.null:
|
||||
return None
|
||||
|
||||
# 根据字段类型返回默认值
|
||||
if field_type in ["CharField", "TextField"]:
|
||||
return ""
|
||||
elif field_type == "IntegerField":
|
||||
return 0
|
||||
elif field_type in ["FloatField", "DoubleField"]:
|
||||
return 0.0
|
||||
elif field_type == "BooleanField":
|
||||
return False
|
||||
elif field_type == "DateTimeField":
|
||||
return datetime.now()
|
||||
|
||||
return None
|
||||
|
||||
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||||
"""数据验证已禁用 - 始终返回True"""
|
||||
return True
|
||||
|
||||
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
|
||||
"""保存迁移断点"""
|
||||
checkpoint = MigrationCheckpoint(
|
||||
collection_name=collection_name,
|
||||
processed_count=processed_count,
|
||||
last_processed_id=last_id,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||||
try:
|
||||
with open(checkpoint_file, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存断点失败: {e}")
|
||||
|
||||
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
|
||||
"""加载迁移断点"""
|
||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||||
if not checkpoint_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(checkpoint_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"加载断点失败: {e}")
|
||||
return None
|
||||
|
||||
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
|
||||
"""批量插入数据"""
|
||||
if not data_list:
|
||||
return 0
|
||||
|
||||
success_count = 0
|
||||
try:
|
||||
with db.atomic():
|
||||
# 分批插入,避免SQL语句过长
|
||||
batch_size = 100
|
||||
for i in range(0, len(data_list), batch_size):
|
||||
batch = data_list[i : i + batch_size]
|
||||
model.insert_many(batch).execute()
|
||||
success_count += len(batch)
|
||||
except Exception as e:
|
||||
logger.error(f"批量插入失败: {e}")
|
||||
# 如果批量插入失败,尝试逐个插入
|
||||
for data in data_list:
|
||||
try:
|
||||
model.create(**data)
|
||||
success_count += 1
|
||||
except Exception:
|
||||
pass # 忽略单个插入失败
|
||||
|
||||
return success_count
|
||||
|
||||
def _check_duplicate_by_unique_fields(
|
||||
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
|
||||
) -> bool:
|
||||
"""根据唯一字段检查重复"""
|
||||
if not unique_fields:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = model.select()
|
||||
for field_name in unique_fields:
|
||||
if field_name in data and data[field_name] is not None:
|
||||
field_obj = getattr(model, field_name)
|
||||
query = query.where(field_obj == data[field_name])
|
||||
|
||||
return query.exists()
|
||||
except Exception as e:
|
||||
logger.debug(f"重复检查失败: {e}")
|
||||
return False
|
||||
|
||||
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
|
||||
"""使用ORM创建模型实例"""
|
||||
try:
|
||||
# 过滤掉不存在的字段
|
||||
valid_data = {}
|
||||
for field_name, value in data.items():
|
||||
if hasattr(model, field_name):
|
||||
valid_data[field_name] = value
|
||||
else:
|
||||
logger.debug(f"跳过未知字段: {field_name}")
|
||||
|
||||
# 创建实例
|
||||
instance = model.create(**valid_data)
|
||||
return instance
|
||||
|
||||
except IntegrityError as e:
|
||||
# 处理唯一约束冲突等完整性错误
|
||||
logger.debug(f"完整性约束冲突: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"创建模型实例失败: {e}")
|
||||
return None
|
||||
|
||||
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
|
||||
"""迁移单个集合 - 使用优化的批量插入和进度条"""
|
||||
stats = MigrationStats()
|
||||
stats.start_time = datetime.now()
|
||||
|
||||
# 检查是否有断点
|
||||
checkpoint = self._load_checkpoint(config.mongo_collection)
|
||||
start_from_id = checkpoint.last_processed_id if checkpoint else None
|
||||
if checkpoint:
|
||||
stats.processed_count = checkpoint.processed_count
|
||||
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
|
||||
|
||||
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
|
||||
|
||||
try:
|
||||
# 获取MongoDB集合
|
||||
mongo_collection = self.mongo_db[config.mongo_collection]
|
||||
|
||||
# 构建查询条件(用于断点恢复)
|
||||
query = {}
|
||||
if start_from_id:
|
||||
query = {"_id": {"$gt": start_from_id}}
|
||||
|
||||
stats.total_documents = mongo_collection.count_documents(query)
|
||||
|
||||
if stats.total_documents == 0:
|
||||
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
|
||||
return stats
|
||||
|
||||
logger.info(f"待迁移文档数量: {stats.total_documents}")
|
||||
|
||||
# 创建Rich进度条
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
|
||||
# 批量处理数据
|
||||
batch_data = []
|
||||
batch_count = 0
|
||||
last_processed_id = None
|
||||
|
||||
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
|
||||
try:
|
||||
doc_id = mongo_doc.get("_id", "unknown")
|
||||
last_processed_id = doc_id
|
||||
|
||||
# 构建目标数据
|
||||
target_data = {}
|
||||
for mongo_field, sqlite_field in config.field_mapping.items():
|
||||
value = self._get_nested_value(mongo_doc, mongo_field)
|
||||
|
||||
# 获取目标字段对象并转换类型
|
||||
if hasattr(config.target_model, sqlite_field):
|
||||
field_obj = getattr(config.target_model, sqlite_field)
|
||||
converted_value = self._convert_field_value(value, field_obj)
|
||||
target_data[sqlite_field] = converted_value
|
||||
|
||||
# 数据验证已禁用
|
||||
# if config.enable_validation:
|
||||
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
|
||||
# stats.skipped_count += 1
|
||||
# continue
|
||||
|
||||
# 重复检查
|
||||
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
|
||||
config.target_model, target_data, config.unique_fields
|
||||
):
|
||||
stats.duplicate_count += 1
|
||||
stats.skipped_count += 1
|
||||
logger.debug(f"跳过重复记录: {doc_id}")
|
||||
continue
|
||||
|
||||
# 添加到批量数据
|
||||
batch_data.append(target_data)
|
||||
stats.processed_count += 1
|
||||
|
||||
# 执行批量插入
|
||||
if len(batch_data) >= config.batch_size:
|
||||
success_count = self._batch_insert(config.target_model, batch_data)
|
||||
stats.success_count += success_count
|
||||
stats.batch_insert_count += 1
|
||||
|
||||
# 保存断点
|
||||
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
|
||||
|
||||
batch_data.clear()
|
||||
batch_count += 1
|
||||
|
||||
# 更新进度条
|
||||
progress.update(task, advance=config.batch_size)
|
||||
|
||||
except Exception as e:
|
||||
doc_id = mongo_doc.get("_id", "unknown")
|
||||
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
|
||||
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
|
||||
|
||||
# 处理剩余的批量数据
|
||||
if batch_data:
|
||||
success_count = self._batch_insert(config.target_model, batch_data)
|
||||
stats.success_count += success_count
|
||||
stats.batch_insert_count += 1
|
||||
progress.update(task, advance=len(batch_data))
|
||||
|
||||
# 完成进度条
|
||||
progress.update(task, completed=stats.total_documents)
|
||||
|
||||
stats.end_time = datetime.now()
|
||||
duration = stats.end_time - stats.start_time
|
||||
|
||||
logger.info(
|
||||
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
|
||||
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
|
||||
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
|
||||
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
|
||||
)
|
||||
|
||||
# 清理断点文件
|
||||
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
|
||||
if checkpoint_file.exists():
|
||||
checkpoint_file.unlink()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
|
||||
stats.add_error("collection_error", str(e))
|
||||
|
||||
return stats
|
||||
|
||||
def migrate_all(self) -> Dict[str, MigrationStats]:
|
||||
"""执行所有迁移任务"""
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
if not self.connect_mongodb():
|
||||
logger.error("无法连接到MongoDB,迁移终止")
|
||||
return {}
|
||||
|
||||
all_stats = {}
|
||||
|
||||
try:
|
||||
# 创建总体进度表格
|
||||
total_collections = len(self.migration_configs)
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
|
||||
f"[yellow]总集合数: {total_collections}[/yellow]",
|
||||
title="迁移开始",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
for idx, config in enumerate(self.migration_configs, 1):
|
||||
self.console.print(
|
||||
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
|
||||
)
|
||||
stats = self.migrate_collection(config)
|
||||
all_stats[config.mongo_collection] = stats
|
||||
|
||||
# 显示单个集合的快速统计
|
||||
if stats.processed_count > 0:
|
||||
success_rate = stats.success_count / stats.processed_count * 100
|
||||
if success_rate >= 95:
|
||||
status_emoji = "✅"
|
||||
status_color = "bright_green"
|
||||
elif success_rate >= 80:
|
||||
status_emoji = "⚠️"
|
||||
status_color = "yellow"
|
||||
else:
|
||||
status_emoji = "❌"
|
||||
status_color = "red"
|
||||
|
||||
self.console.print(
|
||||
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
|
||||
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
|
||||
)
|
||||
|
||||
# 错误率检查
|
||||
if stats.processed_count > 0:
|
||||
error_rate = stats.error_count / stats.processed_count
|
||||
if error_rate > 0.1: # 错误率超过10%
|
||||
self.console.print(
|
||||
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
|
||||
f"({stats.error_count}/{stats.processed_count})[/red]"
|
||||
)
|
||||
|
||||
finally:
|
||||
self.disconnect_mongodb()
|
||||
|
||||
self._print_migration_summary(all_stats)
|
||||
return all_stats
|
||||
|
||||
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
|
||||
"""使用Rich打印美观的迁移汇总信息"""
|
||||
# 计算总体统计
|
||||
total_processed = sum(stats.processed_count for stats in all_stats.values())
|
||||
total_success = sum(stats.success_count for stats in all_stats.values())
|
||||
total_errors = sum(stats.error_count for stats in all_stats.values())
|
||||
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
|
||||
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
|
||||
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
|
||||
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
|
||||
|
||||
# 计算总耗时
|
||||
total_duration_seconds = 0
|
||||
for stats in all_stats.values():
|
||||
if stats.start_time and stats.end_time:
|
||||
duration = stats.end_time - stats.start_time
|
||||
total_duration_seconds += duration.total_seconds()
|
||||
|
||||
# 创建详细统计表格
|
||||
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
|
||||
table.add_column("集合名称", style="cyan", width=20)
|
||||
table.add_column("文档总数", justify="right", style="blue")
|
||||
table.add_column("处理数量", justify="right", style="green")
|
||||
table.add_column("成功数量", justify="right", style="green")
|
||||
table.add_column("错误数量", justify="right", style="red")
|
||||
table.add_column("跳过数量", justify="right", style="yellow")
|
||||
table.add_column("重复数量", justify="right", style="bright_yellow")
|
||||
table.add_column("验证错误", justify="right", style="red")
|
||||
table.add_column("批次数", justify="right", style="purple")
|
||||
table.add_column("成功率", justify="right", style="bright_green")
|
||||
table.add_column("耗时(秒)", justify="right", style="blue")
|
||||
|
||||
for collection_name, stats in all_stats.items():
|
||||
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
|
||||
duration = 0
|
||||
if stats.start_time and stats.end_time:
|
||||
duration = (stats.end_time - stats.start_time).total_seconds()
|
||||
|
||||
# 根据成功率设置颜色
|
||||
if success_rate >= 95:
|
||||
success_rate_style = "[bright_green]"
|
||||
elif success_rate >= 80:
|
||||
success_rate_style = "[yellow]"
|
||||
else:
|
||||
success_rate_style = "[red]"
|
||||
|
||||
table.add_row(
|
||||
collection_name,
|
||||
str(stats.total_documents),
|
||||
str(stats.processed_count),
|
||||
str(stats.success_count),
|
||||
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
|
||||
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
|
||||
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
|
||||
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
|
||||
str(stats.batch_insert_count),
|
||||
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
|
||||
f"{duration:.2f}",
|
||||
)
|
||||
|
||||
# 添加总计行
|
||||
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
|
||||
if total_success_rate >= 95:
|
||||
total_rate_style = "[bright_green]"
|
||||
elif total_success_rate >= 80:
|
||||
total_rate_style = "[yellow]"
|
||||
else:
|
||||
total_rate_style = "[red]"
|
||||
|
||||
table.add_section()
|
||||
table.add_row(
|
||||
"[bold]总计[/bold]",
|
||||
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
|
||||
f"[bold]{total_processed}[/bold]",
|
||||
f"[bold]{total_success}[/bold]",
|
||||
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
|
||||
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
|
||||
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
|
||||
if total_duplicates > 0
|
||||
else "[bold]0[/bold]",
|
||||
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
|
||||
f"[bold]{total_batch_inserts}[/bold]",
|
||||
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
|
||||
f"[bold]{total_duration_seconds:.2f}[/bold]",
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
# 创建状态面板
|
||||
status_items = []
|
||||
if total_errors > 0:
|
||||
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
|
||||
|
||||
if total_validation_errors > 0:
|
||||
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
|
||||
|
||||
if total_duplicates > 0:
|
||||
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
|
||||
|
||||
if total_success_rate >= 95:
|
||||
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
|
||||
elif total_success_rate >= 80:
|
||||
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
|
||||
else:
|
||||
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
|
||||
|
||||
if status_items:
|
||||
status_panel = Panel(
|
||||
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
|
||||
)
|
||||
self.console.print(status_panel)
|
||||
|
||||
# 性能统计面板
|
||||
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
|
||||
performance_info = (
|
||||
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n"
|
||||
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
|
||||
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
|
||||
)
|
||||
|
||||
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
|
||||
self.console.print(performance_panel)
|
||||
|
||||
def add_migration_config(self, config: MigrationConfig):
|
||||
"""添加新的迁移配置"""
|
||||
self.migration_configs.append(config)
|
||||
|
||||
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
|
||||
"""迁移单个指定的集合"""
|
||||
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
|
||||
if not config:
|
||||
logger.error(f"未找到集合 {collection_name} 的迁移配置")
|
||||
return None
|
||||
|
||||
if not self.connect_mongodb():
|
||||
logger.error("无法连接到MongoDB")
|
||||
return None
|
||||
|
||||
try:
|
||||
stats = self.migrate_collection(config)
|
||||
self._print_migration_summary({collection_name: stats})
|
||||
return stats
|
||||
finally:
|
||||
self.disconnect_mongodb()
|
||||
|
||||
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
|
||||
"""导出错误报告"""
|
||||
error_report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"summary": {
|
||||
collection: {
|
||||
"total": stats.total_documents,
|
||||
"processed": stats.processed_count,
|
||||
"success": stats.success_count,
|
||||
"errors": stats.error_count,
|
||||
"skipped": stats.skipped_count,
|
||||
"duplicates": stats.duplicate_count,
|
||||
}
|
||||
for collection, stats in all_stats.items()
|
||||
},
|
||||
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
|
||||
}
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(error_report, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"错误报告已导出到: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"导出错误报告失败: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主程序入口"""
|
||||
migrator = MongoToSQLiteMigrator()
|
||||
|
||||
# 执行迁移
|
||||
migration_results = migrator.migrate_all()
|
||||
|
||||
# 导出错误报告(如果有错误)
|
||||
if any(stats.error_count > 0 for stats in migration_results.values()):
|
||||
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
migrator.export_error_report(migration_results, error_report_path)
|
||||
|
||||
logger.info("数据迁移完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,75 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
|
||||
paragraphs = []
|
||||
paragraph = ""
|
||||
for line in raw.split("\n"):
|
||||
if line.strip() == "":
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
paragraph = ""
|
||||
else:
|
||||
paragraph += line + "\n"
|
||||
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _process_multi_files() -> list:
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
# 处理所有文件
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
paragraphs = _process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
Returns:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_data = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
for item in raw_data:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
pg_hash = get_sha256(item)
|
||||
if pg_hash in sha256_set:
|
||||
logger.warning(f"重复数据:{item}")
|
||||
continue
|
||||
sha256_set.add(pg_hash)
|
||||
sha256_list.append(pg_hash)
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
556
scripts/run.sh
556
scripts/run.sh
@@ -1,556 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
|
||||
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
|
||||
# 请小心使用任何一键脚本!
|
||||
|
||||
INSTALLER_VERSION="0.0.5-refactor"
|
||||
LANG=C.UTF-8
|
||||
|
||||
# 如无法访问GitHub请修改此处镜像地址
|
||||
GITHUB_REPO="https://ghfast.top/https://github.com"
|
||||
|
||||
# 颜色输出
|
||||
GREEN="\e[32m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
# 需要的基本软件包
|
||||
|
||||
declare -A REQUIRED_PACKAGES=(
|
||||
["common"]="git sudo python3 curl gnupg"
|
||||
["debian"]="python3-venv python3-pip build-essential"
|
||||
["ubuntu"]="python3-venv python3-pip build-essential"
|
||||
["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make"
|
||||
["arch"]="python-virtualenv python-pip base-devel"
|
||||
)
|
||||
|
||||
# 默认项目目录
|
||||
DEFAULT_INSTALL_DIR="/opt/maicore"
|
||||
|
||||
# 服务名称
|
||||
SERVICE_NAME="maicore"
|
||||
SERVICE_NAME_WEB="maicore-web"
|
||||
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
|
||||
|
||||
IS_INSTALL_NAPCAT=false
|
||||
IS_INSTALL_DEPENDENCIES=false
|
||||
|
||||
# 检查是否已安装
|
||||
check_installed() {
|
||||
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
|
||||
}
|
||||
|
||||
# 加载安装信息
|
||||
load_install_info() {
|
||||
if [[ -f /etc/maicore_install.conf ]]; then
|
||||
source /etc/maicore_install.conf
|
||||
else
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
BRANCH="refactor"
|
||||
fi
|
||||
}
|
||||
|
||||
# 显示管理菜单
|
||||
show_menu() {
|
||||
while true; do
|
||||
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
|
||||
"1" "启动MaiCore" \
|
||||
"2" "停止MaiCore" \
|
||||
"3" "重启MaiCore" \
|
||||
"4" "启动NapCat Adapter" \
|
||||
"5" "停止NapCat Adapter" \
|
||||
"6" "重启NapCat Adapter" \
|
||||
"7" "拉取最新MaiCore仓库" \
|
||||
"8" "切换分支" \
|
||||
"9" "退出" 3>&1 1>&2 2>&3)
|
||||
|
||||
[[ $? -ne 0 ]] && exit 0
|
||||
|
||||
case "$choice" in
|
||||
1)
|
||||
systemctl start ${SERVICE_NAME}
|
||||
whiptail --msgbox "✅MaiCore已启动" 10 60
|
||||
;;
|
||||
2)
|
||||
systemctl stop ${SERVICE_NAME}
|
||||
whiptail --msgbox "🛑MaiCore已停止" 10 60
|
||||
;;
|
||||
3)
|
||||
systemctl restart ${SERVICE_NAME}
|
||||
whiptail --msgbox "🔄MaiCore已重启" 10 60
|
||||
;;
|
||||
4)
|
||||
systemctl start ${SERVICE_NAME_NBADAPTER}
|
||||
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
|
||||
;;
|
||||
5)
|
||||
systemctl stop ${SERVICE_NAME_NBADAPTER}
|
||||
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
|
||||
;;
|
||||
6)
|
||||
systemctl restart ${SERVICE_NAME_NBADAPTER}
|
||||
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
|
||||
;;
|
||||
7)
|
||||
update_dependencies
|
||||
;;
|
||||
8)
|
||||
switch_branch
|
||||
;;
|
||||
9)
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
whiptail --msgbox "无效选项!" 10 60
|
||||
;;
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
# 更新依赖
|
||||
update_dependencies() {
|
||||
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
|
||||
systemctl stop ${SERVICE_NAME}
|
||||
cd "${INSTALL_DIR}/MaiBot" || {
|
||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
||||
return 1
|
||||
}
|
||||
if ! git pull origin "${BRANCH}"; then
|
||||
whiptail --msgbox "🚫 代码更新失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
source "${INSTALL_DIR}/venv/bin/activate"
|
||||
if ! pip install -r requirements.txt; then
|
||||
whiptail --msgbox "🚫 依赖安装失败!" 10 60
|
||||
deactivate
|
||||
return 1
|
||||
fi
|
||||
deactivate
|
||||
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
|
||||
}
|
||||
|
||||
# 切换分支
|
||||
switch_branch() {
|
||||
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
|
||||
[[ -z "$new_branch" ]] && {
|
||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
||||
return 1
|
||||
}
|
||||
|
||||
cd "${INSTALL_DIR}/MaiBot" || {
|
||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
||||
return 1
|
||||
}
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
|
||||
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! git checkout "${new_branch}"; then
|
||||
whiptail --msgbox "🚫 分支切换失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! git pull origin "${new_branch}"; then
|
||||
whiptail --msgbox "🚫 代码拉取失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
systemctl stop ${SERVICE_NAME}
|
||||
source "${INSTALL_DIR}/venv/bin/activate"
|
||||
pip install -r requirements.txt
|
||||
deactivate
|
||||
|
||||
sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf
|
||||
BRANCH="${new_branch}"
|
||||
check_eula
|
||||
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} !" 10 60
|
||||
}
|
||||
|
||||
check_eula() {
|
||||
# 首先计算当前EULA的MD5值
|
||||
current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}')
|
||||
|
||||
# 首先计算当前隐私条款文件的哈希值
|
||||
current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}')
|
||||
|
||||
# 如果当前的md5值为空,则直接返回
|
||||
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
|
||||
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
|
||||
fi
|
||||
|
||||
# 检查eula.confirmed文件是否存在
|
||||
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
|
||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
||||
confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed)
|
||||
else
|
||||
confirmed_md5=""
|
||||
fi
|
||||
|
||||
# 检查privacy.confirmed文件是否存在
|
||||
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
|
||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
||||
confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed)
|
||||
else
|
||||
confirmed_md5_privacy=""
|
||||
fi
|
||||
|
||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
||||
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
|
||||
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议? \n\n " 12 70
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed
|
||||
echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# ----------- 主安装流程 -----------
|
||||
run_installation() {
|
||||
# 1/6: 检测是否安装 whiptail
|
||||
if ! command -v whiptail &>/dev/null; then
|
||||
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
|
||||
|
||||
if command -v apt-get &>/dev/null; then
|
||||
apt-get update && apt-get install -y whiptail
|
||||
elif command -v pacman &>/dev/null; then
|
||||
pacman -Syu --noconfirm whiptail
|
||||
elif command -v yum &>/dev/null; then
|
||||
yum install -y whiptail
|
||||
else
|
||||
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
whiptail --title "ℹ️ 提示" --msgbox "如果您没有特殊需求,请优先使用docker方式部署。" 10 60
|
||||
|
||||
# 协议确认
|
||||
if ! (whiptail --title "ℹ️ [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议?" 12 70); then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 欢迎信息
|
||||
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore,将自动进入安装流程,安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段,代码可能随时更改\n文档未完善,有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险,请自行了解,谨慎使用\n由于持续迭代,可能存在一些已知或未知的bug\n由于开发中,可能消耗较多token\n\n本脚本可能更新不及时,如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
|
||||
|
||||
# 系统检查
|
||||
check_system() {
|
||||
if [[ "$(id -u)" -ne 0 ]]; then
|
||||
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -f /etc/os-release ]]; then
|
||||
source /etc/os-release
|
||||
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "arch" ]]; then
|
||||
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
||||
return
|
||||
else
|
||||
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
check_system
|
||||
|
||||
# 设置包管理器
|
||||
case "$ID" in
|
||||
debian|ubuntu)
|
||||
PKG_MANAGER="apt"
|
||||
;;
|
||||
centos)
|
||||
PKG_MANAGER="yum"
|
||||
;;
|
||||
arch)
|
||||
# 添加arch包管理器
|
||||
PKG_MANAGER="pacman"
|
||||
;;
|
||||
esac
|
||||
|
||||
# 检查NapCat
|
||||
check_napcat() {
|
||||
if command -v napcat &>/dev/null; then
|
||||
NAPCAT_INSTALLED=true
|
||||
else
|
||||
NAPCAT_INSTALLED=false
|
||||
fi
|
||||
}
|
||||
check_napcat
|
||||
|
||||
# 安装必要软件包
|
||||
install_packages() {
|
||||
missing_packages=()
|
||||
# 检查 common 及当前系统专属依赖
|
||||
for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
|
||||
case "$PKG_MANAGER" in
|
||||
apt)
|
||||
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
yum)
|
||||
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
pacman)
|
||||
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ ${#missing_packages[@]} -gt 0 ]]; then
|
||||
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装?" 10 60
|
||||
if [[ $? -eq 0 ]]; then
|
||||
IS_INSTALL_DEPENDENCIES=true
|
||||
else
|
||||
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续?" 10 60 || exit 1
|
||||
fi
|
||||
fi
|
||||
}
|
||||
install_packages
|
||||
|
||||
# 安装NapCat
|
||||
install_napcat() {
|
||||
[[ $NAPCAT_INSTALLED == true ]] && return
|
||||
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat,是否安装?\n如果您想使用远程NapCat,请跳过此步。" 10 60 && {
|
||||
IS_INSTALL_NAPCAT=true
|
||||
}
|
||||
}
|
||||
|
||||
# 仅在非Arch系统上安装NapCat
|
||||
[[ "$ID" != "arch" ]] && install_napcat
|
||||
|
||||
# Python版本检查
|
||||
check_python() {
|
||||
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
|
||||
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 如果没安装python则不检查python版本
|
||||
if command -v python3 &>/dev/null; then
|
||||
check_python
|
||||
fi
|
||||
|
||||
|
||||
# 选择分支
|
||||
choose_branch() {
|
||||
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
|
||||
"main" "稳定版本(推荐)" ON \
|
||||
"dev" "开发版(不知道什么意思就别选)" OFF \
|
||||
"classical" "经典版(0.6.0以前的版本)" OFF \
|
||||
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
|
||||
RETVAL=$?
|
||||
if [ $RETVAL -ne 0 ]; then
|
||||
whiptail --msgbox "🚫 操作取消!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$BRANCH" == "custom" ]]; then
|
||||
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
|
||||
RETVAL=$?
|
||||
if [ $RETVAL -ne 0 ]; then
|
||||
whiptail --msgbox "🚫 输入取消!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$BRANCH" ]]; then
|
||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
}
|
||||
choose_branch
|
||||
|
||||
# 选择安装路径
|
||||
choose_install_dir() {
|
||||
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录:" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
|
||||
[[ -z "$INSTALL_DIR" ]] && {
|
||||
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
}
|
||||
}
|
||||
choose_install_dir
|
||||
|
||||
# 确认安装
|
||||
confirm_install() {
|
||||
local confirm_msg="请确认以下更改:\n\n"
|
||||
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
|
||||
confirm_msg+="🔀 分支: $BRANCH\n"
|
||||
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
|
||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
|
||||
|
||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
|
||||
confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
|
||||
|
||||
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
|
||||
}
|
||||
confirm_install
|
||||
|
||||
# 开始安装
|
||||
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
|
||||
|
||||
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
|
||||
case "$PKG_MANAGER" in
|
||||
apt)
|
||||
apt update && apt install -y "${missing_packages[@]}"
|
||||
;;
|
||||
yum)
|
||||
yum install -y "${missing_packages[@]}" --nobest
|
||||
;;
|
||||
pacman)
|
||||
pacman -S --noconfirm "${missing_packages[@]}"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [[ $IS_INSTALL_NAPCAT == true ]]; then
|
||||
echo -e "${GREEN}安装 NapCat...${RESET}"
|
||||
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}创建安装目录...${RESET}"
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
cd "$INSTALL_DIR" || exit 1
|
||||
|
||||
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
|
||||
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
|
||||
echo -e "${RED}克隆MaiCore仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
|
||||
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
|
||||
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
|
||||
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
|
||||
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
echo -e "${GREEN}安装Python依赖...${RESET}"
|
||||
pip install -r MaiBot/requirements.txt
|
||||
cd MaiBot
|
||||
pip install uv
|
||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}安装maim_message依赖...${RESET}"
|
||||
cd maim_message
|
||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -e .
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
|
||||
cd MaiBot-Napcat-Adapter
|
||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}同意协议...${RESET}"
|
||||
|
||||
# 首先计算当前EULA的MD5值
|
||||
current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}')
|
||||
|
||||
# 首先计算当前隐私条款文件的哈希值
|
||||
current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}')
|
||||
|
||||
echo -n $current_md5 > MaiBot/eula.confirmed
|
||||
echo -n $current_md5_privacy > MaiBot/privacy.confirmed
|
||||
|
||||
echo -e "${GREEN}创建系统服务...${RESET}"
|
||||
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
|
||||
[Unit]
|
||||
Description=MaiCore
|
||||
After=network.target ${SERVICE_NAME_NBADAPTER}.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=${INSTALL_DIR}/MaiBot
|
||||
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
|
||||
Restart=always
|
||||
RestartSec=10s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
|
||||
# [Unit]
|
||||
# Description=MaiCore WebUI
|
||||
# After=network.target ${SERVICE_NAME}.service
|
||||
|
||||
# [Service]
|
||||
# Type=simple
|
||||
# WorkingDirectory=${INSTALL_DIR}/MaiBot
|
||||
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
|
||||
# Restart=always
|
||||
# RestartSec=10s
|
||||
|
||||
# [Install]
|
||||
# WantedBy=multi-user.target
|
||||
# EOF
|
||||
|
||||
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
|
||||
[Unit]
|
||||
Description=MaiBot Napcat Adapter
|
||||
After=network.target mongod.service ${SERVICE_NAME}.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
|
||||
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
|
||||
Restart=always
|
||||
RestartSec=10s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
systemctl daemon-reload
|
||||
|
||||
# 保存安装信息
|
||||
echo "INSTALLER_VERSION=${INSTALLER_VERSION}" > /etc/maicore_install.conf
|
||||
echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf
|
||||
echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf
|
||||
|
||||
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建系统服务:${SERVICE_NAME}、${SERVICE_NAME_WEB}、${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务:\n启动服务:systemctl start ${SERVICE_NAME}\n查看状态:systemctl status ${SERVICE_NAME}" 14 60
|
||||
}
|
||||
|
||||
# ----------- 主执行流程 -----------
|
||||
# 检查root权限
|
||||
[[ $(id -u) -ne 0 ]] && {
|
||||
echo -e "${RED}请使用root用户运行此脚本!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# 如果已安装显示菜单,并检查协议是否更新
|
||||
if check_installed; then
|
||||
load_install_info
|
||||
check_eula
|
||||
show_menu
|
||||
else
|
||||
run_installation
|
||||
# 安装完成后询问是否启动
|
||||
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务?" 10 60; then
|
||||
systemctl start ${SERVICE_NAME}
|
||||
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
|
||||
fi
|
||||
fi
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# ==============================================
|
||||
# Environment Initialization
|
||||
# ==============================================
|
||||
|
||||
# Step 1: Locate project root directory
|
||||
SCRIPTS_DIR="scripts"
|
||||
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
|
||||
PROJECT_ROOT=$(cd "$SCRIPT_DIR/.." && pwd)
|
||||
|
||||
# Step 2: Verify scripts directory exists
|
||||
if [ ! -d "$PROJECT_ROOT/$SCRIPTS_DIR" ]; then
|
||||
echo "❌ Error: scripts directory not found in project root" >&2
|
||||
echo "Current path: $PROJECT_ROOT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 3: Set up Python environment
|
||||
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
|
||||
cd "$PROJECT_ROOT" || {
|
||||
echo "❌ Failed to cd to project root: $PROJECT_ROOT" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Debug info
|
||||
echo "============================"
|
||||
echo "Project Root: $PROJECT_ROOT"
|
||||
echo "Python Path: $PYTHONPATH"
|
||||
echo "Working Dir: $(pwd)"
|
||||
echo "============================"
|
||||
|
||||
# ==============================================
|
||||
# Python Script Execution
|
||||
# ==============================================
|
||||
|
||||
run_python_script() {
|
||||
local script_name=$1
|
||||
echo "🔄 Running $script_name"
|
||||
if ! python3 "$SCRIPTS_DIR/$script_name"; then
|
||||
echo "❌ $script_name failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Execute scripts in order
|
||||
run_python_script "raw_data_preprocessor.py"
|
||||
run_python_script "info_extraction.py"
|
||||
run_python_script "import_openie.py"
|
||||
|
||||
echo "✅ All scripts completed successfully"
|
||||
@@ -1,394 +0,0 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
|
||||
|
||||
def contains_emoji_or_image_tags(text: str) -> bool:
|
||||
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否包含 [表情包] 或 [图片] 标记
|
||||
emoji_pattern = r'\[表情包[^\]]*\]'
|
||||
image_pattern = r'\[图片[^\]]*\]'
|
||||
|
||||
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
|
||||
|
||||
|
||||
def clean_reply_text(text: str) -> str:
|
||||
"""Remove reply references like [回复 xxxx...] from text"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 匹配 [回复 xxxx...] 格式的内容
|
||||
# 使用非贪婪匹配,匹配到第一个 ] 就停止
|
||||
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
|
||||
|
||||
# 去除多余的空白字符
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def format_timestamp(timestamp: float) -> str:
|
||||
"""Format timestamp to readable date string"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of processed_plain_text length"""
|
||||
distribution = {
|
||||
'0': 0, # 空文本
|
||||
'1-5': 0, # 极短文本
|
||||
'6-10': 0, # 很短文本
|
||||
'11-20': 0, # 短文本
|
||||
'21-30': 0, # 较短文本
|
||||
'31-50': 0, # 中短文本
|
||||
'51-70': 0, # 中等文本
|
||||
'71-100': 0, # 较长文本
|
||||
'101-150': 0, # 长文本
|
||||
'151-200': 0, # 很长文本
|
||||
'201-300': 0, # 超长文本
|
||||
'301-500': 0, # 极长文本
|
||||
'501-1000': 0, # 巨长文本
|
||||
'1000+': 0 # 超巨长文本
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
continue
|
||||
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
|
||||
if length == 0:
|
||||
distribution['0'] += 1
|
||||
elif length <= 5:
|
||||
distribution['1-5'] += 1
|
||||
elif length <= 10:
|
||||
distribution['6-10'] += 1
|
||||
elif length <= 20:
|
||||
distribution['11-20'] += 1
|
||||
elif length <= 30:
|
||||
distribution['21-30'] += 1
|
||||
elif length <= 50:
|
||||
distribution['31-50'] += 1
|
||||
elif length <= 70:
|
||||
distribution['51-70'] += 1
|
||||
elif length <= 100:
|
||||
distribution['71-100'] += 1
|
||||
elif length <= 150:
|
||||
distribution['101-150'] += 1
|
||||
elif length <= 200:
|
||||
distribution['151-200'] += 1
|
||||
elif length <= 300:
|
||||
distribution['201-300'] += 1
|
||||
elif length <= 500:
|
||||
distribution['301-500'] += 1
|
||||
elif length <= 1000:
|
||||
distribution['501-1000'] += 1
|
||||
else:
|
||||
distribution['1000+'] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for processed_plain_text length"""
|
||||
lengths = []
|
||||
null_count = 0
|
||||
excluded_count = 0 # 被排除的消息数量
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
null_count += 1
|
||||
elif contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
# 排除包含表情包或图片标记的消息
|
||||
excluded_count += 1
|
||||
else:
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
lengths.append(len(cleaned_text))
|
||||
|
||||
if not lengths:
|
||||
return {
|
||||
'count': 0,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
}
|
||||
|
||||
lengths.sort()
|
||||
count = len(lengths)
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': min(lengths),
|
||||
'max': max(lengths),
|
||||
'avg': sum(lengths) / count,
|
||||
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id,排除特殊类型消息
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
).count()
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"获取聊天列表失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
now = time.time()
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print("时间格式错误,将不限制时间范围")
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
|
||||
"""Get top N longest messages"""
|
||||
message_lengths = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is not None:
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
chat_name = get_chat_name(msg.chat_id)
|
||||
time_str = format_timestamp(msg.time)
|
||||
# 截取前100个字符作为预览
|
||||
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
|
||||
message_lengths.append((chat_name, length, time_str, preview))
|
||||
|
||||
# 按长度排序,取前N个
|
||||
message_lengths.sort(key=lambda x: x[1], reverse=True)
|
||||
return message_lengths[:top_n]
|
||||
|
||||
|
||||
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
"""Analyze processed_plain_text lengths with optional filters"""
|
||||
|
||||
# 构建查询条件,排除特殊类型的消息
|
||||
query = Messages.select().where(
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
)
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
messages = list(query)
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_text_length_distribution(messages)
|
||||
stats = get_text_length_stats(messages)
|
||||
top_longest = get_top_longest_messages(messages, 10)
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Processed Plain Text 长度分析结果 ===")
|
||||
print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
||||
elif end_time:
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"总消息数量: {len(messages)}")
|
||||
print(f"有文本消息数量: {stats['count']}")
|
||||
print(f"空文本消息数量: {stats['null_count']}")
|
||||
print(f"被排除的消息数量: {stats['excluded_count']}")
|
||||
if stats['count'] > 0:
|
||||
print(f"最短长度: {stats['min']} 字符")
|
||||
print(f"最长长度: {stats['max']} 字符")
|
||||
print(f"平均长度: {stats['avg']:.2f} 字符")
|
||||
print(f"中位数长度: {stats['median']:.2f} 字符")
|
||||
|
||||
print("\n文本长度分布:")
|
||||
total = stats['count']
|
||||
if total > 0:
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
|
||||
|
||||
# 显示最长的消息
|
||||
if top_longest:
|
||||
print(f"\n最长的 {len(top_longest)} 条消息:")
|
||||
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
|
||||
print(f"{i}. [{chat_name}] {time_str}")
|
||||
print(f" 长度: {length} 字符")
|
||||
print(f" 预览: {preview}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for text length analysis"""
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("Processed Plain Text 长度分析工具")
|
||||
print("="*50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
chat_id = None
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到聊天数据")
|
||||
continue
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条消息)")
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
chat_id = chats[chat_choice - 1][0]
|
||||
else:
|
||||
print("无效选择")
|
||||
continue
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
# 执行分析
|
||||
analyze_text_lengths(chat_id, start_time, end_time)
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
@@ -12,9 +12,10 @@ import binascii
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from sqlalchemy import select
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
session = get_session()
|
||||
|
||||
"""
|
||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||
|
||||
@@ -151,7 +154,7 @@ class MaiEmoji:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
Emoji.create(
|
||||
emoji = Emoji(
|
||||
emoji_hash=self.hash,
|
||||
full_path=self.full_path,
|
||||
format=self.format,
|
||||
@@ -165,6 +168,8 @@ class MaiEmoji:
|
||||
usage_count=self.usage_count,
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -200,7 +205,7 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
|
||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||
except Emoji.DoesNotExist: # type: ignore
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
|
||||
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
# data is now an iterable of Peewee Emoji model instances
|
||||
emoji_data_list = list(data)
|
||||
|
||||
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||
@@ -393,12 +397,17 @@ class EmojiManager:
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""初始化数据库连接和表情目录"""
|
||||
peewee_db.connect(reuse_if_open=True)
|
||||
if peewee_db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
Emoji.create_table(safe=True) # Ensures table exists
|
||||
self._initialized = True
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
if db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
self._initialized = True # 标记为已初始化
|
||||
logger.info("EmojiManager初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"EmojiManager初始化失败: {e}")
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
def _ensure_db(self) -> None:
|
||||
"""确保数据库已初始化"""
|
||||
@@ -410,7 +419,7 @@ class EmojiManager:
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
emoji_update.save() # Persist changes to DB
|
||||
@@ -644,10 +653,10 @@ class EmojiManager:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
emoji_peewee_instances = Emoji.select()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
self.emoji_objects = emoji_objects
|
||||
@@ -675,15 +684,15 @@ class EmojiManager:
|
||||
self._ensure_db()
|
||||
|
||||
if emoji_hash:
|
||||
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
|
||||
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = Emoji.select()
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
|
||||
emoji_peewee_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
@@ -760,7 +769,7 @@ class EmojiManager:
|
||||
# 如果内存中没有,从数据库查找
|
||||
self._ensure_db()
|
||||
try:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -921,9 +930,10 @@ class EmojiManager:
|
||||
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||
existing_description = None
|
||||
try:
|
||||
from src.common.database.database_model import Images
|
||||
# from src.common.database.database_model_compat import Images
|
||||
|
||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
existing_image = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
@@ -7,7 +7,9 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
@@ -20,7 +22,7 @@ DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
session = get_session()
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
@@ -168,30 +170,50 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
# """
|
||||
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
|
||||
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
# """
|
||||
# learnt_style_expressions = []
|
||||
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
"""
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
|
||||
for expr in style_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "style",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
|
||||
for expr in grammar_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
|
||||
|
||||
|
||||
# # 直接从数据库查询
|
||||
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||
# for expr in style_query:
|
||||
# # 确保create_date存在,如果不存在则使用last_active_time
|
||||
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
# learnt_style_expressions.append(
|
||||
# {
|
||||
# "situation": expr.situation,
|
||||
# "style": expr.style,
|
||||
# "count": expr.count,
|
||||
# "last_active_time": expr.last_active_time,
|
||||
# "source_id": self.chat_id,
|
||||
# "type": "style",
|
||||
# "create_date": create_date,
|
||||
# }
|
||||
# )
|
||||
# return learnt_style_expressions
|
||||
|
||||
|
||||
|
||||
@@ -201,7 +223,7 @@ class ExpressionLearner:
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
all_expressions = session.execute(select(Expression)).scalars()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
@@ -217,18 +239,20 @@ class ExpressionLearner:
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
session.delete(expr)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
@@ -297,23 +321,22 @@ class ExpressionLearner:
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
new_expression = Expression(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
@@ -322,16 +345,18 @@ class ExpressionLearner:
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
.order_by(Expression.count.asc())
|
||||
session.execute(select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())).scalars()
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
@@ -509,54 +534,35 @@ class ExpressionLearnerManager:
|
||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
# 查重:同chat_id+type+situation+style
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 标记迁移完成
|
||||
try:
|
||||
# 确保done.done文件的父目录存在
|
||||
done_parent_dir = os.path.dirname(done_flag)
|
||||
if not os.path.exists(done_parent_dir):
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
except OSError as e:
|
||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
session.add(new_expression)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 检查并处理grammar表达删除
|
||||
if not os.path.exists(done_flag2):
|
||||
@@ -581,18 +587,20 @@ class ExpressionLearnerManager:
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
def delete_all_grammar_expressions(self) -> int:
|
||||
|
||||
@@ -9,8 +9,11 @@ from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
session = get_session()
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -131,9 +134,12 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
style_query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
))
|
||||
grammar_query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
||||
))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -146,9 +152,24 @@ class ExpressionSelector:
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
for expr in style_query.scalars()
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in grammar_query.scalars()
|
||||
]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
@@ -174,19 +195,19 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
session.commit()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||
from .lpmmconfig import global_config
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
|
||||
|
||||
class MemoryActiveManager:
|
||||
def __init__(
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.embedding_client = llm_client_embedding
|
||||
|
||||
def get_activation(self, question: str) -> float:
|
||||
"""获取记忆激活度"""
|
||||
# 生成问题的Embedding
|
||||
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
|
||||
# 查询关系库中的相似度
|
||||
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
|
||||
|
||||
# 动态过滤阈值
|
||||
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
|
||||
if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
# 未找到相关关系
|
||||
return 0.0
|
||||
|
||||
# 计算激活度
|
||||
activation = sum([item[2] for item in rel_scores]) * 10
|
||||
|
||||
return activation
|
||||
@@ -16,8 +16,10 @@ from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from sqlalchemy import select,insert,update,text,delete
|
||||
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
@@ -37,7 +39,7 @@ def cosine_similarity(v1, v2):
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
session = get_session()
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
@@ -731,13 +733,14 @@ class Hippocampus:
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与关键词的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(keywords)
|
||||
all_words = memory_words | text_words
|
||||
if all_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
logger.debug(f"节点包含 {len(memory_items)} 条记忆")
|
||||
# 计算每条记忆与输入文本的相似度
|
||||
memory_similarities = []
|
||||
for memory in memory_items:
|
||||
# 计算与输入文本的相似度
|
||||
memory_words = set(jieba.cut(memory))
|
||||
text_words = set(jieba.cut(text))
|
||||
all_words = memory_words | text_words
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
@@ -844,11 +847,6 @@ class Hippocampus:
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 输出激活映射
|
||||
# logger.info("激活映射统计:")
|
||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||
|
||||
# 计算激活节点数与总节点数的比值
|
||||
total_activation = sum(activate_map.values())
|
||||
# logger.debug(f"总激活值: {total_activation:.2f}")
|
||||
@@ -942,10 +940,13 @@ class EntorhinalCortex:
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
# 使用 Peewee 更新记录
|
||||
Messages.update(memorized_times=current_memorized_times + 1).where(
|
||||
Messages.message_id == message["message_id"]
|
||||
).execute()
|
||||
# 使用 SQLAlchemy 2.0 更新记录
|
||||
session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
session.commit()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
@@ -959,7 +960,7 @@ class EntorhinalCortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
@@ -1025,22 +1026,27 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
session.commit()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphNodes)
|
||||
.where(GraphNodes.concept == node_data["concept"])
|
||||
.values(**{k: v for k, v in node_data.items() if k != "concept"})
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if nodes_to_delete:
|
||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
session.commit()
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(GraphEdges.select())
|
||||
db_edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -1092,20 +1098,29 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphEdges)
|
||||
.where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
)
|
||||
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
session.execute(
|
||||
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
)
|
||||
session.commit()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1118,8 +1133,9 @@ class EntorhinalCortex:
|
||||
|
||||
# 清空数据库
|
||||
clear_start = time.time()
|
||||
GraphNodes.delete().execute()
|
||||
GraphEdges.delete().execute()
|
||||
session.execute(delete(GraphNodes))
|
||||
session.execute(delete(GraphEdges))
|
||||
session.commit()
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
|
||||
@@ -1186,12 +1202,27 @@ class EntorhinalCortex:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
|
||||
# 批量插入边
|
||||
# 批量写入节点
|
||||
node_start = time.time()
|
||||
if nodes_data:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
session.commit()
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
|
||||
# 批量写入边
|
||||
edge_start = time.time()
|
||||
if edges_data:
|
||||
batch_size = 100
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1211,9 +1242,7 @@ class EntorhinalCortex:
|
||||
skipped_nodes = 0
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(GraphNodes.select())
|
||||
total_nodes = len(nodes)
|
||||
|
||||
nodes = list(session.execute(select(GraphNodes)).scalars())
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
@@ -1235,8 +1264,10 @@ class EntorhinalCortex:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
if update_data:
|
||||
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
|
||||
session.execute(
|
||||
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
@@ -1256,7 +1287,7 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(GraphEdges.select())
|
||||
edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
@@ -1272,9 +1303,12 @@ class EntorhinalCortex:
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
GraphEdges.update(**update_data).where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphEdges)
|
||||
.where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
.values(**update_data)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.created_time or current_time
|
||||
@@ -1398,7 +1432,6 @@ class ParahippocampalGyrus:
|
||||
all_words = topic_words | existing_words
|
||||
v1 = [1 if word in topic_words else 0 for word in all_words]
|
||||
v2 = [1 if word in existing_words else 0 for word in all_words]
|
||||
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
|
||||
if similarity >= 0.7:
|
||||
@@ -1502,7 +1535,7 @@ class ParahippocampalGyrus:
|
||||
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
|
||||
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
|
||||
|
||||
# 只有在有足够的节点和边时才进行采样
|
||||
# 只有在有足够的节点和边时进行采样
|
||||
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
|
||||
try:
|
||||
nodes_to_check = random.sample(all_nodes, check_nodes_count)
|
||||
@@ -1548,6 +1581,11 @@ class ParahippocampalGyrus:
|
||||
|
||||
logger.info("[遗忘] 开始检查节点...")
|
||||
node_check_start = time.time()
|
||||
|
||||
# 初始化整合相关变量
|
||||
merged_count = 0
|
||||
nodes_modified = set()
|
||||
|
||||
for node in nodes_to_check:
|
||||
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
|
||||
if node not in self.memory_graph.G:
|
||||
@@ -1567,64 +1605,91 @@ class ParahippocampalGyrus:
|
||||
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue # 处理下一个节点
|
||||
|
||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||
# 检查节点的最后修改时间,如果太旧则尝试遗忘
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
# 条件1:检查是否长时间未修改 (使用配置的遗忘时间)
|
||||
time_threshold = 3600 * global_config.memory.memory_forget_time
|
||||
|
||||
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
|
||||
# 权重为1时使用默认阈值,权重越高阈值越大(越难遗忘)
|
||||
adjusted_threshold = time_threshold * node_weight
|
||||
|
||||
if current_time - last_modified > adjusted_threshold and memory_items:
|
||||
# 既然每个节点现在是完整记忆,直接删除整个节点
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
|
||||
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue
|
||||
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
|
||||
# 随机遗忘一条记忆
|
||||
if len(memory_items) > 1:
|
||||
removed_item = self.memory_graph.forget_topic(node)
|
||||
if removed_item:
|
||||
node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)")
|
||||
elif len(memory_items) == 1:
|
||||
# 如果只有一条记忆,检查是否应该完全移除节点
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node} (最后记忆)")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}")
|
||||
|
||||
# 检查节点内是否有相似的记忆项需要整合
|
||||
if len(memory_items) > 1:
|
||||
merged_in_this_node = False
|
||||
items_to_remove = []
|
||||
|
||||
for i in range(len(memory_items)):
|
||||
for j in range(i + 1, len(memory_items)):
|
||||
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
|
||||
if similarity > 0.8: # 相似度阈值
|
||||
# 合并相似记忆项
|
||||
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
|
||||
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
|
||||
|
||||
# 保留更长的记忆项,标记短的用于删除
|
||||
if shorter_item not in items_to_remove:
|
||||
items_to_remove.append(shorter_item)
|
||||
merged_count += 1
|
||||
merged_in_this_node = True
|
||||
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
|
||||
|
||||
# 移除被合并的记忆项
|
||||
if items_to_remove:
|
||||
for item in items_to_remove:
|
||||
if item in memory_items:
|
||||
memory_items.remove(item)
|
||||
nodes_modified.add(node)
|
||||
# 更新节点的记忆项
|
||||
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
|
||||
self.memory_graph.G.nodes[node]["last_modified"] = current_time
|
||||
|
||||
node_check_end = time.time()
|
||||
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒")
|
||||
|
||||
if any(edge_changes.values()) or any(node_changes.values()):
|
||||
# 输出变化统计
|
||||
if edge_changes["weakened"]:
|
||||
logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接")
|
||||
if edge_changes["removed"]:
|
||||
logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接")
|
||||
if node_changes["reduced"]:
|
||||
logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆")
|
||||
if node_changes["removed"]:
|
||||
logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点")
|
||||
|
||||
# 检查是否有变化需要同步到数据库
|
||||
has_changes = (
|
||||
edge_changes["weakened"] or
|
||||
edge_changes["removed"] or
|
||||
node_changes["reduced"] or
|
||||
node_changes["removed"] or
|
||||
merged_count > 0
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
logger.info("[遗忘] 开始将变更同步到数据库...")
|
||||
sync_start = time.time()
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
sync_end = time.time()
|
||||
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||
|
||||
# 汇总输出所有变化
|
||||
logger.info("[遗忘] 遗忘操作统计:")
|
||||
if edge_changes["weakened"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
|
||||
)
|
||||
|
||||
if edge_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
|
||||
)
|
||||
|
||||
if node_changes["reduced"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
|
||||
)
|
||||
|
||||
if node_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
|
||||
)
|
||||
if merged_count > 0:
|
||||
logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。")
|
||||
sync_start = time.time()
|
||||
logger.info("[整合] 开始将变更同步到数据库...")
|
||||
# 使用 resync 更安全地处理删除和添加
|
||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||
sync_end = time.time()
|
||||
logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||
else:
|
||||
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
logger.info("[整合] 本次检查未发现需要合并的记忆项。")
|
||||
|
||||
|
||||
|
||||
@@ -1734,10 +1799,7 @@ class HippocampusManager:
|
||||
"""获取所有节点名称的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_all_node_names()
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,13 @@ from datetime import datetime, timedelta
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
from sqlalchemy import select
|
||||
logger = get_logger(__name__)
|
||||
|
||||
session = get_session()
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
@@ -120,7 +121,8 @@ class InstantMemory:
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory.save()
|
||||
session.add(memory)
|
||||
session.commit()
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
@@ -166,13 +168,13 @@ class InstantMemory:
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
query = Memory.select().where(
|
||||
query = session.execute(select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts) # type: ignore
|
||||
& (Memory.create_time < end_ts) # type: ignore
|
||||
)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)).scalars()
|
||||
else:
|
||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
|
||||
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
|
||||
@@ -8,8 +8,12 @@ from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import ChatStreams # 新增导入
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.config.config import global_config # 新增导入
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
@@ -19,7 +23,7 @@ install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("chat_stream")
|
||||
|
||||
session = get_session()
|
||||
|
||||
class ChatMessageContext:
|
||||
"""聊天消息上下文,存储消息的上下文信息"""
|
||||
@@ -131,7 +135,8 @@ class ChatManager:
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
|
||||
@@ -231,7 +236,7 @@ class ChatManager:
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
@@ -342,7 +347,28 @@ class ChatManager:
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
# 根据数据库类型选择插入语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用通用插入,尝试SQLite语法
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
@@ -361,7 +387,7 @@ class ChatManager:
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from src.common.database.database_model import Messages, Images
|
||||
from src.common.database.sqlalchemy_models import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, update, desc
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords) -> str:
|
||||
@@ -33,15 +35,11 @@ class MessageStorage:
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 莫越权 救世啊
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# print(message)
|
||||
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
# print(processed_plain_text)
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
@@ -93,11 +91,16 @@ class MessageStorage:
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
|
||||
Messages.create(
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = json.dumps(priority_info) if priority_info else None
|
||||
|
||||
# 获取数据库会话
|
||||
session = get_session()
|
||||
|
||||
new_message = Messages(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time), # type: ignore
|
||||
time=float(message.message_info.time),
|
||||
chat_id=chat_stream.stream_id,
|
||||
# Flattened chat_info
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
@@ -111,18 +114,16 @@ class MessageStorage:
|
||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||
# Flattened user_info (message sender)
|
||||
user_platform=user_info_dict.get("platform"),
|
||||
user_id=user_info_dict.get("user_id"),
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
# Text content
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info,
|
||||
priority_info=priority_info_json,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
@@ -131,35 +132,44 @@ class MessageStorage:
|
||||
key_words_lite=key_words_lite,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
logger.error(f"消息:{message}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
message: MessageRecv,
|
||||
) -> None: # 用于实时更新数据库的自身发送消息ID,目前能处理text,reply,image和emoji
|
||||
"""更新最新一条匹配消息的message_id"""
|
||||
async def update_message(message):
|
||||
"""更新消息ID"""
|
||||
try:
|
||||
if message.message_segment.type == "notify":
|
||||
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||
mmc_message_id = message.message_info.message_id # 修复:正确访问message_id
|
||||
if message.message_segment.type == "text":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "reply":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
else:
|
||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||
return
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return
|
||||
if matched_message := (
|
||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||
):
|
||||
# 更新找到的消息记录
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
@@ -178,10 +188,12 @@ class MessageStorage:
|
||||
def replace_match(match):
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
image_record = (
|
||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||
)
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
).scalar()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
|
||||
@@ -7,13 +7,14 @@ from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person,get_person_id
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
session = get_session()
|
||||
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -813,7 +854,7 @@ def build_readable_messages(
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
@@ -846,21 +887,21 @@ def build_readable_messages(
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
actions_in_range = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time,
|
||||
ActionRecords.time <= max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
).order_by(ActionRecords.time)).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
action_after_latest = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time > max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
).order_by(ActionRecords.time).limit(1)).scalars()
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
self.record_id: int | None = None
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self): # sourcery skip: use-named-expression
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||
updated_rows = query.execute()
|
||||
updated_rows = await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
self.record_id = None
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
if not self.record_id:
|
||||
# 查找最近一分钟内的记录
|
||||
recent_threshold = current_time - timedelta(minutes=1)
|
||||
recent_records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
new_record = OnlineTime.create(
|
||||
timestamp=current_time.timestamp(), # 添加此行
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
duration=5, # 初始时长为5分钟
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
else:
|
||||
# 创建新记录
|
||||
new_record = await db_save(
|
||||
model_class=OnlineTime,
|
||||
data={
|
||||
"timestamp": str(current_time),
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_timestamp = record['timestamp'] # 从字典中获取
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_end_timestamp = record['end_timestamp']
|
||||
record_start_timestamp = record['start_timestamp']
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time'] # This is a float timestamp
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||
if message.chat_info_group_id:
|
||||
chat_id = f"g{message.chat_info_group_id}"
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.request_type or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.chat_info_group_id:
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id:
|
||||
chat_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -13,10 +13,12 @@ from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
@@ -41,9 +43,10 @@ class ImageManager:
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
# 使用SQLAlchemy创建表已在初始化时完成
|
||||
logger.debug("使用SQLAlchemy进行表管理")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@@ -63,12 +66,13 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -82,16 +86,28 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.description = description
|
||||
existing.timestamp = current_timestamp
|
||||
else:
|
||||
# 创建新记录
|
||||
new_desc = ImageDescriptions(
|
||||
image_description_hash=image_hash,
|
||||
type=description_type,
|
||||
description=description,
|
||||
timestamp=current_timestamp
|
||||
)
|
||||
session.add(new_desc)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -214,19 +230,29 @@ class ImageManager:
|
||||
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = detailed_description # 保存详细描述
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
Images.create(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(select(Images).where(
|
||||
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
)).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
@@ -249,19 +275,19 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
existing_image.save()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
@@ -300,10 +326,10 @@ class ImageManager:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -313,6 +339,8 @@ class ImageManager:
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -465,31 +493,32 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
existing_image.count += 1
|
||||
session.commit()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
@@ -503,7 +532,7 @@ class ImageManager:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -512,6 +541,8 @@ class ImageManager:
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
|
||||
# 启动异步VLM处理
|
||||
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
|
||||
@@ -536,60 +567,64 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
# 获取当前图片记录
|
||||
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
|
||||
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = session.execute(select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id
|
||||
)
|
||||
)).scalar()
|
||||
if existing_with_description:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
return
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
@@ -1,14 +1,103 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_session
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
||||
class DatabaseProxy:
|
||||
"""数据库代理类,提供Peewee到SQLAlchemy的兼容性接口"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
|
||||
def connect(self, reuse_if_open=True):
|
||||
"""连接数据库(兼容性方法)"""
|
||||
try:
|
||||
self._engine = get_engine()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
def is_closed(self):
|
||||
"""检查数据库是否关闭(兼容性方法)"""
|
||||
return self._engine is None
|
||||
|
||||
def create_tables(self, models, safe=True):
|
||||
"""创建表(兼容性方法)"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Base
|
||||
engine = get_engine()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建表失败: {e}")
|
||||
return False
|
||||
|
||||
def table_exists(self, model):
|
||||
"""检查表是否存在(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import inspect
|
||||
engine = get_engine()
|
||||
inspector = inspect(engine)
|
||||
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
|
||||
return table_name in inspector.get_table_names()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql):
|
||||
"""执行SQL(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
session = get_session()
|
||||
result = session.execute(text(sql))
|
||||
session.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行SQL失败: {e}")
|
||||
raise
|
||||
|
||||
def atomic(self):
|
||||
"""事务上下文管理器(兼容性方法)"""
|
||||
return SQLAlchemyTransaction()
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
"""SQLAlchemy事务上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def __enter__(self):
|
||||
self.session = get_session()
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
# 创建全局数据库代理实例
|
||||
db = DatabaseProxy()
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
@@ -39,7 +128,7 @@ def __create_database_instance():
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
"""获取MongoDB连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
@@ -47,6 +136,47 @@ def get_db():
|
||||
return _db
|
||||
|
||||
|
||||
def initialize_sql_database(database_config):
|
||||
"""
|
||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
||||
|
||||
Args:
|
||||
database_config: DatabaseConfig对象
|
||||
"""
|
||||
global _sql_engine
|
||||
|
||||
try:
|
||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
||||
|
||||
# 记录数据库配置信息
|
||||
if database_config.database_type == "mysql":
|
||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
||||
logger.info("MySQL数据库连接配置:")
|
||||
logger.info(f" 连接信息: {connection_info}")
|
||||
logger.info(f" 字符集: {database_config.mysql_charset}")
|
||||
else:
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if not os.path.isabs(database_config.sqlite_path):
|
||||
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
|
||||
else:
|
||||
db_path = database_config.sqlite_path
|
||||
logger.info("SQLite数据库连接配置:")
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
if success:
|
||||
_sql_engine = get_engine()
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
else:
|
||||
logger.error("SQLAlchemy数据库初始化失败")
|
||||
|
||||
return _sql_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
@@ -57,26 +187,6 @@ class DBWrapper:
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
_DB_DIR = os.path.join(ROOT_PATH, "data")
|
||||
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(_DB_DIR, exist_ok=True)
|
||||
|
||||
# 全局 Peewee SQLite 数据库访问点
|
||||
db = SqliteDatabase(
|
||||
_DB_FILE,
|
||||
pragmas={
|
||||
"journal_mode": "wal", # WAL模式提高并发性能
|
||||
"cache_size": -64 * 1000, # 64MB缓存
|
||||
"foreign_keys": 1,
|
||||
"ignore_check_constraints": 0,
|
||||
"synchronous": 0, # 异步写入提高性能
|
||||
"busy_timeout": 1000, # 1秒超时而不是3秒
|
||||
},
|
||||
)
|
||||
|
||||
420
src/common/database/sqlalchemy_database_api.py
Normal file
420
src/common/database/sqlalchemy_database_api.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""SQLAlchemy数据库API模块
|
||||
|
||||
提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题
|
||||
支持自动重连、连接池管理和更好的错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
|
||||
from sqlalchemy import desc, asc, func, and_, or_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器,自动处理事务和连接错误"""
|
||||
session = None
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
session = get_session()
|
||||
yield session
|
||||
session.commit()
|
||||
break
|
||||
except (DisconnectionError, OperationalError) as e:
|
||||
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if session:
|
||||
session.rollback()
|
||||
session.close()
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
for field_name, value in filters.items():
|
||||
if not hasattr(model_class, field_name):
|
||||
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
|
||||
continue
|
||||
|
||||
field = getattr(model_class, field_name)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(~field.in_(op_value))
|
||||
else:
|
||||
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Base],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型 ("get", "create", "update", "delete", "count")
|
||||
filters: 过滤条件字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
|
||||
|
||||
with get_db_session() as session:
|
||||
if query_type == "get":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field_name in order_by:
|
||||
if field_name.startswith("-"):
|
||||
field_name = field_name[1:]
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(desc(getattr(model_class, field_name)))
|
||||
else:
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(asc(getattr(model_class, field_name)))
|
||||
|
||||
# 应用限制
|
||||
if limit and limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = query.all()
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for result in results:
|
||||
result_dict = {}
|
||||
for column in result.__table__.columns:
|
||||
result_dict[column.name] = getattr(result, column.name)
|
||||
result_dicts.append(result_dict)
|
||||
|
||||
if single_result:
|
||||
return result_dicts[0] if result_dicts else None
|
||||
return result_dicts
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush() # 获取自动生成的ID
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行更新
|
||||
affected_rows = query.update(data)
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "delete":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行删除
|
||||
affected_rows = query.delete()
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "count":
|
||||
query = session.query(func.count(model_class.id))
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
base_query = session.query(model_class)
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
base_query = base_query.filter(and_(*conditions))
|
||||
query = session.query(func.count()).select_from(base_query.subquery())
|
||||
|
||||
return query.scalar()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
保存后的记录数据或None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
for field, value in data.items():
|
||||
if hasattr(existing_record, field):
|
||||
setattr(existing_record, field, value)
|
||||
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in existing_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(existing_record, column.name)
|
||||
return result_dict
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Base],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
order_by_list = [order_by] if order_by else None
|
||||
return await db_query(
|
||||
model_class=model_class,
|
||||
query_type="get",
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
)
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数,方便从Peewee迁移
|
||||
def get_model_class(model_name: str) -> Optional[Type[Base]]:
|
||||
"""根据模型名称获取模型类"""
|
||||
return MODEL_MAPPING.get(model_name)
|
||||
158
src/common/database/sqlalchemy_init.py
Normal file
158
src/common/database/sqlalchemy_init.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""SQLAlchemy数据库初始化模块
|
||||
|
||||
替换Peewee的数据库初始化逻辑
|
||||
提供统一的数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_engine, get_session, initialize_database
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
|
||||
def initialize_sqlalchemy_database() -> bool:
|
||||
"""
|
||||
初始化SQLAlchemy数据库
|
||||
创建所有表结构
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化SQLAlchemy数据库...")
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine, session_local = initialize_database()
|
||||
|
||||
if engine is None:
|
||||
logger.error("数据库引擎初始化失败")
|
||||
return False
|
||||
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_all_tables() -> bool:
|
||||
"""
|
||||
创建所有数据库表
|
||||
|
||||
Returns:
|
||||
bool: 创建是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始创建数据库表...")
|
||||
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
logger.error("无法获取数据库引擎")
|
||||
return False
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
logger.info("数据库表创建成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"创建数据库表失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库表过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_database_connection() -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
|
||||
Returns:
|
||||
bool: 连接是否正常
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
if session is None:
|
||||
logger.error("无法获取数据库会话")
|
||||
return False
|
||||
|
||||
# 检查会话是否可用(如果能获取到会话说明连接正常)
|
||||
if session is None:
|
||||
logger.error("数据库会话无效")
|
||||
return False
|
||||
|
||||
session.close()
|
||||
|
||||
logger.info("数据库连接检查通过")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"数据库连接检查失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_database_info() -> Optional[dict]:
|
||||
"""
|
||||
获取数据库信息
|
||||
|
||||
Returns:
|
||||
dict: 数据库信息字典,包含引擎信息等
|
||||
"""
|
||||
try:
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
return None
|
||||
|
||||
info = {
|
||||
'engine_name': engine.name,
|
||||
'driver': engine.driver,
|
||||
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
|
||||
'pool_size': getattr(engine.pool, 'size', None),
|
||||
'max_overflow': getattr(engine.pool, 'max_overflow', None),
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_database_initialized = False
|
||||
|
||||
def initialize_database_compat() -> bool:
|
||||
"""
|
||||
兼容性数据库初始化函数
|
||||
用于替换原有的Peewee初始化代码
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global _database_initialized
|
||||
|
||||
if _database_initialized:
|
||||
return True
|
||||
|
||||
success = initialize_sqlalchemy_database()
|
||||
if success:
|
||||
success = create_all_tables()
|
||||
|
||||
if success:
|
||||
success = check_database_connection()
|
||||
|
||||
if success:
|
||||
_database_initialized = True
|
||||
|
||||
return success
|
||||
555
src/common/database/sqlalchemy_models.py
Normal file
555
src/common/database/sqlalchemy_models.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""SQLAlchemy数据库模型定义
|
||||
|
||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = get_logger("sqlalchemy_models")
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
"""
|
||||
if global_config.database.database_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
class SessionProxy:
|
||||
"""线程安全的Session代理类,自动管理session生命周期"""
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_current_session(self):
|
||||
"""获取当前线程的session,如果没有则创建新的"""
|
||||
if not hasattr(self._local, 'session') or self._local.session is None:
|
||||
_, SessionLocal = initialize_database()
|
||||
self._local.session = SessionLocal()
|
||||
return self._local.session
|
||||
|
||||
def _close_current_session(self):
|
||||
"""关闭当前线程的session"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
self._local.session.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._local.session = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理所有session方法"""
|
||||
session = self._get_current_session()
|
||||
attr = getattr(session, name)
|
||||
|
||||
# 如果是方法,需要特殊处理一些关键方法
|
||||
if callable(attr):
|
||||
if name in ['commit', 'rollback']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = attr(*args, **kwargs)
|
||||
if name == 'commit':
|
||||
# commit后不要清除session,只是刷新状态
|
||||
pass # 保持session活跃
|
||||
return result
|
||||
except Exception as e:
|
||||
try:
|
||||
if session and hasattr(session, 'rollback'):
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
# 发生错误时重新创建session
|
||||
self._close_current_session()
|
||||
raise
|
||||
return wrapper
|
||||
elif name == 'close':
|
||||
def wrapper(*args, **kwargs):
|
||||
result = attr(*args, **kwargs)
|
||||
self._close_current_session()
|
||||
return result
|
||||
return wrapper
|
||||
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果是连接相关错误,重新创建session再试一次
|
||||
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
|
||||
logger.warning(f"Session问题,重新创建session: {e}")
|
||||
self._close_current_session()
|
||||
new_session = self._get_current_session()
|
||||
new_attr = getattr(new_session, name)
|
||||
return new_attr(*args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
return attr
|
||||
|
||||
def new_session(self):
|
||||
"""强制创建新的session(关闭当前的,创建新的)"""
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
def ensure_fresh_session(self):
|
||||
"""确保使用新鲜的session(如果当前session有问题则重新创建)"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
# 测试session是否还可用
|
||||
self._local.session.execute("SELECT 1")
|
||||
except Exception:
|
||||
# session有问题,重新创建
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
# 创建全局session代理实例
|
||||
_global_session_proxy = SessionProxy()
|
||||
|
||||
def get_session():
|
||||
"""返回线程安全的session代理,自动管理生命周期"""
|
||||
return _global_session_proxy
|
||||
|
||||
|
||||
class ChatStreams(Base):
|
||||
"""聊天流模型"""
|
||||
__tablename__ = 'chat_streams'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||
create_time = Column(Float, nullable=False)
|
||||
group_platform = Column(Text, nullable=True)
|
||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
group_name = Column(Text, nullable=True)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_nickname = Column(Text, nullable=False)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_chatstreams_stream_id', 'stream_id'),
|
||||
Index('idx_chatstreams_user_id', 'user_id'),
|
||||
Index('idx_chatstreams_group_id', 'group_id'),
|
||||
)
|
||||
|
||||
|
||||
class LLMUsage(Base):
|
||||
"""LLM使用记录模型"""
|
||||
__tablename__ = 'llm_usage'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||||
endpoint = Column(Text, nullable=False)
|
||||
prompt_tokens = Column(Integer, nullable=False)
|
||||
completion_tokens = Column(Integer, nullable=False)
|
||||
total_tokens = Column(Integer, nullable=False)
|
||||
cost = Column(Float, nullable=False)
|
||||
status = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_llmusage_model_name', 'model_name'),
|
||||
Index('idx_llmusage_user_id', 'user_id'),
|
||||
Index('idx_llmusage_request_type', 'request_type'),
|
||||
Index('idx_llmusage_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class Emoji(Base):
|
||||
"""表情包模型"""
|
||||
__tablename__ = 'emoji'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
format = Column(Text, nullable=False)
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
query_count = Column(Integer, nullable=False, default=0)
|
||||
is_registered = Column(Boolean, nullable=False, default=False)
|
||||
is_banned = Column(Boolean, nullable=False, default=False)
|
||||
emotion = Column(Text, nullable=True)
|
||||
record_time = Column(Float, nullable=False)
|
||||
register_time = Column(Float, nullable=True)
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_used_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_emoji_full_path', 'full_path'),
|
||||
Index('idx_emoji_hash', 'emoji_hash'),
|
||||
)
|
||||
|
||||
|
||||
class Messages(Base):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
reply_to = Column(Text, nullable=True)
|
||||
interest_value = Column(Float, nullable=True)
|
||||
is_mentioned = Column(Boolean, nullable=True)
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
chat_info_user_platform = Column(Text, nullable=False)
|
||||
chat_info_user_id = Column(Text, nullable=False)
|
||||
chat_info_user_nickname = Column(Text, nullable=False)
|
||||
chat_info_user_cardname = Column(Text, nullable=True)
|
||||
chat_info_group_platform = Column(Text, nullable=True)
|
||||
chat_info_group_id = Column(Text, nullable=True)
|
||||
chat_info_group_name = Column(Text, nullable=True)
|
||||
chat_info_create_time = Column(Float, nullable=False)
|
||||
chat_info_last_active_time = Column(Float, nullable=False)
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段
|
||||
user_platform = Column(Text, nullable=True)
|
||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
user_nickname = Column(Text, nullable=True)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
processed_plain_text = Column(Text, nullable=True)
|
||||
display_message = Column(Text, nullable=True)
|
||||
memorized_times = Column(Integer, nullable=False, default=0)
|
||||
priority_mode = Column(Text, nullable=True)
|
||||
priority_info = Column(Text, nullable=True)
|
||||
additional_config = Column(Text, nullable=True)
|
||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
||||
is_picid = Column(Boolean, nullable=False, default=False)
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_messages_message_id', 'message_id'),
|
||||
Index('idx_messages_chat_id', 'chat_id'),
|
||||
Index('idx_messages_time', 'time'),
|
||||
Index('idx_messages_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class ActionRecords(Base):
|
||||
"""动作记录模型"""
|
||||
__tablename__ = 'action_records'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
action_name = Column(Text, nullable=False)
|
||||
action_data = Column(Text, nullable=False)
|
||||
action_done = Column(Boolean, nullable=False, default=False)
|
||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||||
action_prompt_display = Column(Text, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_actionrecords_action_id', 'action_id'),
|
||||
Index('idx_actionrecords_chat_id', 'chat_id'),
|
||||
Index('idx_actionrecords_time', 'time'),
|
||||
)
|
||||
|
||||
|
||||
class Images(Base):
|
||||
"""图像信息模型"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
image_id = Column(Text, nullable=False, default="")
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
||||
count = Column(Integer, nullable=False, default=1)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
type = Column(Text, nullable=False)
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_images_emoji_hash', 'emoji_hash'),
|
||||
Index('idx_images_path', 'path'),
|
||||
)
|
||||
|
||||
|
||||
class ImageDescriptions(Base):
|
||||
"""图像描述信息模型"""
|
||||
__tablename__ = 'image_descriptions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(Text, nullable=False)
|
||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_imagedesc_hash', 'image_description_hash'),
|
||||
)
|
||||
|
||||
|
||||
class OnlineTime(Base):
|
||||
"""在线时长记录模型"""
|
||||
__tablename__ = 'online_time'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||
duration = Column(Integer, nullable=False)
|
||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class PersonInfo(Base):
|
||||
"""人物信息模型"""
|
||||
__tablename__ = 'person_info'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
person_name = Column(Text, nullable=True)
|
||||
name_reason = Column(Text, nullable=True)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
nickname = Column(Text, nullable=True)
|
||||
impression = Column(Text, nullable=True)
|
||||
short_impression = Column(Text, nullable=True)
|
||||
points = Column(Text, nullable=True)
|
||||
forgotten_points = Column(Text, nullable=True)
|
||||
info_list = Column(Text, nullable=True)
|
||||
know_times = Column(Float, nullable=True)
|
||||
know_since = Column(Float, nullable=True)
|
||||
last_know = Column(Float, nullable=True)
|
||||
attitude = Column(Integer, nullable=True, default=50)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_personinfo_person_id', 'person_id'),
|
||||
Index('idx_personinfo_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_id = Column(Text, nullable=True)
|
||||
memory_text = Column(Text, nullable=True)
|
||||
keywords = Column(Text, nullable=True)
|
||||
create_time = Column(Float, nullable=True)
|
||||
last_view_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_memory_id', 'memory_id'),
|
||||
)
|
||||
|
||||
|
||||
class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
situation = Column(Text, nullable=False)
|
||||
style = Column(Text, nullable=False)
|
||||
count = Column(Float, nullable=False)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
type = Column(Text, nullable=False)
|
||||
create_date = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class ThinkingLog(Base):
|
||||
"""思考日志模型"""
|
||||
__tablename__ = 'thinking_logs'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
trigger_text = Column(Text, nullable=True)
|
||||
response_text = Column(Text, nullable=True)
|
||||
trigger_info_json = Column(Text, nullable=True)
|
||||
response_info_json = Column(Text, nullable=True)
|
||||
timing_results_json = Column(Text, nullable=True)
|
||||
chat_history_json = Column(Text, nullable=True)
|
||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||||
chat_history_after_response_json = Column(Text, nullable=True)
|
||||
heartflow_data_json = Column(Text, nullable=True)
|
||||
reasoning_data_json = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_thinkinglog_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class GraphNodes(Base):
|
||||
"""记忆图节点模型"""
|
||||
__tablename__ = 'graph_nodes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
memory_items = Column(Text, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphnodes_concept', 'concept'),
|
||||
)
|
||||
|
||||
|
||||
class GraphEdges(Base):
|
||||
"""记忆图边模型"""
|
||||
__tablename__ = 'graph_edges'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(get_string_field(255), nullable=False, index=True)
|
||||
target = Column(get_string_field(255), nullable=False, index=True)
|
||||
strength = Column(Integer, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphedges_source', 'source'),
|
||||
Index('idx_graphedges_target', 'target'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# 对用户名和密码进行URL编码,处理特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
else: # SQLite
|
||||
# 如果是相对路径,则相对于项目根目录
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
return _engine, _SessionLocal
|
||||
|
||||
database_url = get_database_url()
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': config.connection_pool_size,
|
||||
'max_overflow': config.connection_pool_size * 2,
|
||||
'pool_timeout': config.connection_timeout,
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'autocommit': config.mysql_autocommit,
|
||||
'charset': config.mysql_charset,
|
||||
'connect_timeout': config.connection_timeout,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# SQLite配置 - 添加连接池设置以避免连接耗尽
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': 20, # 增加池大小
|
||||
'max_overflow': 30, # 增加溢出连接数
|
||||
'pool_timeout': 60, # 增加超时时间
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'check_same_thread': False,
|
||||
'timeout': 30,
|
||||
}
|
||||
})
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=_engine)
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session = None
|
||||
try:
|
||||
_, SessionLocal = initialize_database()
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""获取数据库引擎"""
|
||||
engine, _ = initialize_database()
|
||||
return engine
|
||||
@@ -373,6 +373,7 @@ MODULE_COLORS = {
|
||||
"base_command": "\033[38;5;208m", # 橙色
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"plugin_hot_reload": "\033[38;5;226m", #品红色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
@@ -406,6 +407,7 @@ MODULE_COLORS = {
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
"database_model": "\033[38;5;94m", # 橙褐色
|
||||
"database": "\033[38;5;46m", # 橙褐色
|
||||
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "\033[38;5;8m", # 深灰色
|
||||
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
"expressor": "表达方式",
|
||||
"plugin_hot_reload": "热重载",
|
||||
"database": "数据库",
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
from typing import List, Optional, Any, Dict
|
||||
from sqlalchemy import not_, select, func
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.common.database.database_model import Messages
|
||||
# from src.common.database.database_model import Messages
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return model_instance.__data__
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
@@ -38,7 +44,8 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -77,42 +84,57 @@ def find_messages(
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
if limit > 0:
|
||||
# 确保limit是正整数
|
||||
limit = max(1, int(limit))
|
||||
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
peewee_results = list(query)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行latest查询失败: {e}")
|
||||
results = []
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
peewee_sort_terms = []
|
||||
sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
sort_terms.append(field.asc())
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if peewee_sort_terms:
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
return [_model_to_dict(msg) for msg in peewee_results]
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = query.count()
|
||||
return count
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
||||
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
DatabaseConfig,
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
ExpressionConfig,
|
||||
@@ -340,6 +341,7 @@ class Config(ConfigBase):
|
||||
|
||||
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||
|
||||
database: DatabaseConfig
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
relationship: RelationshipConfig
|
||||
@@ -466,4 +468,25 @@ update_model_config()
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||
|
||||
# 初始化数据库连接
|
||||
logger.info("正在初始化数据库连接...")
|
||||
from src.common.database.database import initialize_sql_database
|
||||
try:
|
||||
initialize_sql_database(global_config.database)
|
||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
# 初始化数据库表结构
|
||||
logger.info("正在初始化数据库表结构...")
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
try:
|
||||
init_db()
|
||||
logger.info("数据库表结构初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库表结构初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
||||
@@ -13,6 +13,65 @@ from src.config.config_base import ConfigBase
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig(ConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
database_type: Literal["sqlite", "mysql"] = "sqlite"
|
||||
"""数据库类型,支持 sqlite 或 mysql"""
|
||||
|
||||
# SQLite 配置
|
||||
sqlite_path: str = "data/MaiBot.db"
|
||||
"""SQLite数据库文件路径"""
|
||||
|
||||
# MySQL 配置
|
||||
mysql_host: str = "localhost"
|
||||
"""MySQL服务器地址"""
|
||||
|
||||
mysql_port: int = 3306
|
||||
"""MySQL服务器端口"""
|
||||
|
||||
mysql_database: str = "maibot"
|
||||
"""MySQL数据库名"""
|
||||
|
||||
mysql_user: str = "root"
|
||||
"""MySQL用户名"""
|
||||
|
||||
mysql_password: str = ""
|
||||
"""MySQL密码"""
|
||||
|
||||
mysql_charset: str = "utf8mb4"
|
||||
"""MySQL字符集"""
|
||||
|
||||
mysql_unix_socket: str = ""
|
||||
"""MySQL Unix套接字路径(可选,用于本地连接,优先于host/port)"""
|
||||
|
||||
# MySQL SSL 配置
|
||||
mysql_ssl_mode: str = "DISABLED"
|
||||
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
|
||||
|
||||
mysql_ssl_ca: str = ""
|
||||
"""SSL CA证书路径"""
|
||||
|
||||
mysql_ssl_cert: str = ""
|
||||
"""SSL客户端证书路径"""
|
||||
|
||||
mysql_ssl_key: str = ""
|
||||
"""SSL客户端密钥路径"""
|
||||
|
||||
# MySQL 高级配置
|
||||
mysql_autocommit: bool = True
|
||||
"""自动提交事务"""
|
||||
|
||||
mysql_sql_mode: str = "TRADITIONAL"
|
||||
"""SQL模式"""
|
||||
|
||||
# 连接池配置
|
||||
connection_pool_size: int = 10
|
||||
"""连接池大小(仅MySQL有效)"""
|
||||
|
||||
connection_timeout: int = 10
|
||||
"""连接超时时间(秒)"""
|
||||
|
||||
@dataclass
|
||||
class BotConfig(ConfigBase):
|
||||
@@ -72,6 +131,19 @@ class ChatConfig(ConfigBase):
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
|
||||
replyer_random_probability: float = 0.5
|
||||
"""
|
||||
发言时选择推理模型的概率(0-1之间)
|
||||
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
||||
"""
|
||||
|
||||
thinking_timeout: int = 40
|
||||
"""麦麦最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)"""
|
||||
|
||||
talk_frequency: float = 1
|
||||
"""回复频率阈值"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
@@ -93,17 +165,17 @@ class ChatConfig(ConfigBase):
|
||||
"""
|
||||
统一的活跃度和专注度配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
|
||||
|
||||
全局配置示例:
|
||||
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||
|
||||
|
||||
特定聊天流配置示例:
|
||||
[
|
||||
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||
]
|
||||
|
||||
|
||||
说明:
|
||||
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||
@@ -155,72 +227,11 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||
global_frequency = self._get_global_frequency()
|
||||
return self.talk_frequency if global_frequency is None else global_frequency
|
||||
|
||||
def _get_global_focus_value(self) -> Optional[float]:
|
||||
"""
|
||||
获取全局默认专注度配置
|
||||
if global_frequency is not None:
|
||||
return global_frequency
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.focus_value_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
return self._get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
def _get_time_based_focus_value(self, time_focus_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的专注度
|
||||
|
||||
Args:
|
||||
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
current_hour, current_minute = map(int, current_time.split(":"))
|
||||
current_minutes = current_hour * 60 + current_minute
|
||||
|
||||
# 解析时间专注度配置
|
||||
time_focus_pairs = []
|
||||
for time_focus_str in time_focus_list:
|
||||
try:
|
||||
time_str, focus_str = time_focus_str.split(",")
|
||||
hour, minute = map(int, time_str.split(":"))
|
||||
focus_value = float(focus_str)
|
||||
minutes = hour * 60 + minute
|
||||
time_focus_pairs.append((minutes, focus_value))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if not time_focus_pairs:
|
||||
return None
|
||||
|
||||
# 按时间排序
|
||||
time_focus_pairs.sort(key=lambda x: x[0])
|
||||
|
||||
# 查找当前时间对应的专注度
|
||||
current_focus_value = None
|
||||
for minutes, focus_value in time_focus_pairs:
|
||||
if current_minutes >= minutes:
|
||||
current_focus_value = focus_value
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
|
||||
if current_focus_value is None and time_focus_pairs:
|
||||
current_focus_value = time_focus_pairs[-1][1]
|
||||
|
||||
return current_focus_value
|
||||
# 如果都没有匹配,返回默认值
|
||||
return self.talk_frequency
|
||||
|
||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
@@ -395,6 +406,14 @@ class MessageReceiveConfig(ConfigBase):
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalChatConfig(ConfigBase):
|
||||
"""普通聊天配置类"""
|
||||
|
||||
willing_mode: str = "classical"
|
||||
"""意愿模式"""
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
@@ -403,14 +422,14 @@ class ExpressionConfig(ConfigBase):
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
]
|
||||
|
||||
|
||||
说明:
|
||||
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
@@ -475,14 +494,14 @@ class ExpressionConfig(ConfigBase):
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_expression_config is not None:
|
||||
return specific_expression_config
|
||||
specific_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_config is not None:
|
||||
return specific_config
|
||||
|
||||
# 检查全局配置(第一个元素为空字符串的配置)
|
||||
global_expression_config = self._get_global_config()
|
||||
if global_expression_config is not None:
|
||||
return global_expression_config
|
||||
global_config = self._get_global_config()
|
||||
if global_config is not None:
|
||||
return global_config
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
@@ -518,10 +537,10 @@ class ExpressionConfig(ConfigBase):
|
||||
|
||||
# 解析配置
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity: float = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -541,10 +560,10 @@ class ExpressionConfig(ConfigBase):
|
||||
# 检查是否为全局配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -558,7 +577,6 @@ class ToolConfig(ConfigBase):
|
||||
enable_tool: bool = False
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -703,7 +721,6 @@ class KeywordReactionConfig(ConfigBase):
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomPromptConfig(ConfigBase):
|
||||
"""自定义提示词配置类"""
|
||||
@@ -852,3 +869,4 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, get_session
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .model_client.base_client import UsageRecord
|
||||
@@ -143,16 +142,9 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
|
||||
class LLMUsageRecorder:
|
||||
"""
|
||||
LLM使用情况记录器
|
||||
LLM使用情况记录器(SQLAlchemy版本)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||
@@ -160,9 +152,13 @@ class LLMUsageRecorder:
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
|
||||
session = None
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
# 使用 SQLAlchemy 会话创建记录
|
||||
session = get_session()
|
||||
|
||||
usage_record = LLMUsage(
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
model_api_provider=model_info.api_provider,
|
||||
@@ -175,8 +171,12 @@ class LLMUsageRecorder:
|
||||
cost=total_cost or 0.0,
|
||||
time_cost = round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
|
||||
)
|
||||
|
||||
session.add(usage_record)
|
||||
session.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
@@ -184,6 +184,11 @@ class LLMUsageRecorder:
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
40
src/main.py
40
src/main.py
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from maim_message import MessageServer
|
||||
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
@@ -17,8 +19,9 @@ from rich.traceback import install
|
||||
from src.migrate_helper.migrate import check_and_run_migrations
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
# 导入新的插件管理器和热重载管理器
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
@@ -48,6 +51,28 @@ class MainSystem:
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
|
||||
# 设置信号处理器用于优雅退出
|
||||
self._setup_signal_handlers()
|
||||
|
||||
def _setup_signal_handlers(self):
|
||||
"""设置信号处理器"""
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("收到退出信号,正在优雅关闭系统...")
|
||||
self._cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理资源"""
|
||||
try:
|
||||
# 停止插件热重载系统
|
||||
hot_reload_manager.stop()
|
||||
logger.info("🛑 插件热重载系统已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止热重载系统时出错: {e}")
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||
@@ -58,14 +83,7 @@ class MainSystem:
|
||||
logger.info(f"""
|
||||
--------------------------------
|
||||
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
|
||||
--------------------------------
|
||||
如果想要自定义{global_config.bot.nickname}的功能,请查阅:https://docs.mai-mai.org/manual/usage/
|
||||
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
|
||||
--------------------------------
|
||||
如果你想要编写或了解插件相关内容,请访问开发文档https://docs.mai-mai.org/develop/
|
||||
--------------------------------
|
||||
如果你需要查阅模型的消耗以及麦麦的统计数据,请访问根目录的maibot_statistics.html文件
|
||||
""")
|
||||
--------------------------------""")
|
||||
|
||||
async def _init_components(self):
|
||||
"""初始化其他组件"""
|
||||
@@ -87,6 +105,10 @@ class MainSystem:
|
||||
# 加载所有actions,包括默认的和插件的
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
# 启动插件热重载系统
|
||||
|
||||
hot_reload_manager.start()
|
||||
|
||||
# 初始化表情管理器
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
0
src/person_info/fix_session.py
Normal file
0
src/person_info/fix_session.py
Normal file
@@ -2,17 +2,17 @@ import hashlib
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union
|
||||
|
||||
from typing import Any, Callable, Dict, Union, Optional
|
||||
from sqlalchemy import select
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
session = get_session()
|
||||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
@@ -380,36 +380,282 @@ class Person:
|
||||
|
||||
return relation_info
|
||||
|
||||
# 统一的会话管理函数
|
||||
def with_session(func):
|
||||
"""装饰器:为函数自动注入session参数"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
|
||||
return await func(session, *args, **kwargs)
|
||||
return async_wrapper
|
||||
else:
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
|
||||
return func(session, *args, **kwargs)
|
||||
return sync_wrapper
|
||||
|
||||
# 全局会话获取函数,用于替换所有裸露的session使用
|
||||
def _get_session():
|
||||
"""获取数据库会话的统一函数"""
|
||||
return get_session()
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
|
||||
"""初始化PersonInfoManager"""
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 设置连接池参数
|
||||
# 设置连接池参数(仅对SQLite有效)
|
||||
if hasattr(db, "execute_sql"):
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
# 检查数据库类型,只对SQLite执行PRAGMA语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
db.create_tables([PersonInfo], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有person_name
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_null(False)
|
||||
):
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
# 在这里获取会话
|
||||
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_not(None)
|
||||
)).fetchall():
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
|
||||
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
# 在需要时获取会话
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||
"""创建一个项"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
try:
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
|
||||
"""安全地创建用户信息,处理竞态条件"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
def _db_safe_create_sync(p_data: dict):
|
||||
try:
|
||||
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
|
||||
if existing:
|
||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||
return True
|
||||
|
||||
# 尝试创建
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||||
return True # 其他协程已创建,视为成功
|
||||
else:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_safe_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义的字段。")
|
||||
return
|
||||
|
||||
processed_value = value
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(value, (list, dict)):
|
||||
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = json.dumps([], ensure_ascii=False, indent=None)
|
||||
|
||||
def _db_update_sync(p_id: str, f_name: str, val_to_set):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
query_time = time.time()
|
||||
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
session.commit()
|
||||
save_time = time.time()
|
||||
|
||||
total_time = save_time - start_time
|
||||
if total_time > 0.5: # 如果超过500ms就记录日志
|
||||
logger.warning(
|
||||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
|
||||
)
|
||||
|
||||
return True, False # Found and updated, no creation needed
|
||||
else:
|
||||
total_time = time.time() - start_time
|
||||
if total_time > 0.5:
|
||||
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
|
||||
return False, True # Not found, needs creation
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||||
raise
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
|
||||
|
||||
if needs_creation:
|
||||
logger.info(f"{person_id} 不存在,将新建。")
|
||||
creation_data = data if data is not None else {}
|
||||
# Ensure platform and user_id are present for context if available from 'data'
|
||||
# but primarily, set the field that triggered the update.
|
||||
# The create_person_info will handle defaults and serialization.
|
||||
creation_data[field_name] = value # Pass original value to create_person_info
|
||||
|
||||
# Ensure platform and user_id are in creation_data if available,
|
||||
# otherwise create_person_info will use defaults.
|
||||
if data and "platform" in data:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
# 使用安全的创建方法,处理竞态条件
|
||||
await self._safe_create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def has_one_field(person_id: str, field_name: str):
|
||||
"""判断是否存在某一个字段"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
|
||||
return False
|
||||
|
||||
def _db_has_field_sync(p_id: str, f_name: str):
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
return bool(record)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
@@ -513,12 +759,13 @@ class PersonInfoManager:
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
person.person_name = generated_nickname
|
||||
person.name_reason = result.get("reason", "未提供理由")
|
||||
@@ -547,4 +794,304 @@ class PersonInfoManager:
|
||||
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
|
||||
|
||||
|
||||
person_info_manager = PersonInfoManager()
|
||||
@staticmethod
|
||||
async def del_one_document(person_id: str):
|
||||
"""删除指定 person_id 的文档"""
|
||||
if not person_id:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
return
|
||||
|
||||
def _db_delete_sync(p_id: str):
|
||||
try:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
session.delete(record)
|
||||
session.commit()
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
val = getattr(record, f_name, None)
|
||||
if f_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
# If val is already a list/dict (e.g. if somehow set without serialization)
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val
|
||||
return None # Record not found
|
||||
|
||||
try:
|
||||
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
if value_from_db is not None:
|
||||
return value_from_db
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None # Ultimate fallback
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
# Fallback to default in case of any error during DB access
|
||||
return default_value_for_field if field_name in person_info_default else None
|
||||
|
||||
@staticmethod
|
||||
def get_value_sync(person_id: str, field_name: str):
|
||||
"""同步获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = []
|
||||
|
||||
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
|
||||
val = getattr(record, field_name, None)
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return []
|
||||
elif val is None:
|
||||
return []
|
||||
return val
|
||||
return val
|
||||
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_values获取失败:person_id不能为空")
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
def _db_get_record_sync(p_id: str):
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name not in model_fields:
|
||||
if field_name in person_info_default:
|
||||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。")
|
||||
else:
|
||||
logger.debug(f"get_values查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
|
||||
result[field_name] = None
|
||||
continue
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if value is not None:
|
||||
result[field_name] = value
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取满足条件的字段值字典
|
||||
"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
|
||||
return {}
|
||||
|
||||
def _db_get_specific_sync(f_name: str):
|
||||
found_results = {}
|
||||
try:
|
||||
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
|
||||
value = getattr(record, f_name)
|
||||
if way(value):
|
||||
found_results[record.person_id] = value
|
||||
except Exception as e_query:
|
||||
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def get_or_create_person(
|
||||
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||
使用try-except处理竞态条件,避免重复创建错误。
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_get_or_create_sync(p_id: str, init_data: dict):
|
||||
"""原子性的获取或创建操作"""
|
||||
# 首先尝试获取现有记录
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
new_person = PersonInfo(**init_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
unique_nickname = await self._generate_unique_person_name(nickname)
|
||||
initial_data = {
|
||||
"person_id": person_id,
|
||||
"platform": platform,
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"person_name": unique_nickname, # 使用群昵称作为person_name
|
||||
"name_reason": "从群昵称获取",
|
||||
"know_times": 0,
|
||||
"know_since": int(datetime.datetime.now().timestamp()),
|
||||
"last_know": int(datetime.datetime.now().timestamp()),
|
||||
"impression": None,
|
||||
"points": [],
|
||||
"forgotten_points": [],
|
||||
}
|
||||
|
||||
# 序列化JSON字段
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
# 获取 SQLAlchemy 模odel的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
else:
|
||||
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
|
||||
|
||||
return person_id
|
||||
|
||||
async def get_person_info_by_name(self, person_name: str) -> dict | None:
|
||||
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
|
||||
if not person_name:
|
||||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||||
return None
|
||||
|
||||
found_person_id = None
|
||||
for pid, name_in_cache in self.person_name_list.items():
|
||||
if name_in_cache == person_name:
|
||||
found_person_id = pid
|
||||
break
|
||||
|
||||
if not found_person_id:
|
||||
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if (
|
||||
found_person_id not in self.person_name_list
|
||||
or self.person_name_list[found_person_id] != person_name
|
||||
):
|
||||
self.person_name_list[found_person_id] = person_name
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||
return None
|
||||
|
||||
if found_person_id:
|
||||
required_fields = [
|
||||
"person_id",
|
||||
"platform",
|
||||
"user_id",
|
||||
"nickname",
|
||||
"user_cardname",
|
||||
"user_avatar",
|
||||
"person_name",
|
||||
"name_reason",
|
||||
]
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
valid_fields_to_get = [
|
||||
f
|
||||
for f in required_fields
|
||||
if f in model_fields or f in person_info_default
|
||||
]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
if person_data:
|
||||
final_result = {key: person_data.get(key) for key in required_fields}
|
||||
return final_result
|
||||
else:
|
||||
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||||
return None
|
||||
|
||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||||
return None
|
||||
|
||||
|
||||
person_info_manager = None
|
||||
|
||||
|
||||
def get_person_info_manager():
|
||||
global person_info_manager
|
||||
if person_info_manager is None:
|
||||
person_info_manager = PersonInfoManager()
|
||||
return person_info_manager
|
||||
|
||||
@@ -5,385 +5,25 @@
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from src.common.logger import get_logger
|
||||
from peewee import Model, DoesNotExist
|
||||
from src.common.database.sqlalchemy_database_api import (
|
||||
db_query,
|
||||
db_save,
|
||||
db_get,
|
||||
store_action_info,
|
||||
get_model_class,
|
||||
MODEL_MAPPING
|
||||
)
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
# =============================================================================
|
||||
# 通用数据库查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Model],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
"""
|
||||
"""
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
|
||||
query_type="create",
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_done": True},
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
return None if query_type == "get" and single_result else []
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await database_api.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if existing_records := list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
):
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Model],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||
|
||||
Returns:
|
||||
如果single_result为True,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await database_api.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await database_api.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by="-time",
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if single_result else []
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象,包含聊天相关信息
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存的记录数据
|
||||
None: 如果保存失败
|
||||
|
||||
示例:
|
||||
record = await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display="执行了回复动作",
|
||||
action_done=True,
|
||||
thinking_id="thinking_123",
|
||||
action_data={"content": "Hello"},
|
||||
action_name="reply_action"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
import json
|
||||
from src.common.database.database_model import ActionRecords
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 如果没有chat_stream,设置默认值
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 使用已有的db_save函数保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
# 保持向后兼容性
|
||||
__all__ = [
|
||||
'db_query',
|
||||
'db_save',
|
||||
'db_get',
|
||||
'store_action_info',
|
||||
'get_model_class',
|
||||
'MODEL_MAPPING'
|
||||
]
|
||||
|
||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"hot_reload_manager",
|
||||
]
|
||||
|
||||
@@ -237,35 +237,55 @@ class ComponentRegistry:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
return False
|
||||
try:
|
||||
# 根据组件类型进行特定的清理操作
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._action_registry.pop(component_name)
|
||||
self._default_actions.pop(component_name)
|
||||
# 移除Action注册
|
||||
self._action_registry.pop(component_name, None)
|
||||
self._default_actions.pop(component_name, None)
|
||||
logger.debug(f"已移除Action组件: {component_name}")
|
||||
|
||||
case ComponentType.COMMAND:
|
||||
self._command_registry.pop(component_name)
|
||||
# 移除Command注册和模式
|
||||
self._command_registry.pop(component_name, None)
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key)
|
||||
self._command_patterns.pop(key, None)
|
||||
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
|
||||
|
||||
case ComponentType.TOOL:
|
||||
self._tool_registry.pop(component_name)
|
||||
self._llm_available_tools.pop(component_name)
|
||||
# 移除Tool注册
|
||||
self._tool_registry.pop(component_name, None)
|
||||
self._llm_available_tools.pop(component_name, None)
|
||||
logger.debug(f"已移除Tool组件: {component_name}")
|
||||
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
# 移除EventHandler注册和事件订阅
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
self._event_handler_registry.pop(component_name)
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._event_handler_registry.pop(component_name, None)
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
try:
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
|
||||
case _:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return False
|
||||
|
||||
# 移除通用注册信息
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components.pop(namespaced_name)
|
||||
self._components_by_type[component_type].pop(component_name)
|
||||
self._components_classes.pop(namespaced_name)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
self._components.pop(namespaced_name, None)
|
||||
self._components_by_type[component_type].pop(component_name, None)
|
||||
self._components_classes.pop(namespaced_name, None)
|
||||
|
||||
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
@@ -615,5 +635,54 @@ class ComponentRegistry:
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def unregister_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载插件及其所有组件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功卸载
|
||||
"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
if not plugin_info:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
|
||||
return False
|
||||
|
||||
logger.info(f"开始卸载插件: {plugin_name}")
|
||||
|
||||
# 记录卸载失败的组件
|
||||
failed_components = []
|
||||
|
||||
# 逐个移除插件的所有组件
|
||||
for component_info in plugin_info.components:
|
||||
try:
|
||||
success = await self.remove_component(
|
||||
component_info.name,
|
||||
component_info.component_type,
|
||||
plugin_name,
|
||||
)
|
||||
if not success:
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
|
||||
# 移除插件注册信息
|
||||
plugin_removed = self.remove_plugin_registry(plugin_name)
|
||||
|
||||
if failed_components:
|
||||
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
|
||||
return False
|
||||
elif not plugin_removed:
|
||||
logger.error(f"插件 {plugin_name} 注册信息移除失败")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"插件 {plugin_name} 卸载成功")
|
||||
return True
|
||||
|
||||
|
||||
# 创建全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
|
||||
@@ -33,7 +33,7 @@ class EventsManager:
|
||||
|
||||
if handler_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
return True
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
|
||||
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
插件热重载模块
|
||||
|
||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Dict, Set
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
logger = get_logger("plugin_hot_reload")
|
||||
|
||||
|
||||
class PluginFileHandler(FileSystemEventHandler):
|
||||
"""插件文件变化处理器"""
|
||||
|
||||
def __init__(self, hot_reload_manager):
|
||||
super().__init__()
|
||||
self.hot_reload_manager = hot_reload_manager
|
||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
||||
self.debounce_delay = 1.0 # 防抖延迟(秒)
|
||||
|
||||
def on_modified(self, event):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
"""处理文件变化"""
|
||||
try:
|
||||
# 获取插件名称
|
||||
plugin_name = self._get_plugin_name_from_path(file_path)
|
||||
if not plugin_name:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
last_time = self.last_reload_time.get(plugin_name, 0)
|
||||
|
||||
# 防抖处理,避免频繁重载
|
||||
if current_time - last_time < self.debounce_delay:
|
||||
return
|
||||
|
||||
file_name = Path(file_path).name
|
||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
|
||||
|
||||
# 如果是删除事件,处理关键文件删除
|
||||
if change_type == "deleted":
|
||||
if file_name == "plugin.py":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
elif file_name == "manifest.toml":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
# 添加到待重载列表
|
||||
self.pending_reloads.add(plugin_name)
|
||||
self.last_reload_time[plugin_name] = current_time
|
||||
|
||||
# 延迟重载,避免文件正在写入时重载
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name,),
|
||||
daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理文件变化时发生错误: {e}")
|
||||
|
||||
def _delayed_reload(self, plugin_name: str):
|
||||
"""延迟重载插件"""
|
||||
try:
|
||||
time.sleep(self.debounce_delay)
|
||||
|
||||
if plugin_name in self.pending_reloads:
|
||||
self.pending_reloads.remove(plugin_name)
|
||||
self.hot_reload_manager._reload_plugin(plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _get_plugin_name_from_path(self, file_path: str) -> str:
|
||||
"""从文件路径获取插件名称"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查是否在监听的插件目录中
|
||||
plugin_root = Path(self.hot_reload_manager.watch_directory)
|
||||
if not path.is_relative_to(plugin_root):
|
||||
return ""
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml)
|
||||
plugin_dir = plugin_root / plugin_name
|
||||
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
|
||||
return plugin_name
|
||||
|
||||
return ""
|
||||
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class PluginHotReloadManager:
|
||||
"""插件热重载管理器"""
|
||||
|
||||
def __init__(self, watch_directory: str = None):
|
||||
print("fuck")
|
||||
print(os.getcwd())
|
||||
self.watch_directory = os.path.join(os.getcwd(), "plugins")
|
||||
self.observer = None
|
||||
self.file_handler = None
|
||||
self.is_running = False
|
||||
|
||||
# 确保监听目录存在
|
||||
if not os.path.exists(self.watch_directory):
|
||||
os.makedirs(self.watch_directory, exist_ok=True)
|
||||
logger.info(f"创建插件监听目录: {self.watch_directory}")
|
||||
|
||||
def start(self):
|
||||
"""启动热重载监听"""
|
||||
if self.is_running:
|
||||
logger.warning("插件热重载已经在运行中")
|
||||
return
|
||||
|
||||
try:
|
||||
self.observer = Observer()
|
||||
self.file_handler = PluginFileHandler(self)
|
||||
|
||||
self.observer.schedule(
|
||||
self.file_handler,
|
||||
self.watch_directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
self.observer.start()
|
||||
self.is_running = True
|
||||
|
||||
logger.info("🚀 插件热重载已启动,监听目录: plugins")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
||||
self.is_running = False
|
||||
|
||||
def stop(self):
|
||||
"""停止热重载监听"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
def _reload_plugin(self, plugin_name: str):
|
||||
"""重载指定插件"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _unload_plugin(self, plugin_name: str):
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def reload_all_plugins(self):
|
||||
"""重载所有插件"""
|
||||
try:
|
||||
logger.info("🔄 开始重载所有插件...")
|
||||
|
||||
# 获取当前已加载的插件列表
|
||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for plugin_name in loaded_plugins:
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载所有插件时发生错误: {e}")
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""获取热重载状态"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directory": self.watch_directory,
|
||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
||||
}
|
||||
|
||||
|
||||
# 全局热重载管理器实例
|
||||
hot_reload_manager = PluginHotReloadManager()
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
@@ -488,6 +489,105 @@ class PluginManager:
|
||||
else:
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||
|
||||
# === 插件卸载和重载管理 ===
|
||||
|
||||
def unload_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功
|
||||
"""
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取插件实例
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
|
||||
# 从失败列表中移除(如果存在)
|
||||
if plugin_name in self.failed_plugins:
|
||||
del self.failed_plugins[plugin_name]
|
||||
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
try:
|
||||
# 先卸载插件
|
||||
if plugin_name in self.loaded_plugins:
|
||||
self.unload_plugin(plugin_name)
|
||||
|
||||
# 清除Python模块缓存
|
||||
plugin_path = self.plugin_paths.get(plugin_name)
|
||||
if plugin_path:
|
||||
plugin_file = os.path.join(plugin_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
# 从sys.modules中移除相关模块
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
|
||||
|
||||
for module_name in sys.modules:
|
||||
if module_name.startswith(plugin_module_prefix):
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
for module_name in modules_to_remove:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# 从插件类注册表中移除
|
||||
if plugin_name in self.plugin_classes:
|
||||
del self.plugin_classes[plugin_name]
|
||||
|
||||
# 重新加载插件模块
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
# 重新加载插件实例
|
||||
success, _ = self.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"🔄 插件重载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
|
||||
@@ -1,149 +1,149 @@
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
|
||||
class TTSAction(BaseAction):
|
||||
"""TTS语音转换动作处理类"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "tts_action"
|
||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"当需要发送语音信息时使用",
|
||||
"当用户明确要求使用语音功能时使用",
|
||||
"当表达内容更适合用语音而不是文字传达时使用",
|
||||
"当用户想听到语音回答而非阅读文本时使用",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
# 获取要转换的文本
|
||||
text = self.action_data.get("text")
|
||||
|
||||
if not text:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
|
||||
return False, "执行TTS动作失败:未提供文本内容"
|
||||
|
||||
# 确保文本适合TTS使用
|
||||
processed_text = self._process_text_for_tts(text)
|
||||
|
||||
try:
|
||||
# 发送TTS消息
|
||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
||||
|
||||
# 记录动作信息
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
||||
return True, "TTS动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
|
||||
return False, f"执行TTS动作时出错: {e}"
|
||||
|
||||
def _process_text_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
处理文本使其更适合TTS使用
|
||||
- 移除不必要的特殊字符和表情符号
|
||||
- 修正标点符号以提高语音质量
|
||||
- 优化文本结构使语音更流畅
|
||||
"""
|
||||
# 这里可以添加文本处理逻辑
|
||||
# 例如:移除多余的标点、表情符号,优化语句结构等
|
||||
|
||||
# 简单示例实现
|
||||
processed_text = text
|
||||
|
||||
# 移除多余的标点符号
|
||||
import re
|
||||
|
||||
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
|
||||
|
||||
# 确保句子结尾有合适的标点
|
||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||
processed_text = f"{processed_text}。"
|
||||
|
||||
return processed_text
|
||||
|
||||
|
||||
@register_plugin
|
||||
class TTSPlugin(BasePlugin):
|
||||
"""TTS插件
|
||||
- 这是文字转语音插件
|
||||
- Normal模式下依靠关键词触发
|
||||
- Focus模式下由LLM判断触发
|
||||
- 具有一定的文本预处理能力
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"components": "组件启用控制",
|
||||
"logging": "日志记录相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
|
||||
},
|
||||
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
|
||||
"logging": {
|
||||
"level": ConfigField(
|
||||
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
),
|
||||
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
enable_tts = self.get_config("components.enable_tts", True)
|
||||
components = [] # 添加Action组件
|
||||
if enable_tts:
|
||||
components.append((TTSAction.get_action_info(), TTSAction))
|
||||
|
||||
return components
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
|
||||
class TTSAction(BaseAction):
|
||||
"""TTS语音转换动作处理类"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "tts_action"
|
||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"当需要发送语音信息时使用",
|
||||
"当用户明确要求使用语音功能时使用",
|
||||
"当表达内容更适合用语音而不是文字传达时使用",
|
||||
"当用户想听到语音回答而非阅读文本时使用",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
# 获取要转换的文本
|
||||
text = self.action_data.get("text")
|
||||
|
||||
if not text:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
|
||||
return False, "执行TTS动作失败:未提供文本内容"
|
||||
|
||||
# 确保文本适合TTS使用
|
||||
processed_text = self._process_text_for_tts(text)
|
||||
|
||||
try:
|
||||
# 发送TTS消息
|
||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
||||
|
||||
# 记录动作信息
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
||||
return True, "TTS动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
|
||||
return False, f"执行TTS动作时出错: {e}"
|
||||
|
||||
def _process_text_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
处理文本使其更适合TTS使用
|
||||
- 移除不必要的特殊字符和表情符号
|
||||
- 修正标点符号以提高语音质量
|
||||
- 优化文本结构使语音更流畅
|
||||
"""
|
||||
# 这里可以添加文本处理逻辑
|
||||
# 例如:移除多余的标点、表情符号,优化语句结构等
|
||||
|
||||
# 简单示例实现
|
||||
processed_text = text
|
||||
|
||||
# 移除多余的标点符号
|
||||
import re
|
||||
|
||||
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
|
||||
|
||||
# 确保句子结尾有合适的标点
|
||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||
processed_text = f"{processed_text}。"
|
||||
|
||||
return processed_text
|
||||
|
||||
|
||||
@register_plugin
|
||||
class TTSPlugin(BasePlugin):
|
||||
"""TTS插件
|
||||
- 这是文字转语音插件
|
||||
- Normal模式下依靠关键词触发
|
||||
- Focus模式下由LLM判断触发
|
||||
- 具有一定的文本预处理能力
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"components": "组件启用控制",
|
||||
"logging": "日志记录相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
|
||||
},
|
||||
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
|
||||
"logging": {
|
||||
"level": ConfigField(
|
||||
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
),
|
||||
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
enable_tts = self.get_config("components.enable_tts", True)
|
||||
components = [] # 添加Action组件
|
||||
if enable_tts:
|
||||
components.append((TTSAction.get_action_info(), TTSAction))
|
||||
|
||||
return components
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "6.4.0"
|
||||
version = "6.2.3"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -11,8 +11,38 @@ version = "6.4.0"
|
||||
# 修订号:配置文件内容小更新
|
||||
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
|
||||
[database]
|
||||
# 数据库配置
|
||||
database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "mysql"
|
||||
|
||||
# SQLite 配置(当 database_type = "sqlite" 时使用)
|
||||
sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径
|
||||
|
||||
# MySQL 配置(当 database_type = "mysql" 时使用)
|
||||
mysql_host = "localhost" # MySQL服务器地址
|
||||
mysql_port = 3306 # MySQL服务器端口
|
||||
mysql_database = "maibot" # MySQL数据库名
|
||||
mysql_user = "root" # MySQL用户名
|
||||
mysql_password = "" # MySQL密码
|
||||
mysql_charset = "utf8mb4" # MySQL字符集
|
||||
mysql_unix_socket = "" # MySQL Unix套接字路径(可选,用于本地连接,优先于host/port)
|
||||
|
||||
# MySQL SSL 配置
|
||||
mysql_ssl_mode = "DISABLED" # SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY
|
||||
mysql_ssl_ca = "" # SSL CA证书路径
|
||||
mysql_ssl_cert = "" # SSL客户端证书路径
|
||||
mysql_ssl_key = "" # SSL客户端密钥路径
|
||||
|
||||
# MySQL 高级配置
|
||||
mysql_autocommit = true # 自动提交事务
|
||||
mysql_sql_mode = "TRADITIONAL" # SQL模式
|
||||
|
||||
# 连接池配置
|
||||
connection_pool_size = 10 # 连接池大小(仅MySQL有效)
|
||||
connection_timeout = 10 # 连接超时时间(秒)
|
||||
|
||||
[bot]
|
||||
platform = "qq"
|
||||
platform = "qq"
|
||||
qq_account = 1145141919810 # 麦麦的QQ账号
|
||||
nickname = "麦麦" # 麦麦的昵称
|
||||
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
|
||||
@@ -53,11 +83,11 @@ expression_groups = [
|
||||
]
|
||||
|
||||
|
||||
[chat] #麦麦的聊天设置
|
||||
talk_frequency = 0.5
|
||||
# 麦麦活跃度,越高,麦麦回复越多,范围0-1
|
||||
focus_value = 0.5
|
||||
# 麦麦的专注度,越高越容易持续连续对话,可能消耗更多token, 范围0-1
|
||||
|
||||
[chat] #麦麦的聊天通用设置
|
||||
focus_value = 1
|
||||
# 麦麦的专注思考能力,越高越容易专注,可能消耗更多token
|
||||
# 专注时能更好把握发言时机,能够进行持久的连续对话
|
||||
|
||||
max_context_size = 20 # 上下文长度
|
||||
|
||||
@@ -120,6 +150,7 @@ mood_update_threshold = 1 # 情绪更新阈值,越高,更新越慢
|
||||
|
||||
[emoji]
|
||||
emoji_chance = 0.6 # 麦麦激活表情包动作的概率
|
||||
emoji_activate_type = "llm" # 表情包激活类型,可选:random,llm ; random下,表情包动作随机启用,llm下,表情包动作根据llm判断是否启用
|
||||
|
||||
max_reg_num = 60 # 表情包最大注册数量
|
||||
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
|
||||
|
||||
Reference in New Issue
Block a user