初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
committed by Windpicker-owo
parent ef7a3aee23
commit 23ee3767ef
77 changed files with 10000 additions and 7525 deletions

4
.github/prompts/chat.prompt.md vendored Normal file
View File

@@ -0,0 +1,4 @@
---
mode: agent
---
记得执行前激活虚拟环境用的shell是powershell与linux语法有区别

View File

@@ -1,121 +0,0 @@
# 贡献者契约行为准则
## 我们的承诺
作为成员、贡献者和维护者,我们承诺为每个人提供友好、安全和受欢迎的环境,无论年龄、体型、身体或精神上的残疾、民族、性别特征、性别认同和表达、经验水平、教育、社会经济地位、国籍、个人外貌、种族、宗教或性取向如何。
我们承诺以有助于建立开放、友好、多元化、包容和健康社区的方式行事和互动。
## 我们的标准
有助于为我们的社区创造积极环境的行为示例包括:
* 表现出对其他人的同理心和善意
* 尊重不同的意见、观点和经验
* 优雅地给出和接受建设性反馈
* 承担责任,为我们的错误向受影响的人道歉,并从中学习经验
* 专注于不仅对我们个人,而且对整个社区最有利的事情
* 使用友善和包容的语言
* 专业地讨论技术问题,避免人身攻击
不可接受的行为示例包括:
* 使用性暗示的语言或图像,以及任何形式的性关注或性挑逗
* 恶意评论、侮辱或贬损性评论,以及人身攻击或政治攻击
* 公开或私下的骚扰
* 未经明确许可,发布他人的私人信息,如物理地址或电子邮件地址
* 在专业环境中合理认为不当的其他行为
* 故意传播错误信息或误导性内容
* 恶意破坏项目资源或社区讨论
## 执行责任
社区维护者负责澄清和执行我们可接受行为的标准,并会对他们认为不当、威胁、冒犯或有害的任何行为采取适当和公平的纠正措施。
社区维护者有权删除、编辑或拒绝与本行为准则不符的评论、提交、代码、wiki编辑、问题和其他贡献并会在适当时传达审核决定的原因。
## 适用范围
本行为准则适用于所有社区空间,包括但不限于:
* GitHub 仓库及相关讨论区
* Issue 和 Pull Request 讨论
* 项目相关的在线论坛、聊天室和社交媒体
* 项目官方活动和会议
* 代表项目或社区的任何其他场合
当个人代表项目或其社区时,本行为准则也适用于公共空间。代表的示例包括使用官方电子邮件地址、通过官方社交媒体账户发布信息,或在在线或线下活动中担任指定代表。
## 特定于MaiBot项目的指导原则
### 技术讨论原则
* 保持技术讨论的专业性和建设性
* 在提出问题前请先查看现有文档和已有的issues
* 提供清晰、详细的错误报告和功能请求
* 尊重不同的技术选择和实现方案
### AI/LLM相关内容规范
* 讨论AI技术应当负责任和伦理
* 不得分享或讨论可能造成伤害的AI应用
* 尊重数据隐私和用户权益
* 遵守相关法律法规和平台政策
### 多语言支持
* 主要使用中文进行交流,但欢迎其他语言的贡献者
* 对非中文母语用户保持耐心和友善
* 在必要时提供翻译帮助
## 报告机制
如果您遇到或目睹违反行为准则的行为,请通过以下方式报告:
1. **GitHub Issues**: 对于公开的违规行为可以在相关issue中直接指出
2. **私下联系**: 可以通过GitHub私信联系项目维护者
3. **邮件联系**: [如果有项目邮箱地址,请在此提供]
所有报告都将得到及时和公正的处理。我们承诺保护报告者的隐私和安全。
## 执行措施
社区维护者将遵循以下社区影响指导原则来确定违反本行为准则的后果:
### 1. 更正
**社区影响**: 使用不当语言或其他被认为在社区中不专业或不受欢迎的行为。
**后果**: 由社区维护者私下发出书面警告,提供关于违规性质的明确说明和行为不当的原因解释。可能会要求公开道歉。
### 2. 警告
**社区影响**: 通过单个事件或一系列行为违规。
**后果**: 警告并说明继续违规的后果。在规定的时间内,不得与相关人员互动,包括主动与执行行为准则的人员互动。这包括避免在社区空间以及外部渠道(如社交媒体)中的互动。违反这些条款可能导致临时或永久禁令。
### 3. 临时禁令
**社区影响**: 严重违反社区标准,包括持续的不当行为。
**后果**: 在规定的时间内临时禁止与社区进行任何形式的互动或公开交流。在此期间,不允许与相关人员进行公开或私下互动,包括主动与执行行为准则的人员互动。违反这些条款可能导致永久禁令。
### 4. 永久禁令
**社区影响**: 表现出违反社区标准的模式,包括持续的不当行为、对个人的骚扰,或对某类个人的攻击或贬低。
**后果**: 永久禁止在社区内进行任何形式的公开互动。
## 归属
本行为准则改编自[贡献者契约](https://www.contributor-covenant.org/)版本2.1,可在 https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 获得。
社区影响指导原则的灵感来自[Mozilla 的行为准则执行阶梯](https://github.com/mozilla/diversity)。
有关本行为准则的常见问题解答,请参见 https://www.contributor-covenant.org/faq。翻译版本可在 https://www.contributor-covenant.org/translations 获得。
## 联系方式
如果您对本行为准则有任何疑问或建议,请通过以下方式联系我们:
* 在GitHub上创建issue进行讨论
* 联系项目维护者
---
**感谢您帮助我们建设一个友好、包容的开源社区!**
*最后更新时间: 2025年6月21日*

View File

@@ -0,0 +1,21 @@
{
"name": "MaiBot-Napcat-Adapter-DevContainer",
"image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye",
"features": {
"ghcr.io/rocker-org/devcontainer-features/apt-packages:1": {
"packages": [
"tmux"
]
},
"ghcr.io/devcontainers/features/github-cli:1": {}
},
"forwardPorts": [
"8095:8095"
],
"postCreateCommand": "pip3 install --user -r requirements.txt",
"customizations" : {
"jetbrains" : {
"backend" : "PyCharm"
}
}
}

View File

@@ -0,0 +1,54 @@
name: Docker Image CI
on:
push:
branches: [ "main", "dev" ]
workflow_dispatch: # 允许手动触发工作流
jobs:
build:
runs-on: ubuntu-latest
env:
DOCKERHUB_USER: ${{ secrets.DOCKERHUB_USERNAME }}
DATE_TAG: $(date -u +'%Y-%m-%dT%H-%M-%S')
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Clone maim_message
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
- name: Determine Image Tags
id: tags
run: |
if [ "${{ github.ref_name }}" == "main" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:latest,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:main-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT
elif [ "${{ github.ref_name }}" == "dev" ]; then
echo "tags=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev,${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:dev-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT
fi
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
tags: ${{ steps.tags.outputs.tags }}
push: true
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }}
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maimbot-adapter:buildcache-${{ github.ref_name }},mode=max
labels: |
org.opencontainers.image.created=${{ steps.tags.outputs.date_tag }}
org.opencontainers.image.revision=${{ github.sha }}

277
MaiBot-Napcat-Adapter-dev/.gitignore vendored Normal file
View File

@@ -0,0 +1,277 @@
log/
logs/
out/
.env
.env.*
.cursor
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
llm_statistics.txt
mongodb
napcat
run_dev.bat
elua.confirmed
# C extensions
*.so
/results
config_backup/
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
# jieba
jieba.cache
# .vscode
!.vscode/settings.json
# direnv
/.direnv
# JetBrains
.idea
*.iml
*.ipr
# PyEnv
# If using PyEnv and configured to use a specific Python version locally
# a .local-version file will be created in the root of the project to specify the version.
.python-version
OtherRes.txt
/eula.confirmed
/privacy.confirmed
logs
.ruff_cache
.vscode
/config/*
config/old/bot_config_20250405_212257.toml
temp/
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
config.toml
config.toml.back
test
data/NapcatAdapter.db
data/NapcatAdapter.db-shm
data/NapcatAdapter.db-wal

View File

@@ -0,0 +1,20 @@
FROM python:3.13.5-slim
LABEL authors="infinitycat233"
# Copy uv and maim_message
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY maim_message /maim_message
COPY requirements.txt /requirements.txt
# Install requirements
RUN uv pip install --system --upgrade pip
RUN uv pip install --system -e /maim_message
RUN uv pip install --system -r /requirements.txt
WORKDIR /adapters
COPY . .
EXPOSE 8095
ENTRYPOINT ["python", "main.py"]

View File

@@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@@ -0,0 +1,83 @@
# MaiBot 与 Napcat 的 Adapter
运行方式:独立/放在MaiBot本体作为插件
# 使用说明
请参考[官方文档](https://docs.mai-mai.org/manual/adapters/napcat.html)
# 消息流转过程
```mermaid
sequenceDiagram
participant Napcat as Napcat客户端
participant Adapter as MaiBot-Napcat适配器
participant Queue as 消息队列
participant Handler as 消息处理器
participant MaiBot as MaiBot服务
Note over Napcat,MaiBot: 初始化阶段
Napcat->>Adapter: WebSocket连接(ws://localhost:8095)
Adapter->>MaiBot: WebSocket连接(ws://localhost:8000)
Note over Napcat,MaiBot: 心跳检测
loop 每30秒
Napcat->>Adapter: 发送心跳包
Adapter->>Napcat: 心跳响应
end
Note over Napcat,MaiBot: 消息处理流程
Napcat->>Adapter: 发送消息
Adapter->>Queue: 消息入队(message_queue)
Queue->>Handler: 消息出队处理
Handler->>Handler: 解析消息类型
alt 文本消息
Handler->>MaiBot: 发送文本消息
else 图片消息
Handler->>MaiBot: 发送图片消息
else 混合消息
Handler->>MaiBot: 发送混合消息
else 转发消息
Handler->>MaiBot: 发送转发消息
end
MaiBot-->>Adapter: 消息响应
Adapter-->>Napcat: 消息响应
Note over Napcat,MaiBot: 优雅关闭
Adapter->>MaiBot: 关闭连接
Adapter->>Queue: 清空消息队列
Adapter->>Napcat: 关闭连接
```
# TO DO List
- [x] 读取自动心跳测试连接
- [x] 接受消息解析
- [x] 文本解析
- [x] 图片解析
- [x] 文本与消息混合解析
- [x] 转发解析(含图片动态解析)
- [ ] 群公告解析
- [x] 回复解析
- [ ] 群临时消息(可能不做)
- [ ] 链接解析
- [x] 戳一戳解析
- [x] 读取戳一戳的自定义内容
- [ ] 语音解析(?)
- [ ] 所有的notice类
- [x] 撤回(已添加相关指令)
- [x] 发送消息
- [x] 发送文本
- [x] 发送图片
- [x] 发送表情包
- [x] 引用回复(完成但是没测试)
- [ ] 戳回去(?)
- [x] 发送语音
- [x] 使用echo与uuid保证消息顺序
- [x] 执行部分管理员功能
- [x] 禁言别人
- [x] 全体禁言
- [x] 群踢人功能
# 特别鸣谢
特别感谢[@Maple127667](https://github.com/Maple127667)对本项目代码思路的支持
以及[@墨梓柒](https://github.com/DrSmoothl)对部分代码想法的支持

View File

@@ -0,0 +1,60 @@
# Command Arguments
```python
Seg.type = "command"
```
## 群聊禁言
```python
Seg.data: Dict[str, Any] = {
"name": "GROUP_BAN",
"args": {
"qq_id": "用户QQ号",
"duration": "禁言时长(秒)"
},
}
```
其中群聊ID将会通过Group_Info.group_id自动获取。
**当`duration`为 0 时相当于解除禁言。**
## 群聊全体禁言
```python
Seg.data: Dict[str, Any] = {
"name": "GROUP_WHOLE_BAN",
"args": {
"enable": "是否开启全体禁言True/False"
},
}
```
其中群聊ID将会通过Group_Info.group_id自动获取。
`enable`的参数需要为boolean类型True表示开启全体禁言False表示关闭全体禁言。
## 群聊踢人
```python
Seg.data: Dict[str, Any] = {
"name": "GROUP_KICK",
"args": {
"qq_id": "用户QQ号",
},
}
```
其中群聊ID将会通过Group_Info.group_id自动获取。
## 戳一戳
```python
Seg.data: Dict[str, Any] = {
"name": "SEND_POKE",
"args": {
"qq_id": "目标QQ号"
}
}
```
## 撤回消息
```python
Seg.data: Dict[str, Any] = {
"name": "DELETE_MSG",
"args": {
"message_id": "消息所对应的message_id"
}
}
```
其中message_id是消息的实际qq_id于新版的mmc中可以从数据库获取如果工作正常的话

View File

@@ -0,0 +1,90 @@
import asyncio
import sys
import json
import websockets as Server
from src.logger import logger
from src.recv_handler.message_handler import message_handler
from src.recv_handler.meta_event_handler import meta_event_handler
from src.recv_handler.notice_handler import notice_handler
from src.recv_handler.message_sending import message_send_instance
from src.send_handler import send_handler
from src.config import global_config
from src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
from src.response_pool import put_response, check_timeout_response
message_queue = asyncio.Queue()
async def message_recv(server_connection: Server.ServerConnection):
await message_handler.set_server_connection(server_connection)
asyncio.create_task(notice_handler.set_server_connection(server_connection))
await send_handler.set_server_connection(server_connection)
async for raw_message in server_connection:
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
post_type = decoded_raw_message.get("post_type")
if post_type in ["meta_event", "message", "notice"]:
await message_queue.put(decoded_raw_message)
elif post_type is None:
await put_response(decoded_raw_message)
async def message_process():
while True:
message = await message_queue.get()
post_type = message.get("post_type")
if post_type == "message":
await message_handler.handle_raw_message(message)
elif post_type == "meta_event":
await meta_event_handler.handle_meta_event(message)
elif post_type == "notice":
await notice_handler.handle_notice(message)
else:
logger.warning(f"未知的post_type: {post_type}")
message_queue.task_done()
await asyncio.sleep(0.05)
async def main():
message_send_instance.maibot_router = router
_ = await asyncio.gather(napcat_server(), mmc_start_com(), message_process(), check_timeout_response())
async def napcat_server():
logger.info("正在启动adapter...")
async with Server.serve(message_recv, global_config.napcat_server.host, global_config.napcat_server.port, max_size=2**26) as server:
logger.info(
f"Adapter已启动监听地址: ws://{global_config.napcat_server.host}:{global_config.napcat_server.port}"
)
await server.serve_forever()
async def graceful_shutdown():
try:
logger.info("正在关闭adapter...")
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
if not task.done():
task.cancel()
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), 15)
await mmc_stop_com() # 后置避免神秘exception
logger.info("Adapter已成功关闭")
except Exception as e:
logger.error(f"Adapter关闭中出现错误: {e}")
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
logger.warning("收到中断信号,正在优雅关闭...")
loop.run_until_complete(graceful_shutdown())
except Exception as e:
logger.exception(f"主程序异常: {str(e)}")
sys.exit(1)
finally:
if loop and not loop.is_closed():
loop.close()
sys.exit(0)

View File

@@ -0,0 +1,44 @@
# Notify Args
```python
Seg.type = "notify"
```
## 群聊成员被禁言
```python
Seg.data: Dict[str, Any] = {
"sub_type": "ban",
"duration": "对应的禁言时间,单位为秒",
"banned_user_info": "被禁言的用户的信息为标准UserInfo转换成的字典"
}
```
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
**注意: `banned_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象**
## 群聊开启全体禁言
```python
Seg.data: Dict[str, Any] = {
"sub_type": "whole_ban",
"duration": -1,
"banned_user_info": None
}
```
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
## 群聊成员被解除禁言
```python
Seg.data: Dict[str, Any] = {
"sub_type": "whole_lift_ban",
"lifted_user_info": "被解除禁言的用户的信息为标准UserInfo对象"
}
```
**对于自然禁言解除的情况,此时`MessageBase.UserInfo`为`None`**
对于手动解除禁言的情况,此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息
**注意: `lifted_user_info`需要自行调用`UserInfo.from_dict()`函数转换为标准UserInfo对象**
## 群聊关闭全体禁言
```python
Seg.data: Dict[str, Any] = {
"sub_type": "whole_lift_ban",
"lifted_user_info": None,
}
```
此时`MessageBase.UserInfo`,即消息的`UserInfo`为操作者(operator)的信息

View File

@@ -0,0 +1,44 @@
[project]
name = "MaiBotNapcatAdapter"
version = "0.4.7"
description = "A MaiBot adapter for Napcat"
[tool.ruff]
include = ["*.py"]
# 行长度设置
line-length = 120
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = ["E711","E501"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
# 使用双引号表示字符串
quote-style = "double"
# 尊重魔法尾随逗号
# 例如:
# items = [
# "apple",
# "banana",
# "cherry",
# ]
skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"

View File

@@ -0,0 +1,10 @@
websockets
aiohttp
asyncio
requests
maim_message
loguru
pillow
tomlkit
rich
sqlmodel

View File

@@ -0,0 +1,24 @@
from enum import Enum
import tomlkit
import os
from .logger import logger
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
AI_VOICE_SEND = "send_group_ai_record" # 发送群AI语音
def __str__(self) -> str:
return self.value
pyproject_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "pyproject.toml")
toml_data = tomlkit.parse(open(pyproject_path, "r", encoding="utf-8").read())
version = toml_data["project"]["version"]
logger.info(f"版本\n\nMaiBot-Napcat-Adapter 版本: {version}\n喜欢的话点个star喵~\n")

View File

@@ -0,0 +1,5 @@
from .config import global_config
__all__ = [
"global_config",
]

View File

@@ -0,0 +1,146 @@
import os
from dataclasses import dataclass
from datetime import datetime
import tomlkit
import shutil
from tomlkit import TOMLDocument
from tomlkit.items import Table
from ..logger import logger
from rich.traceback import install
from src.config.config_base import ConfigBase
from src.config.official_configs import (
ChatConfig,
DebugConfig,
MaiBotServerConfig,
NapcatServerConfig,
NicknameConfig,
VoiceConfig,
)
install(extra_lines=3)
TEMPLATE_DIR = "template"
def update_config():
# 定义文件路径
template_path = f"{TEMPLATE_DIR}/template_config.toml"
old_config_path = "config.toml"
new_config_path = "config.toml"
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
quit()
# 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建备份文件夹
backup_dir = "config_backup"
os.makedirs(backup_dir, exist_ok=True)
# 备份文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(backup_dir, f"config.toml.bak.{timestamp}")
# 备份旧配置文件
shutil.copy2(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
@dataclass
class Config(ConfigBase):
"""总配置类"""
nickname: NicknameConfig
napcat_server: NapcatServerConfig
maibot_server: MaiBotServerConfig
chat: ChatConfig
voice: VoiceConfig
debug: DebugConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 更新配置
update_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path="config.toml")
logger.info("非常的新鲜,非常的美味!")

View File

@@ -0,0 +1,136 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: Dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
field_type = f.type
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
except Exception as e:
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
return field_type.from_dict(value)
field_origin_type = get_origin(field_type)
field_args_type = get_args(field_type)
# 处理泛型集合类型list, set, tuple
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_args_type[0]) for item in value]
if field_origin_type is set:
return {cls._convert_field(item, field_args_type[0]) for item in value}
if field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_args_type):
raise TypeError(
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_args_type) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_args_type
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理Optional类型
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
if value is None:
return None
# 如果有数据,检查实际类型
if type(value) not in field_args_type:
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
return cls._convert_field(value, field_args_type[0])
# 处理int, str, float, bool等基础类型
if field_origin_type is None:
if isinstance(value, field_type):
return field_type(value)
else:
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
# 处理Literal类型
if field_origin_type is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
# 处理其他类型
if field_type is Any:
return value
# 其他类型直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -0,0 +1,80 @@
from dataclasses import dataclass, field
from typing import Literal
from src.config.config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
ADAPTER_PLATFORM = "qq"
@dataclass
class NicknameConfig(ConfigBase):
nickname: str
"""机器人昵称"""
@dataclass
class NapcatServerConfig(ConfigBase):
host: str = "localhost"
"""Napcat服务端的主机地址"""
port: int = 8095
"""Napcat服务端的端口号"""
heartbeat_interval: int = 30
"""Napcat心跳间隔时间单位为秒"""
@dataclass
class MaiBotServerConfig(ConfigBase):
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
"""平台名称“qq”"""
host: str = "localhost"
"""MaiMCore的主机地址"""
port: int = 8000
"""MaiMCore的端口号"""
@dataclass
class ChatConfig(ConfigBase):
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""群聊列表类型 白名单/黑名单"""
group_list: list[int] = field(default_factory=[])
"""群聊列表"""
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""私聊列表类型 白名单/黑名单"""
private_list: list[int] = field(default_factory=[])
"""私聊列表"""
ban_user_id: list[int] = field(default_factory=[])
"""被封禁的用户ID列表封禁后将无法与其进行交互"""
ban_qq_bot: bool = False
"""是否屏蔽QQ官方机器人若为True则所有QQ官方机器人将无法与MaiMCore进行交互"""
enable_poke: bool = True
"""是否启用戳一戳功能"""
@dataclass
class VoiceConfig(ConfigBase):
use_tts: bool = False
"""是否启用TTS功能"""
@dataclass
class DebugConfig(ConfigBase):
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
"""日志级别默认为INFO"""

View File

@@ -0,0 +1,162 @@
import os
from typing import Optional, List
from dataclasses import dataclass
from sqlmodel import Field, Session, SQLModel, create_engine, select
from src.logger import logger
"""
表记录的方式:
| group_id | user_id | lift_time |
|----------|---------|-----------|
其中使用 user_id == 0 表示群全体禁言
"""
@dataclass
class BanUser:
"""
程序处理使用的实例
"""
user_id: int
group_id: int
lift_time: Optional[int] = Field(default=-1)
class DB_BanUser(SQLModel, table=True):
"""
表示数据库中的用户禁言记录。
使用双重主键
"""
user_id: int = Field(index=True, primary_key=True) # 被禁言用户的用户 ID
group_id: int = Field(index=True, primary_key=True) # 用户被禁言的群组 ID
lift_time: Optional[int] # 禁言解除的时间(时间戳)
def is_identical(obj1: BanUser, obj2: BanUser) -> bool:
"""
检查两个 BanUser 对象是否相同。
"""
return obj1.user_id == obj2.user_id and obj1.group_id == obj2.group_id
class DatabaseManager:
"""
数据库管理类,负责与数据库交互。
"""
def __init__(self):
os.makedirs(os.path.join(os.path.dirname(__file__), "..", "data"), exist_ok=True) # 确保数据目录存在
DATABASE_FILE = os.path.join(os.path.dirname(__file__), "..", "data", "NapcatAdapter.db")
self.sqlite_url = f"sqlite:///{DATABASE_FILE}" # SQLite 数据库 URL
self.engine = create_engine(self.sqlite_url, echo=False) # 创建数据库引擎
self._ensure_database() # 确保数据库和表已创建
def _ensure_database(self) -> None:
"""
确保数据库和表已创建。
"""
logger.info("确保数据库文件和表已创建...")
SQLModel.metadata.create_all(self.engine)
logger.success("数据库和表已创建或已存在")
def update_ban_record(self, ban_list: List[BanUser]) -> None:
# sourcery skip: class-extract-method
"""
更新禁言列表到数据库。
支持在不存在时创建新记录,对于多余的项目自动删除。
"""
with Session(self.engine) as session:
all_records = session.exec(select(DB_BanUser)).all()
for ban_user in ban_list:
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_user.user_id, DB_BanUser.group_id == ban_user.group_id
)
if existing_record := session.exec(statement).first():
if existing_record.lift_time == ban_user.lift_time:
logger.debug(f"禁言记录未变更: {existing_record}")
continue
# 更新现有记录的 lift_time
existing_record.lift_time = ban_user.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {existing_record}")
else:
# 创建新记录
db_record = DB_BanUser(
user_id=ban_user.user_id, group_id=ban_user.group_id, lift_time=ban_user.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_user}")
# 删除不在 ban_list 中的记录
for db_record in all_records:
record = BanUser(user_id=db_record.user_id, group_id=db_record.group_id, lift_time=db_record.lift_time)
if not any(is_identical(record, ban_user) for ban_user in ban_list):
statement = select(DB_BanUser).where(
DB_BanUser.user_id == record.user_id, DB_BanUser.group_id == record.group_id
)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
session.commit()
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: {ban_record}")
session.commit()
logger.info("禁言记录已更新")
def get_ban_records(self) -> List[BanUser]:
"""
读取所有禁言记录。
"""
with Session(self.engine) as session:
statement = select(DB_BanUser)
records = session.exec(statement).all()
return [BanUser(user_id=item.user_id, group_id=item.group_id, lift_time=item.lift_time) for item in records]
def create_ban_record(self, ban_record: BanUser) -> None:
"""
为特定群组中的用户创建禁言记录。
一个简化版本的添加方式,防止 update_ban_record 方法的复杂性。
其同时还是简化版的更新方式。
"""
with Session(self.engine) as session:
# 检查记录是否已存在
statement = select(DB_BanUser).where(
DB_BanUser.user_id == ban_record.user_id, DB_BanUser.group_id == ban_record.group_id
)
existing_record = session.exec(statement).first()
if existing_record:
# 如果记录已存在,更新 lift_time
existing_record.lift_time = ban_record.lift_time
session.add(existing_record)
logger.debug(f"更新禁言记录: {ban_record}")
else:
# 如果记录不存在,创建新记录
db_record = DB_BanUser(
user_id=ban_record.user_id, group_id=ban_record.group_id, lift_time=ban_record.lift_time
)
session.add(db_record)
logger.debug(f"创建新禁言记录: {ban_record}")
session.commit()
def delete_ban_record(self, ban_record: BanUser):
"""
删除特定用户在特定群组中的禁言记录。
一个简化版本的删除方式,防止 update_ban_record 方法的复杂性。
"""
user_id = ban_record.user_id
group_id = ban_record.group_id
with Session(self.engine) as session:
statement = select(DB_BanUser).where(DB_BanUser.user_id == user_id, DB_BanUser.group_id == group_id)
if ban_record := session.exec(statement).first():
session.delete(ban_record)
session.commit()
logger.debug(f"删除禁言记录: {ban_record}")
else:
logger.info(f"未找到禁言记录: user_id: {user_id}, group_id: {group_id}")
db_manager = DatabaseManager()

View File

@@ -0,0 +1,21 @@
from loguru import logger
from .config import global_config
import sys
# 默认 logger
logger.remove()
logger.add(
sys.stderr,
level=global_config.debug.level,
format="<blue>{time:YYYY-MM-DD HH:mm:ss}</blue> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
filter=lambda record: "name" not in record["extra"] or record["extra"].get("name") != "maim_message",
)
logger.add(
sys.stderr,
level="INFO",
format="<red>{time:YYYY-MM-DD HH:mm:ss}</red> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
filter=lambda record: record["extra"].get("name") == "maim_message",
)
# 创建样式不同的 logger
custom_logger = logger.bind(name="maim_message")
logger = logger.bind(name="MaiBot-Napcat-Adapter")

View File

@@ -0,0 +1,24 @@
from maim_message import Router, RouteConfig, TargetConfig
from .config import global_config
from .logger import logger, custom_logger
from .send_handler import send_handler
route_config = RouteConfig(
route_config={
global_config.maibot_server.platform_name: TargetConfig(
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
token=None,
)
}
)
router = Router(route_config, custom_logger)
async def mmc_start_com():
logger.info("正在连接MaiBot")
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
await router.stop()

View File

@@ -0,0 +1,87 @@
from enum import Enum
class MetaEventType:
lifecycle = "lifecycle" # 生命周期
class Lifecycle:
connect = "connect" # 生命周期 - WebSocket 连接成功
heartbeat = "heartbeat" # 心跳
class MessageType: # 接受消息大类
private = "private" # 私聊消息
class Private:
friend = "friend" # 私聊消息 - 好友
group = "group" # 私聊消息 - 群临时
group_self = "group_self" # 私聊消息 - 群中自身发送
other = "other" # 私聊消息 - 其他
group = "group" # 群聊消息
class Group:
normal = "normal" # 群聊消息 - 普通
anonymous = "anonymous" # 群聊消息 - 匿名消息
notice = "notice" # 群聊消息 - 系统提示
class NoticeType: # 通知事件
friend_recall = "friend_recall" # 私聊消息撤回
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
group_ban = "group_ban" # 群禁言
class Notify:
poke = "poke" # 戳一戳
class GroupBan:
ban = "ban" # 禁言
lift_ban = "lift_ban" # 解除禁言
class RealMessageType: # 实际消息分类
text = "text" # 纯文本
face = "face" # qq表情
image = "image" # 图片
record = "record" # 语音
video = "video" # 视频
at = "at" # @某人
rps = "rps" # 猜拳魔法表情
dice = "dice" # 骰子
shake = "shake" # 私聊窗口抖动(只收)
poke = "poke" # 群聊戳一戳
share = "share" # 链接分享json形式
reply = "reply" # 回复消息
forward = "forward" # 转发消息
node = "node" # 转发消息节点
class MessageSentType:
private = "private"
class Private:
friend = "friend"
group = "group"
group = "group"
class Group:
normal = "normal"
class CommandType(Enum):
"""命令类型"""
GROUP_BAN = "set_group_ban" # 禁言用户
GROUP_WHOLE_BAN = "set_group_whole_ban" # 群全体禁言
GROUP_KICK = "set_group_kick" # 踢出群聊
SEND_POKE = "send_poke" # 戳一戳
DELETE_MSG = "delete_msg" # 撤回消息
def __str__(self) -> str:
return self.value
ACCEPT_FORMAT = ["text", "image", "emoji", "reply", "voice", "command", "voiceurl", "music", "videourl", "file"]

View File

@@ -0,0 +1,669 @@
from src.logger import logger
from src.config import global_config
from src.utils import (
get_group_info,
get_member_info,
get_image_base64,
get_record_detail,
get_self_info,
get_message_detail,
)
from .qq_emoji_list import qq_face
from .message_sending import message_send_instance
from . import RealMessageType, MessageType, ACCEPT_FORMAT
import time
import json
import websockets as Server
from typing import List, Tuple, Optional, Dict, Any
import uuid
from maim_message import (
UserInfo,
GroupInfo,
Seg,
BaseMessageInfo,
MessageBase,
TemplateInfo,
FormatInfo,
)
from src.response_pool import get_response
class MessageHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
self.bot_id_list: Dict[int, bool] = {}
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
async def check_allow_to_chat(
self,
user_id: int,
group_id: Optional[int] = None,
ignore_bot: Optional[bool] = False,
ignore_global_list: Optional[bool] = False,
) -> bool:
# sourcery skip: hoist-statement-from-if, merge-else-if-into-elif
"""
检查是否允许聊天
Parameters:
user_id: int: 用户ID
group_id: int: 群ID
ignore_bot: bool: 是否忽略机器人检查
ignore_global_list: bool: 是否忽略全局黑名单检查
Returns:
bool: 是否允许聊天
"""
logger.debug(f"群聊id: {group_id}, 用户id: {user_id}")
logger.debug("开始检查聊天白名单/黑名单")
if group_id:
if global_config.chat.group_list_type == "whitelist" and group_id not in global_config.chat.group_list:
logger.warning("群聊不在聊天白名单中,消息被丢弃")
return False
elif global_config.chat.group_list_type == "blacklist" and group_id in global_config.chat.group_list:
logger.warning("群聊在聊天黑名单中,消息被丢弃")
return False
else:
if global_config.chat.private_list_type == "whitelist" and user_id not in global_config.chat.private_list:
logger.warning("私聊不在聊天白名单中,消息被丢弃")
return False
elif global_config.chat.private_list_type == "blacklist" and user_id in global_config.chat.private_list:
logger.warning("私聊在聊天黑名单中,消息被丢弃")
return False
if user_id in global_config.chat.ban_user_id and not ignore_global_list:
logger.warning("用户在全局黑名单中,消息被丢弃")
return False
if global_config.chat.ban_qq_bot and group_id and not ignore_bot:
logger.debug("开始判断是否为机器人")
member_info = await get_member_info(self.server_connection, group_id, user_id)
if member_info:
is_bot = member_info.get("is_robot")
if is_bot is None:
logger.warning("无法获取用户是否为机器人,默认为不是但是不进行更新")
else:
if is_bot:
logger.warning("QQ官方机器人消息拦截已启用消息被丢弃新机器人加入拦截名单")
self.bot_id_list[user_id] = True
return False
else:
self.bot_id_list[user_id] = False
return True
async def handle_raw_message(self, raw_message: dict) -> None:
# sourcery skip: low-code-quality, remove-unreachable-code
"""
从Napcat接受的原始消息处理
Parameters:
raw_message: dict: 原始消息
"""
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
template_info: TemplateInfo = None # 模板信息,暂时为空,等待启用
format_info: FormatInfo = FormatInfo(
content_format=["text", "image", "emoji", "voice"],
accept_format=ACCEPT_FORMAT,
) # 格式化信息
if message_type == MessageType.private:
sub_type = raw_message.get("sub_type")
if sub_type == MessageType.Private.friend:
sender_info: dict = raw_message.get("sender")
if not await self.check_allow_to_chat(sender_info.get("user_id"), None):
return None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
)
# 不存在群信息
group_info: GroupInfo = None
elif sub_type == MessageType.Private.group:
"""
本部分暂时不做支持,先放着
"""
logger.warning("群临时消息类型不支持")
return None
sender_info: dict = raw_message.get("sender")
# 由于临时会话中Napcat默认不发送成员昵称所以需要单独获取
fetched_member_info: dict = await get_member_info(
self.server_connection,
raw_message.get("group_id"),
sender_info.get("user_id"),
)
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=nickname,
user_cardname=None,
)
# -------------------这里需要群信息吗?-------------------
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info: dict = await get_group_info(self.server_connection, raw_message.get("group_id"))
group_name = ""
if fetched_group_info.get("group_name"):
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
else:
logger.warning(f"私聊消息类型 {sub_type} 不支持")
return None
elif message_type == MessageType.group:
sub_type = raw_message.get("sub_type")
if sub_type == MessageType.Group.normal:
sender_info: dict = raw_message.get("sender")
if not await self.check_allow_to_chat(sender_info.get("user_id"), raw_message.get("group_id")):
return None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
)
# 获取群聊相关信息在此单独处理group_name因为默认发送的消息中没有
fetched_group_info = await get_group_info(self.server_connection, raw_message.get("group_id"))
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=raw_message.get("group_id"),
group_name=group_name,
)
else:
logger.warning(f"群聊消息类型 {sub_type} 不支持")
return None
additional_config: dict = {}
if global_config.voice.use_tts:
additional_config["allow_tts"] = True
# 消息信息
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id=message_id,
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=template_info,
format_info=format_info,
additional_config=additional_config,
)
# 处理实际信息
if not raw_message.get("message"):
logger.warning("原始消息内容为空")
return None
# 获取Seg列表
seg_message: List[Seg] = await self.handle_real_message(raw_message)
if not seg_message:
logger.warning("处理后消息内容为空")
return None
submit_seg: Seg = Seg(
type="seglist",
data=seg_message,
)
# MessageBase创建
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=submit_seg,
raw_message=raw_message.get("raw_message"),
)
logger.info("发送到Maibot处理信息")
await message_send_instance.message_send(message_base)
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
# sourcery skip: low-code-quality
"""
处理实际消息
Parameters:
real_message: dict: 实际消息
Returns:
seg_message: list[Seg]: 处理后的消息段列表
"""
real_message: list = raw_message.get("message")
if not real_message:
return None
seg_message: List[Seg] = []
for sub_message in real_message:
sub_message: dict
sub_message_type = sub_message.get("type")
match sub_message_type:
case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("text处理失败")
case RealMessageType.face:
ret_seg = await self.handle_face_message(sub_message)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("face处理失败或不支持")
case RealMessageType.reply:
if not in_reply:
ret_seg = await self.handle_reply_message(sub_message)
if ret_seg:
seg_message += ret_seg
else:
logger.warning("reply处理失败")
case RealMessageType.image:
ret_seg = await self.handle_image_message(sub_message)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("image处理失败")
case RealMessageType.record:
ret_seg = await self.handle_record_message(sub_message)
if ret_seg:
seg_message.clear()
seg_message.append(ret_seg)
break # 使得消息只有record消息
else:
logger.warning("record处理失败或不支持")
case RealMessageType.video:
logger.warning("不支持视频解析")
case RealMessageType.at:
ret_seg = await self.handle_at_message(
sub_message,
raw_message.get("self_id"),
raw_message.get("group_id"),
)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("at处理失败")
case RealMessageType.rps:
logger.warning("暂时不支持猜拳魔法表情解析")
case RealMessageType.dice:
logger.warning("暂时不支持骰子表情解析")
case RealMessageType.shake:
# 预计等价于戳一戳
logger.warning("暂时不支持窗口抖动解析")
case RealMessageType.share:
logger.warning("暂时不支持链接解析")
case RealMessageType.forward:
messages = await self._get_forward_message(sub_message)
if not messages:
logger.warning("转发消息内容为空或获取失败")
return None
ret_seg = await self.handle_forward_message(messages)
if ret_seg:
seg_message.append(ret_seg)
else:
logger.warning("转发消息处理失败")
case RealMessageType.node:
logger.warning("不支持转发消息节点解析")
case _:
logger.warning(f"未知消息类型: {sub_message_type}")
return seg_message
async def handle_text_message(self, raw_message: dict) -> Seg:
"""
处理纯文本信息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
plain_text: str = message_data.get("text")
return Seg(type="text", data=plain_text)
async def handle_face_message(self, raw_message: dict) -> Seg | None:
"""
处理表情消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
face_raw_id: str = str(message_data.get("id"))
if face_raw_id in qq_face:
face_content: str = qq_face.get(face_raw_id)
return Seg(type="text", data=face_content)
else:
logger.warning(f"不支持的表情:{face_raw_id}")
return None
async def handle_image_message(self, raw_message: dict) -> Seg | None:
"""
处理图片消息与表情包消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
image_sub_type = message_data.get("sub_type")
try:
image_base64 = await get_image_base64(message_data.get("url"))
except Exception as e:
logger.error(f"图片消息处理失败: {str(e)}")
return None
if image_sub_type == 0:
"""这部分认为是图片"""
return Seg(type="image", data=image_base64)
elif image_sub_type not in [4, 9]:
"""这部分认为是表情包"""
return Seg(type="emoji", data=image_base64)
else:
logger.warning(f"不支持的图片子类型:{image_sub_type}")
return None
async def handle_at_message(self, raw_message: dict, self_id: int, group_id: int) -> Seg | None:
# sourcery skip: use-named-expression
"""
处理at消息
Parameters:
raw_message: dict: 原始消息
self_id: int: 机器人QQ号
group_id: int: 群号
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
if message_data:
qq_id = message_data.get("qq")
if str(self_id) == str(qq_id):
logger.debug("机器人被at")
self_info: dict = await get_self_info(self.server_connection)
if self_info:
return Seg(type="text", data=f"@<{self_info.get('nickname')}:{self_info.get('user_id')}>")
else:
return None
else:
member_info: dict = await get_member_info(self.server_connection, group_id=group_id, user_id=qq_id)
if member_info:
return Seg(type="text", data=f"@<{member_info.get('nickname')}:{member_info.get('user_id')}>")
else:
return None
async def handle_record_message(self, raw_message: dict) -> Seg | None:
"""
处理语音消息
Parameters:
raw_message: dict: 原始消息
Returns:
seg_data: Seg: 处理后的消息段
"""
message_data: dict = raw_message.get("data")
file: str = message_data.get("file")
if not file:
logger.warning("语音消息缺少文件信息")
return None
try:
record_detail = await get_record_detail(self.server_connection, file)
if not record_detail:
logger.warning("获取语音消息详情失败")
return None
audio_base64: str = record_detail.get("base64")
except Exception as e:
logger.error(f"语音消息处理失败: {str(e)}")
return None
if not audio_base64:
logger.error("语音消息处理失败,未获取到音频数据")
return None
return Seg(type="voice", data=audio_base64)
async def handle_reply_message(self, raw_message: dict) -> List[Seg] | None:
# sourcery skip: move-assign-in-block, use-named-expression
"""
处理回复消息
"""
raw_message_data: dict = raw_message.get("data")
message_id: int = None
if raw_message_data:
message_id = raw_message_data.get("id")
else:
return None
message_detail: dict = await get_message_detail(self.server_connection, message_id)
if not message_detail:
logger.warning("获取被引用的消息详情失败")
return None
reply_message = await self.handle_real_message(message_detail, in_reply=True)
if reply_message is None:
reply_message = "(获取发言内容失败)"
sender_info: dict = message_detail.get("sender")
sender_nickname: str = sender_info.get("nickname")
sender_id: str = sender_info.get("user_id")
seg_message: List[Seg] = []
if not sender_nickname:
logger.warning("无法获取被引用的人的昵称,返回默认值")
seg_message.append(Seg(type="text", data="[回复 未知用户:"))
else:
seg_message.append(Seg(type="text", data=f"[回复<{sender_nickname}:{sender_id}>"))
seg_message += reply_message
seg_message.append(Seg(type="text", data="],说:"))
return seg_message
async def handle_forward_message(self, message_list: list) -> Seg | None:
"""
递归处理转发消息,并按照动态方式确定图片处理方式
Parameters:
message_list: list: 转发消息列表
"""
handled_message, image_count = await self._handle_forward_message(message_list, 0)
handled_message: Seg
image_count: int
if not handled_message:
return None
if image_count < 5 and image_count > 0:
# 处理图片数量小于5的情况此时解析图片为base64
logger.trace("图片数量小于5开始解析图片为base64")
return await self._recursive_parse_image_seg(handled_message, True)
elif image_count > 0:
logger.trace("图片数量大于等于5开始解析图片为占位符")
# 处理图片数量大于等于5的情况此时解析图片为占位符
return await self._recursive_parse_image_seg(handled_message, False)
else:
# 处理没有图片的情况,此时直接返回
logger.trace("没有图片,直接返回")
return handled_message
async def _recursive_parse_image_seg(self, seg_data: Seg, to_image: bool) -> Seg:
# sourcery skip: merge-else-if-into-elif
if to_image:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
image_url = seg_data.data
try:
encoded_image = await get_image_base64(image_url)
except Exception as e:
logger.error(f"图片处理失败: {str(e)}")
return Seg(type="text", data="[图片]")
return Seg(type="image", data=encoded_image)
elif seg_data.type == "emoji":
image_url = seg_data.data
try:
encoded_image = await get_image_base64(image_url)
except Exception as e:
logger.error(f"图片处理失败: {str(e)}")
return Seg(type="text", data="[表情包]")
return Seg(type="emoji", data=encoded_image)
else:
logger.trace(f"不处理类型: {seg_data.type}")
return seg_data
else:
if seg_data.type == "seglist":
new_seg_list = []
for i_seg in seg_data.data:
parsed_seg = await self._recursive_parse_image_seg(i_seg, to_image)
new_seg_list.append(parsed_seg)
return Seg(type="seglist", data=new_seg_list)
elif seg_data.type == "image":
return Seg(type="text", data="[图片]")
elif seg_data.type == "emoji":
return Seg(type="text", data="[动画表情]")
else:
logger.trace(f"不处理类型: {seg_data.type}")
return seg_data
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
# sourcery skip: low-code-quality
"""
递归处理实际转发消息
Parameters:
message_list: list: 转发消息列表首层对应messages字段后面对应content字段
layer: int: 当前层级
Returns:
seg_data: Seg: 处理后的消息段
image_count: int: 图片数量
"""
seg_list: List[Seg] = []
image_count = 0
if message_list is None:
return None, 0
for sub_message in message_list:
sub_message: dict
sender_info: dict = sub_message.get("sender")
user_nickname: str = sender_info.get("nickname", "QQ用户")
user_nickname_str = f"{user_nickname}】:"
break_seg = Seg(type="text", data="\n")
message_of_sub_message_list: List[Dict[str, Any]] = sub_message.get("message")
if not message_of_sub_message_list:
logger.warning("转发消息内容为空")
continue
message_of_sub_message = message_of_sub_message_list[0]
if message_of_sub_message.get("type") == RealMessageType.forward:
if layer >= 3:
full_seg_data = Seg(
type="text",
data=("--" * layer) + f"{user_nickname}】:【转发消息】\n",
)
else:
sub_message_data = message_of_sub_message.get("data")
if not sub_message_data:
continue
contents = sub_message_data.get("content")
seg_data, count = await self._handle_forward_message(contents, layer + 1)
image_count += count
head_tip = Seg(
type="text",
data=("--" * layer) + f"{user_nickname}】: 合并转发消息内容:\n",
)
full_seg_data = Seg(type="seglist", data=[head_tip, seg_data])
seg_list.append(full_seg_data)
elif message_of_sub_message.get("type") == RealMessageType.text:
sub_message_data = message_of_sub_message.get("data")
if not sub_message_data:
continue
text_message = sub_message_data.get("text")
seg_data = Seg(type="text", data=text_message)
data_list: List[Any] = []
if layer > 0:
data_list = [
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
]
else:
data_list = [
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
]
seg_list.append(Seg(type="seglist", data=data_list))
elif message_of_sub_message.get("type") == RealMessageType.image:
image_count += 1
image_data = message_of_sub_message.get("data")
sub_type = image_data.get("sub_type")
image_url = image_data.get("url")
data_list: List[Any] = []
if sub_type == 0:
seg_data = Seg(type="image", data=image_url)
else:
seg_data = Seg(type="emoji", data=image_url)
if layer > 0:
data_list = [
Seg(type="text", data=("--" * layer) + user_nickname_str),
seg_data,
break_seg,
]
else:
data_list = [
Seg(type="text", data=user_nickname_str),
seg_data,
break_seg,
]
full_seg_data = Seg(type="seglist", data=data_list)
seg_list.append(full_seg_data)
return Seg(type="seglist", data=seg_list), image_count
async def _get_forward_message(self, raw_message: dict) -> Dict[str, Any] | None:
forward_message_data: Dict = raw_message.get("data")
if not forward_message_data:
logger.warning("转发消息内容为空")
return None
forward_message_id = forward_message_data.get("id")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_forward_msg",
"params": {"message_id": forward_message_id},
"echo": request_uuid,
}
)
try:
await self.server_connection.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error("获取转发消息超时")
return None
except Exception as e:
logger.error(f"获取转发消息失败: {str(e)}")
return None
logger.debug(
f"转发消息原始格式:{json.dumps(response)[:80]}..."
if len(json.dumps(response)) > 80
else json.dumps(response)
)
response_data: Dict = response.get("data")
if not response_data:
logger.warning("转发消息内容为空或获取失败")
return None
return response_data.get("messages")
message_handler = MessageHandler()

View File

@@ -0,0 +1,31 @@
from src.logger import logger
from maim_message import MessageBase, Router
class MessageSending:
"""
负责把消息发送到麦麦
"""
maibot_router: Router = None
def __init__(self):
pass
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息
Parameters:
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
"""
try:
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("可能是路由未正确配置或连接异常")
return send_status
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
message_send_instance = MessageSending()

View File

@@ -0,0 +1,49 @@
from src.logger import logger
from src.config import global_config
import time
import asyncio
from . import MetaEventType
class MetaEventHandler:
"""
处理Meta事件
"""
def __init__(self):
self.interval = global_config.napcat_server.heartbeat_interval
self._interval_checking = False
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")
if event_type == MetaEventType.lifecycle:
sub_type = message.get("sub_type")
if sub_type == MetaEventType.Lifecycle.connect:
self_id = message.get("self_id")
self.last_heart_beat = time.time()
logger.success(f"Bot {self_id} 连接成功")
asyncio.create_task(self.check_heartbeat(self_id))
elif event_type == MetaEventType.heartbeat:
if message["status"].get("online") and message["status"].get("good"):
if not self._interval_checking:
asyncio.create_task(self.check_heartbeat())
self.last_heart_beat = time.time()
self.interval = message.get("interval") / 1000
else:
self_id = message.get("self_id")
logger.warning(f"Bot {self_id} Napcat 端异常!")
async def check_heartbeat(self, id: int) -> None:
self._interval_checking = True
while True:
now_time = time.time()
if now_time - self.last_heart_beat > self.interval * 2:
logger.error(f"Bot {id} 可能发生了连接断开被下线或者Napcat卡死")
break
else:
logger.debug("心跳正常")
await asyncio.sleep(self.interval)
meta_event_handler = MetaEventHandler()

View File

@@ -0,0 +1,516 @@
import time
import json
import asyncio
import websockets as Server
from typing import Tuple, Optional
from src.logger import logger
from src.config import global_config
from src.database import BanUser, db_manager, is_identical
from . import NoticeType, ACCEPT_FORMAT
from .message_sending import message_send_instance
from .message_handler import message_handler
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
from src.utils import (
get_group_info,
get_member_info,
get_self_info,
get_stranger_info,
read_ban_list,
)
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
class NoticeHandler:
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
lifted_list: list[BanUser] = [] # 已经自然解除禁言
def __init__(self):
self.server_connection: Server.ServerConnection = None
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
while self.server_connection.state != Server.State.OPEN:
await asyncio.sleep(0.5)
self.banned_list, self.lifted_list = await read_ban_list(self.server_connection)
asyncio.create_task(self.auto_lift_detect())
asyncio.create_task(self.send_notice())
asyncio.create_task(self.handle_natural_lift())
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
"""
将用户禁言记录添加到self.banned_list中
如果是全体禁言则user_id为0
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
lift_time = -1
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
for record in self.banned_list:
if is_identical(record, ban_record):
self.banned_list.remove(record)
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 作为更新
return
self.banned_list.append(ban_record)
db_manager.create_ban_record(ban_record) # 添加到数据库
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
"""
从self.lifted_group_list中移除已经解除全体禁言的群
"""
if user_id is None:
user_id = 0 # 使用0表示全体禁言
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
self.lifted_list.append(ban_record)
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
async def handle_notice(self, raw_message: dict) -> None:
notice_type = raw_message.get("notice_type")
# message_time: int = raw_message.get("time")
message_time: float = time.time() # 应可乐要求现在是float了
group_id = raw_message.get("group_id")
user_id = raw_message.get("user_id")
target_id = raw_message.get("target_id")
handled_message: Seg = None
user_info: UserInfo = None
system_notice: bool = False
match notice_type:
case NoticeType.friend_recall:
logger.info("好友撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.group_recall:
logger.info("群内用户撤回一条消息")
logger.info(f"撤回消息ID{raw_message.get('message_id')}, 撤回时间:{raw_message.get('time')}")
logger.warning("暂时不支持撤回消息处理")
case NoticeType.notify:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if global_config.chat.enable_poke and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
logger.info("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_ban:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.GroupBan.ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理群禁言")
handled_message, user_info = await self.handle_ban_notify(raw_message, group_id)
system_notice = True
case NoticeType.GroupBan.lift_ban:
if not await message_handler.check_allow_to_chat(user_id, group_id, True, False):
return None
logger.info("处理解除群禁言")
handled_message, user_info = await self.handle_lift_ban_notify(raw_message, group_id)
system_notice = True
case _:
logger.warning(f"不支持的group_ban类型: {notice_type}.{sub_type}")
case _:
logger.warning(f"不支持的notice类型: {notice_type}")
return None
if not handled_message or not user_info:
logger.warning("notice处理失败或不支持")
return None
group_info: GroupInfo = None
if group_id:
fetched_group_info = await get_group_info(self.server_connection, group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id="notice",
time=message_time,
user_info=user_info,
group_info=group_info,
template_info=None,
format_info=FormatInfo(
content_format=["text", "notify"],
accept_format=ACCEPT_FORMAT,
),
additional_config={"target_id": target_id}, # 在这里塞了一个target_id方便mmc那边知道被戳的人是谁
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=handled_message,
raw_message=json.dumps(raw_message),
)
if system_notice:
await self.put_notice(message_base)
else:
logger.info("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
async def handle_poke_notify(
self, raw_message: dict, group_id: int, user_id: int
) -> Tuple[Seg | None, UserInfo | None]:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, remove-redundant-if, remove-unnecessary-else, swap-if-else-branches
self_info: dict = await get_self_info(self.server_connection)
if not self_info:
logger.error("自身信息获取失败")
return None, None
self_id = raw_message.get("self_id")
target_id = raw_message.get("target_id")
target_name: str = None
raw_info: list = raw_message.get("raw_info")
if group_id:
user_qq_info: dict = await get_member_info(self.server_connection, group_id, user_id)
else:
user_qq_info: dict = await get_stranger_info(self.server_connection, user_id)
if user_qq_info:
user_name = user_qq_info.get("nickname")
user_cardname = user_qq_info.get("card")
else:
user_name = "QQ用户"
user_cardname = "QQ用户"
logger.info("无法获取戳一戳对方的用户昵称")
# 计算Seg
if self_id == target_id:
display_name = ""
target_name = self_info.get("nickname")
elif self_id == user_id:
# 让ada不发送麦麦戳别人的消息
return None, None
else:
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
if group_id:
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, target_id)
if fetched_member_info:
target_name = fetched_member_info.get("nickname")
else:
target_name = "QQ用户"
logger.info("无法获取被戳一戳方的用户昵称")
display_name = user_name
else:
return None, None
first_txt: str = "戳了戳"
second_txt: str = ""
try:
first_txt = raw_info[2].get("txt", "戳了戳")
second_txt = raw_info[4].get("txt", "")
except Exception as e:
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
)
seg_data: Seg = Seg(
type="text",
data=f"{display_name}{first_txt}{target_name}{second_txt}这是QQ的一个功能用于提及某人但没那么明显",
)
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.server_connection, group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
user_id = raw_message.get("user_id")
banned_user_info: UserInfo = None
user_nickname: str = "QQ用户"
user_cardname: str = None
sub_type: str = None
duration = raw_message.get("duration")
if duration is None:
logger.error("禁言时长不能为空,无法处理禁言通知")
return None, None
if user_id == 0: # 为全体禁言
sub_type: str = "whole_ban"
self._ban_operation(group_id)
else: # 为单人禁言
# 获取被禁言人的信息
sub_type: str = "ban"
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
banned_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._ban_operation(group_id, user_id, int(time.time() + duration))
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"duration": duration,
"banned_user_info": banned_user_info.to_dict() if banned_user_info else None,
},
)
return seg_data, operator_info
async def handle_lift_ban_notify(
self, raw_message: dict, group_id: int
) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None, None
# 计算user_info
operator_id = raw_message.get("operator_id")
operator_nickname: str = None
operator_cardname: str = None
member_info: dict = await get_member_info(self.server_connection, group_id, operator_id)
if member_info:
operator_nickname = member_info.get("nickname")
operator_cardname = member_info.get("card")
else:
logger.warning("无法获取解除禁言执行者的昵称,消息可能会无效")
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
)
# 计算Seg
sub_type: str = None
user_nickname: str = "QQ用户"
user_cardname: str = None
lifted_user_info: UserInfo = None
user_id = raw_message.get("user_id")
if user_id == 0: # 全体禁言解除
sub_type = "whole_lift_ban"
self._lift_operation(group_id)
else: # 单人禁言解除
sub_type = "lift_ban"
# 获取被解除禁言人的信息
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
else:
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
self._lift_operation(group_id, user_id)
seg_data: Seg = Seg(
type="notify",
data={
"sub_type": sub_type,
"lifted_user_info": lifted_user_info.to_dict() if lifted_user_info else None,
},
)
return seg_data, operator_info
async def put_notice(self, message_base: MessageBase) -> None:
"""
将处理后的通知消息放入通知队列
"""
if notice_queue.full() or unsuccessful_notice_queue.full():
logger.warning("通知队列已满,可能是多次发送失败,消息丢弃")
else:
await notice_queue.put(message_base)
async def handle_natural_lift(self) -> None:
while True:
if len(self.lifted_list) != 0:
lift_record = self.lifted_list.pop()
group_id = lift_record.group_id
user_id = lift_record.user_id
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
seg_message: Seg = await self.natural_lift(group_id, user_id)
fetched_group_info = await get_group_info(self.server_connection, group_id)
group_name: str = None
if fetched_group_info:
group_name = fetched_group_info.get("group_name")
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
message_id="notice",
time=time.time(),
user_info=None, # 自然解除禁言没有操作者
group_info=group_info,
template_info=None,
format_info=None,
)
message_base: MessageBase = MessageBase(
message_info=message_info,
message_segment=seg_message,
raw_message=json.dumps(
{
"post_type": "notice",
"notice_type": "group_ban",
"sub_type": "lift_ban",
"group_id": group_id,
"user_id": user_id,
"operator_id": None, # 自然解除禁言没有操作者
}
),
)
await self.put_notice(message_base)
await asyncio.sleep(0.5) # 确保队列处理间隔
else:
await asyncio.sleep(5) # 每5秒检查一次
async def natural_lift(self, group_id: int, user_id: int) -> Seg | None:
if not group_id:
logger.error("群ID不能为空无法处理解除禁言通知")
return None
if user_id == 0: # 理论上永远不会触发
return Seg(
type="notify",
data={
"sub_type": "whole_lift_ban",
"lifted_user_info": None,
},
)
user_nickname: str = "QQ用户"
user_cardname: str = None
fetched_member_info: dict = await get_member_info(self.server_connection, group_id, user_id)
if fetched_member_info:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
)
return Seg(
type="notify",
data={
"sub_type": "lift_ban",
"lifted_user_info": lifted_user_info.to_dict(),
},
)
async def auto_lift_detect(self) -> None:
while True:
if len(self.banned_list) == 0:
await asyncio.sleep(5)
continue
for ban_record in self.banned_list:
if ban_record.user_id == 0 or ban_record.lift_time == -1:
continue
if ban_record.lift_time <= int(time.time()):
# 触发自然解除禁言
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
self.lifted_list.append(ban_record)
self.banned_list.remove(ban_record)
await asyncio.sleep(5)
async def send_notice(self) -> None:
"""
发送通知消息到Napcat
"""
while True:
if not unsuccessful_notice_queue.empty():
to_be_send: MessageBase = await unsuccessful_notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
unsuccessful_notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
continue
to_be_send: MessageBase = await notice_queue.get()
try:
send_status = await message_send_instance.message_send(to_be_send)
if send_status:
notice_queue.task_done()
else:
await unsuccessful_notice_queue.put(to_be_send)
except Exception as e:
logger.error(f"发送通知消息失败: {str(e)}")
await unsuccessful_notice_queue.put(to_be_send)
await asyncio.sleep(1)
notice_handler = NoticeHandler()

View File

@@ -0,0 +1,250 @@
qq_face: dict = {
"0": "[表情:惊讶]",
"1": "[表情:撇嘴]",
"2": "[表情:色]",
"3": "[表情:发呆]",
"4": "[表情:得意]",
"5": "[表情:流泪]",
"6": "[表情:害羞]",
"7": "[表情:闭嘴]",
"8": "[表情:睡]",
"9": "[表情:大哭]",
"10": "[表情:尴尬]",
"11": "[表情:发怒]",
"12": "[表情:调皮]",
"13": "[表情:呲牙]",
"14": "[表情:微笑]",
"15": "[表情:难过]",
"16": "[表情:酷]",
"18": "[表情:抓狂]",
"19": "[表情:吐]",
"20": "[表情:偷笑]",
"21": "[表情:可爱]",
"22": "[表情:白眼]",
"23": "[表情:傲慢]",
"24": "[表情:饥饿]",
"25": "[表情:困]",
"26": "[表情:惊恐]",
"27": "[表情:流汗]",
"28": "[表情:憨笑]",
"29": "[表情:悠闲]",
"30": "[表情:奋斗]",
"31": "[表情:咒骂]",
"32": "[表情:疑问]",
"33": "[表情: 嘘]",
"34": "[表情:晕]",
"35": "[表情:折磨]",
"36": "[表情:衰]",
"37": "[表情:骷髅]",
"38": "[表情:敲打]",
"39": "[表情:再见]",
"41": "[表情:发抖]",
"42": "[表情:爱情]",
"43": "[表情:跳跳]",
"46": "[表情:猪头]",
"49": "[表情:拥抱]",
"53": "[表情:蛋糕]",
"56": "[表情:刀]",
"59": "[表情:便便]",
"60": "[表情:咖啡]",
"63": "[表情:玫瑰]",
"64": "[表情:凋谢]",
"66": "[表情:爱心]",
"67": "[表情:心碎]",
"74": "[表情:太阳]",
"75": "[表情:月亮]",
"76": "[表情:赞]",
"77": "[表情:踩]",
"78": "[表情:握手]",
"79": "[表情:胜利]",
"85": "[表情:飞吻]",
"86": "[表情:怄火]",
"89": "[表情:西瓜]",
"96": "[表情:冷汗]",
"97": "[表情:擦汗]",
"98": "[表情:抠鼻]",
"99": "[表情:鼓掌]",
"100": "[表情:糗大了]",
"101": "[表情:坏笑]",
"102": "[表情:左哼哼]",
"103": "[表情:右哼哼]",
"104": "[表情:哈欠]",
"105": "[表情:鄙视]",
"106": "[表情:委屈]",
"107": "[表情:快哭了]",
"108": "[表情:阴险]",
"109": "[表情:左亲亲]",
"110": "[表情:吓]",
"111": "[表情:可怜]",
"112": "[表情:菜刀]",
"114": "[表情:篮球]",
"116": "[表情:示爱]",
"118": "[表情:抱拳]",
"119": "[表情:勾引]",
"120": "[表情:拳头]",
"121": "[表情:差劲]",
"123": "[表情NO]",
"124": "[表情OK]",
"125": "[表情:转圈]",
"129": "[表情:挥手]",
"137": "[表情:鞭炮]",
"144": "[表情:喝彩]",
"146": "[表情:爆筋]",
"147": "[表情:棒棒糖]",
"169": "[表情:手枪]",
"171": "[表情:茶]",
"172": "[表情:眨眼睛]",
"173": "[表情:泪奔]",
"174": "[表情:无奈]",
"175": "[表情:卖萌]",
"176": "[表情:小纠结]",
"177": "[表情:喷血]",
"178": "[表情:斜眼笑]",
"179": "[表情doge]",
"181": "[表情:戳一戳]",
"182": "[表情:笑哭]",
"183": "[表情:我最美]",
"185": "[表情:羊驼]",
"187": "[表情:幽灵]",
"201": "[表情:点赞]",
"212": "[表情:托腮]",
"262": "[表情:脑阔疼]",
"263": "[表情:沧桑]",
"264": "[表情:捂脸]",
"265": "[表情:辣眼睛]",
"266": "[表情:哦哟]",
"267": "[表情:头秃]",
"268": "[表情:问号脸]",
"269": "[表情:暗中观察]",
"270": "[表情emm]",
"271": "[表情:吃 瓜]",
"272": "[表情:呵呵哒]",
"273": "[表情:我酸了]",
"277": "[表情:汪汪]",
"281": "[表情:无眼笑]",
"282": "[表情:敬礼]",
"283": "[表情:狂笑]",
"284": "[表情:面无表情]",
"285": "[表情:摸鱼]",
"286": "[表情:魔鬼笑]",
"287": "[表情:哦]",
"289": "[表情:睁眼]",
"293": "[表情:摸锦鲤]",
"294": "[表情:期待]",
"295": "[表情:拿到红包]",
"297": "[表情:拜谢]",
"298": "[表情:元宝]",
"299": "[表情:牛啊]",
"300": "[表情:胖三斤]",
"302": "[表情:左拜年]",
"303": "[表情:右拜年]",
"305": "[表情:右亲亲]",
"306": "[表情:牛气冲天]",
"307": "[表情:喵喵]",
"311": "[表情打call]",
"312": "[表情:变形]",
"314": "[表情:仔细分析]",
"317": "[表情:菜汪]",
"318": "[表情:崇拜]",
"319": "[表情: 比心]",
"320": "[表情:庆祝]",
"323": "[表情:嫌弃]",
"324": "[表情:吃糖]",
"325": "[表情:惊吓]",
"326": "[表情:生气]",
"332": "[表情:举牌牌]",
"333": "[表情:烟花]",
"334": "[表情:虎虎生威]",
"336": "[表情:豹富]",
"337": "[表情:花朵脸]",
"338": "[表情:我想开了]",
"339": "[表情:舔屏]",
"341": "[表情:打招呼]",
"342": "[表情酸Q]",
"343": "[表情:我方了]",
"344": "[表情:大怨种]",
"345": "[表情:红包多多]",
"346": "[表情:你真棒棒]",
"347": "[表情:大展宏兔]",
"349": "[表情:坚强]",
"350": "[表情:贴贴]",
"351": "[表情:敲敲]",
"352": "[表情:咦]",
"353": "[表情:拜托]",
"354": "[表情:尊嘟假嘟]",
"355": "[表情:耶]",
"356": "[表情666]",
"357": "[表情:裂开]",
"392": "[表情:龙年 快乐]",
"393": "[表情:新年中龙]",
"394": "[表情:新年大龙]",
"395": "[表情:略略略]",
"😊": "[表情:嘿嘿]",
"😌": "[表情:羞涩]",
"😚": "[ 表情:亲亲]",
"😓": "[表情:汗]",
"😰": "[表情:紧张]",
"😝": "[表情:吐舌]",
"😁": "[表情:呲牙]",
"😜": "[表情:淘气]",
"": "[表情:可爱]",
"😍": "[表情:花痴]",
"😔": "[表情:失落]",
"😄": "[表情:高兴]",
"😏": "[表情:哼哼]",
"😒": "[表情:不屑]",
"😳": "[表情:瞪眼]",
"😘": "[表情:飞吻]",
"😭": "[表情:大哭]",
"😱": "[表情:害怕]",
"😂": "[表情:激动]",
"💪": "[表情:肌肉]",
"👊": "[表情:拳头]",
"👍": "[表情 :厉害]",
"👏": "[表情:鼓掌]",
"👎": "[表情:鄙视]",
"🙏": "[表情:合十]",
"👌": "[表情:好的]",
"👆": "[表情:向上]",
"👀": "[表情:眼睛]",
"🍜": "[表情:拉面]",
"🍧": "[表情:刨冰]",
"🍞": "[表情:面包]",
"🍺": "[表情:啤酒]",
"🍻": "[表情:干杯]",
"": "[表情:咖啡]",
"🍎": "[表情:苹果]",
"🍓": "[表情:草莓]",
"🍉": "[表情:西瓜]",
"🚬": "[表情:吸烟]",
"🌹": "[表情:玫瑰]",
"🎉": "[表情:庆祝]",
"💝": "[表情:礼物]",
"💣": "[表情:炸弹]",
"": "[表情:闪光]",
"💨": "[表情:吹气]",
"💦": "[表情:水]",
"🔥": "[表情:火]",
"💤": "[表情:睡觉]",
"💩": "[表情:便便]",
"💉": "[表情:打针]",
"📫": "[表情:邮箱]",
"🐎": "[表情:骑马]",
"👧": "[表情:女孩]",
"👦": "[表情:男孩]",
"🐵": "[表情:猴]",
"🐷": "[表情:猪]",
"🐮": "[表情:牛]",
"🐔": "[表情:公鸡]",
"🐸": "[表情:青蛙]",
"👻": "[表情:幽灵]",
"🐛": "[表情:虫]",
"🐶": "[表情:狗]",
"🐳": "[表情:鲸鱼]",
"👢": "[表情:靴子]",
"": "[表情:晴天]",
"": "[表情:问号]",
"🔫": "[表情:手枪]",
"💓": "[表情:爱 心]",
"🏪": "[表情:便利店]",
}

View File

@@ -0,0 +1,44 @@
import asyncio
import time
from typing import Dict
from .config import global_config
from .logger import logger
response_dict: Dict = {}
response_time_dict: Dict = {}
async def get_response(request_id: str, timeout: int = 10) -> dict:
response = await asyncio.wait_for(_get_response(request_id), timeout)
_ = response_time_dict.pop(request_id)
logger.trace(f"响应信息id: {request_id} 已从响应字典中取出")
return response
async def _get_response(request_id: str) -> dict:
"""
内部使用的获取响应函数,主要用于在需要时获取响应
"""
while request_id not in response_dict:
await asyncio.sleep(0.2)
return response_dict.pop(request_id)
async def put_response(response: dict):
echo_id = response.get("echo")
now_time = time.time()
response_dict[echo_id] = response
response_time_dict[echo_id] = now_time
logger.trace(f"响应信息id: {echo_id} 已存入响应字典")
async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)

View File

@@ -0,0 +1,461 @@
import json
import websockets as Server
import uuid
from maim_message import (
UserInfo,
GroupInfo,
Seg,
BaseMessageInfo,
MessageBase,
)
from typing import Dict, Any, Tuple
from . import CommandType
from .config import global_config
from .response_pool import get_response
from .logger import logger
from .utils import get_image_format, convert_image_to_gif
from .recv_handler.message_sending import message_send_instance
class SendHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
self.server_connection = server_connection
async def handle_message(self, raw_message_base_dict: dict) -> None:
raw_message_base: MessageBase = MessageBase.from_dict(raw_message_base_dict)
message_segment: Seg = raw_message_base.message_segment
logger.info("接收到来自MaiBot的消息处理中")
if message_segment.type == "command":
return await self.send_command(raw_message_base)
else:
return await self.send_normal_message(raw_message_base)
async def send_normal_message(self, raw_message_base: MessageBase) -> None:
"""
处理普通消息发送
"""
logger.info("处理普通信息中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: GroupInfo = message_info.group_info
user_info: UserInfo = message_info.user_info
target_id: int = None
action: str = None
id_name: str = None
processed_message: list = []
try:
processed_message = await self.handle_seg_recursive(message_segment)
except Exception as e:
logger.error(f"处理消息时发生错误: {e}")
return
if not processed_message:
logger.critical("现在暂时不支持解析此回复!")
return None
if group_info and user_info:
logger.debug("发送群聊消息")
target_id = group_info.group_id
action = "send_group_msg"
id_name = "group_id"
elif user_info:
logger.debug("发送私聊消息")
target_id = user_info.user_id
action = "send_private_msg"
id_name = "user_id"
else:
logger.error("无法识别的消息类型")
return
logger.info("尝试发送到napcat")
response = await self.send_message_to_napcat(
action,
{
id_name: target_id,
"message": processed_message,
},
)
if response.get("status") == "ok":
logger.info("消息发送成功")
qq_message_id = response.get("data", {}).get("message_id")
await self.message_sent_back(raw_message_base, qq_message_id)
else:
logger.warning(f"消息发送失败napcat返回{str(response)}")
async def send_command(self, raw_message_base: MessageBase) -> None:
"""
处理命令类
"""
logger.info("处理命令中")
message_info: BaseMessageInfo = raw_message_base.message_info
message_segment: Seg = raw_message_base.message_segment
group_info: GroupInfo = message_info.group_info
seg_data: Dict[str, Any] = message_segment.data
command_name: str = seg_data.get("name")
try:
match command_name:
case CommandType.GROUP_BAN.name:
command, args_dict = self.handle_ban_command(seg_data.get("args"), group_info)
case CommandType.GROUP_WHOLE_BAN.name:
command, args_dict = self.handle_whole_ban_command(seg_data.get("args"), group_info)
case CommandType.GROUP_KICK.name:
command, args_dict = self.handle_kick_command(seg_data.get("args"), group_info)
case CommandType.SEND_POKE.name:
command, args_dict = self.handle_poke_command(seg_data.get("args"), group_info)
case CommandType.DELETE_MSG.name:
command, args_dict = self.delete_msg_command(seg_data.get("args"))
case CommandType.AI_VOICE_SEND.name:
command, args_dict = self.handle_ai_voice_send_command(seg_data.get("args"), group_info)
case _:
logger.error(f"未知命令: {command_name}")
return
except Exception as e:
logger.error(f"处理命令时发生错误: {e}")
return None
if not command or not args_dict:
logger.error("命令或参数缺失")
return None
response = await self.send_message_to_napcat(command, args_dict)
if response.get("status") == "ok":
logger.info(f"命令 {command_name} 执行成功")
else:
logger.warning(f"命令 {command_name} 执行失败napcat返回{str(response)}")
def get_level(self, seg_data: Seg) -> int:
if seg_data.type == "seglist":
return 1 + max(self.get_level(seg) for seg in seg_data.data)
else:
return 1
async def handle_seg_recursive(self, seg_data: Seg) -> list:
payload: list = []
if seg_data.type == "seglist":
# level = self.get_level(seg_data) # 给以后可能的多层嵌套做准备,此处不使用
if not seg_data.data:
return []
for seg in seg_data.data:
payload = self.process_message_by_type(seg, payload)
else:
payload = self.process_message_by_type(seg_data, payload)
return payload
def process_message_by_type(self, seg: Seg, payload: list) -> list:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
new_payload = payload
if seg.type == "reply":
target_id = seg.data
if target_id == "notice":
return payload
new_payload = self.build_payload(payload, self.handle_reply_message(target_id), True)
elif seg.type == "text":
text = seg.data
if not text:
return payload
new_payload = self.build_payload(payload, self.handle_text_message(text), False)
elif seg.type == "face":
logger.warning("MaiBot 发送了qq原生表情暂时不支持")
elif seg.type == "image":
image = seg.data
new_payload = self.build_payload(payload, self.handle_image_message(image), False)
elif seg.type == "emoji":
emoji = seg.data
new_payload = self.build_payload(payload, self.handle_emoji_message(emoji), False)
elif seg.type == "voice":
voice = seg.data
new_payload = self.build_payload(payload, self.handle_voice_message(voice), False)
elif seg.type == "voiceurl":
voice_url = seg.data
new_payload = self.build_payload(payload, self.handle_voiceurl_message(voice_url), False)
elif seg.type == "music":
song_id = seg.data
new_payload = self.build_payload(payload, self.handle_music_message(song_id), False)
elif seg.type == "videourl":
video_url = seg.data
new_payload = self.build_payload(payload, self.handle_videourl_message(video_url), False)
elif seg.type == "file":
file_path = seg.data
new_payload = self.build_payload(payload, self.handle_file_message(file_path), False)
return new_payload
def build_payload(self, payload: list, addon: dict, is_reply: bool = False) -> list:
# sourcery skip: for-append-to-extend, merge-list-append, simplify-generator
"""构建发送的消息体"""
if is_reply:
temp_list = []
temp_list.append(addon)
for i in payload:
if i.get("type") == "reply":
logger.debug("检测到多个回复,使用最新的回复")
continue
temp_list.append(i)
return temp_list
else:
payload.append(addon)
return payload
def handle_reply_message(self, id: str) -> dict:
"""处理回复消息"""
return {"type": "reply", "data": {"id": id}}
def handle_text_message(self, message: str) -> dict:
"""处理文本消息"""
return {"type": "text", "data": {"text": message}}
def handle_image_message(self, encoded_image: str) -> dict:
"""处理图片消息"""
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 0,
},
} # base64 编码的图片
def handle_emoji_message(self, encoded_emoji: str) -> dict:
"""处理表情消息"""
encoded_image = encoded_emoji
image_format = get_image_format(encoded_emoji)
if image_format != "gif":
encoded_image = convert_image_to_gif(encoded_emoji)
return {
"type": "image",
"data": {
"file": f"base64://{encoded_image}",
"subtype": 1,
"summary": "[动画表情]",
},
}
def handle_voice_message(self, encoded_voice: str) -> dict:
"""处理语音消息"""
if not global_config.voice.use_tts:
logger.warning("未启用语音消息处理")
return {}
if not encoded_voice:
return {}
return {
"type": "record",
"data": {"file": f"base64://{encoded_voice}"},
}
def handle_voiceurl_message(self, voice_url: str) -> dict:
"""处理语音链接消息"""
return {
"type": "record",
"data": {"file": voice_url},
}
def handle_music_message(self, song_id: str) -> dict:
"""处理音乐消息"""
return {
"type": "music",
"data": {"type": "163", "id": song_id},
}
def handle_videourl_message(self, video_url: str) -> dict:
"""处理视频链接消息"""
return {
"type": "video",
"data": {"file": video_url},
}
def handle_file_message(self, file_path: str) -> dict:
"""处理文件消息"""
return {
"type": "file",
"data": {"file": f"file://{file_path}"},
}
def handle_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理封禁命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
duration: int = int(args["duration"])
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if duration < 0:
raise ValueError("封禁时间必须大于等于0")
if not user_id or not group_id:
raise ValueError("封禁命令缺少必要参数")
if duration > 2592000:
raise ValueError("封禁时间不能超过30天")
return (
CommandType.GROUP_BAN.value,
{
"group_id": group_id,
"user_id": user_id,
"duration": duration,
},
)
def handle_whole_ban_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理全体禁言命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
enable = args["enable"]
assert isinstance(enable, bool), "enable参数必须是布尔值"
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
return (
CommandType.GROUP_WHOLE_BAN.value,
{
"group_id": group_id,
"enable": enable,
},
)
def handle_kick_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理群成员踢出命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.GROUP_KICK.value,
{
"group_id": group_id,
"user_id": user_id,
"reject_add_request": False, # 不拒绝加群请求
},
)
def handle_poke_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""处理戳一戳命令
Args:
args (Dict[str, Any]): 参数字典
group_info (GroupInfo): 群聊信息(对应目标群聊)
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
user_id: int = int(args["qq_id"])
if group_info is None:
group_id = None
else:
group_id: int = int(group_info.group_id)
if group_id <= 0:
raise ValueError("群组ID无效")
if user_id <= 0:
raise ValueError("用户ID无效")
return (
CommandType.SEND_POKE.value,
{
"group_id": group_id,
"user_id": user_id,
},
)
def delete_msg_command(self, args: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""处理撤回消息命令
Args:
args (Dict[str, Any]): 参数字典
Returns:
Tuple[CommandType, Dict[str, Any]]
"""
try:
message_id = int(args["message_id"])
if message_id <= 0:
raise ValueError("消息ID无效")
except KeyError:
raise ValueError("缺少必需参数: message_id") from None
except (ValueError, TypeError) as e:
raise ValueError(f"消息ID无效: {args['message_id']} - {str(e)}") from None
return (
CommandType.DELETE_MSG.value,
{
"message_id": message_id,
},
)
def handle_ai_voice_send_command(self, args: Dict[str, Any], group_info: GroupInfo) -> Tuple[str, Dict[str, Any]]:
"""
处理AI语音发送命令的逻辑。
并返回 NapCat 兼容的 (action, params) 元组。
"""
if not group_info or not group_info.group_id:
raise ValueError("AI语音发送命令必须在群聊上下文中使用")
if not args:
raise ValueError("AI语音发送命令缺少参数")
group_id: int = int(group_info.group_id)
character_id = args.get("character")
text_content = args.get("text")
if not character_id or not text_content:
raise ValueError(f"AI语音发送命令参数不完整: character='{character_id}', text='{text_content}'")
return (
CommandType.AI_VOICE_SEND.value,
{
"group_id": group_id,
"text": text_content,
"character": character_id,
},
)
async def send_message_to_napcat(self, action: str, params: dict) -> dict:
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
await self.server_connection.send(payload)
try:
response = await get_response(request_uuid)
except TimeoutError:
logger.error("发送消息超时,未收到响应")
return {"status": "error", "message": "timeout"}
except Exception as e:
logger.error(f"发送消息失败: {e}")
return {"status": "error", "message": str(e)}
return response
async def message_sent_back(self, message_base: MessageBase, qq_message_id: str) -> None:
# 修改 additional_config添加 echo 字段
if message_base.message_info.additional_config is None:
message_base.message_info.additional_config = {}
message_base.message_info.additional_config["echo"] = True
# 获取原始的 mmc_message_id
mmc_message_id = message_base.message_info.message_id
# 修改 message_segment 为 notify 类型
message_base.message_segment = Seg(
type="notify", data={"sub_type": "echo", "echo": mmc_message_id, "actual_id": qq_message_id}
)
await message_send_instance.message_send(message_base)
logger.debug("已回送消息ID")
return
send_handler = SendHandler()

View File

@@ -0,0 +1,310 @@
import websockets as Server
import json
import base64
import uuid
import urllib3
import ssl
import io
from src.database import BanUser, db_manager
from .logger import logger
from .response_pool import get_response
from PIL import Image
from typing import Union, List, Tuple, Optional
class SSLAdapter(urllib3.PoolManager):
def __init__(self, *args, **kwargs):
context = ssl.create_default_context()
context.set_ciphers("DEFAULT@SECLEVEL=1")
context.minimum_version = ssl.TLSVersion.TLSv1_2
kwargs["ssl_context"] = context
super().__init__(*args, **kwargs)
async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群相关信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群聊信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_group_detail_info(websocket: Server.ServerConnection, group_id: int) -> dict | None:
"""
获取群详细信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群详细信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取群详细信息超时,群号: {group_id}")
return None
except Exception as e:
logger.error(f"获取群详细信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_member_info(websocket: Server.ServerConnection, group_id: int, user_id: int) -> dict | None:
"""
获取群成员信息
返回值需要处理可能为空的情况
"""
logger.debug("获取群成员信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_group_member_info",
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
socket_response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取成员信息超时,群号: {group_id}, 用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取成员信息失败: {e}")
return None
logger.debug(socket_response)
return socket_response.get("data")
async def get_image_base64(url: str) -> str:
# sourcery skip: raise-specific-error
"""获取图片/表情包的Base64"""
logger.debug(f"下载图片: {url}")
http = SSLAdapter()
try:
response = http.request("GET", url, timeout=10)
if response.status != 200:
raise Exception(f"HTTP Error: {response.status}")
image_bytes = response.data
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"图片下载失败: {str(e)}")
raise
def convert_image_to_gif(image_base64: str) -> str:
# sourcery skip: extract-method
"""
将Base64编码的图片转换为GIF格式
Parameters:
image_base64: str: Base64编码的图片数据
Returns:
str: Base64编码的GIF图片数据
"""
logger.debug("转换图片为GIF格式")
try:
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
output_buffer = io.BytesIO()
image.save(output_buffer, format="GIF")
output_buffer.seek(0)
return base64.b64encode(output_buffer.read()).decode("utf-8")
except Exception as e:
logger.error(f"图片转换为GIF失败: {str(e)}")
return image_base64
async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
"""
获取自身信息
Parameters:
websocket: WebSocket连接对象
Returns:
data: dict: 返回的自身信息
"""
logger.debug("获取自身信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error("获取自身信息超时")
return None
except Exception as e:
logger.error(f"获取自身信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
def get_image_format(raw_data: str) -> str:
"""
从Base64编码的数据中确定图片的格式。
Parameters:
raw_data: str: Base64编码的图片数据。
Returns:
format: str: 图片的格式(例如 'jpeg', 'png', 'gif')。
"""
image_bytes = base64.b64decode(raw_data)
return Image.open(io.BytesIO(image_bytes)).format.lower()
async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) -> dict | None:
"""
获取陌生人信息
Parameters:
websocket: WebSocket连接对象
user_id: 用户ID
Returns:
dict: 返回的陌生人信息
"""
logger.debug("获取陌生人信息中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid)
except TimeoutError:
logger.error(f"获取陌生人信息超时用户ID: {user_id}")
return None
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_message_detail(websocket: Server.ServerConnection, message_id: Union[str, int]) -> dict | None:
"""
获取消息详情,可能为空
Parameters:
websocket: WebSocket连接对象
message_id: 消息ID
Returns:
dict: 返回的消息详情
"""
logger.debug("获取消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取消息详情超时消息ID: {message_id}")
return None
except Exception as e:
logger.error(f"获取消息详情失败: {e}")
return None
logger.debug(response)
return response.get("data")
async def get_record_detail(
websocket: Server.ServerConnection, file: str, file_id: Optional[str] = None
) -> dict | None:
"""
获取语音消息内容
Parameters:
websocket: WebSocket连接对象
file: 文件名
file_id: 文件ID
Returns:
dict: 返回的语音消息详情
"""
logger.debug("获取语音消息详情中")
request_uuid = str(uuid.uuid4())
payload = json.dumps(
{
"action": "get_record",
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
"echo": request_uuid,
}
)
try:
await websocket.send(payload)
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
except TimeoutError:
logger.error(f"获取语音消息详情超时,文件: {file}, 文件ID: {file_id}")
return None
except Exception as e:
logger.error(f"获取语音消息详情失败: {e}")
return None
logger.debug(f"{str(response)[:200]}...") # 防止语音的超长base64编码导致日志过长
return response.get("data")
async def read_ban_list(
websocket: Server.ServerConnection,
) -> Tuple[List[BanUser], List[BanUser]]:
"""
从根目录下的data文件夹中的文件读取禁言列表。
同时自动更新已经失效禁言
Returns:
Tuple[
一个仍在禁言中的用户的BanUser列表,
一个已经自然解除禁言的用户的BanUser列表,
一个仍在全体禁言中的群的BanUser列表,
一个已经自然解除全体禁言的群的BanUser列表,
]
"""
try:
ban_list = db_manager.get_ban_records()
lifted_list: List[BanUser] = []
logger.info("已经读取禁言列表")
for ban_record in ban_list:
if ban_record.user_id == 0:
fetched_group_info = await get_group_info(websocket, ban_record.group_id)
if fetched_group_info is None:
logger.warning(f"无法获取群信息,群号: {ban_record.group_id},默认禁言解除")
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
group_all_shut: int = fetched_group_info.get("group_all_shut")
if group_all_shut == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
else:
fetched_member_info = await get_member_info(websocket, ban_record.group_id, ban_record.user_id)
if fetched_member_info is None:
logger.warning(
f"无法获取群成员信息用户ID: {ban_record.user_id}, 群号: {ban_record.group_id},默认禁言解除"
)
lifted_list.append(ban_record)
ban_list.remove(ban_record)
continue
lift_ban_time: int = fetched_member_info.get("shut_up_timestamp")
if lift_ban_time == 0:
lifted_list.append(ban_record)
ban_list.remove(ban_record)
else:
ban_record.lift_time = lift_ban_time
db_manager.update_ban_record(ban_list)
return ban_list, lifted_list
except Exception as e:
logger.error(f"读取禁言列表失败: {e}")
return [], []
def save_ban_record(list: List[BanUser]):
return db_manager.update_ban_record(list)

View File

@@ -0,0 +1,34 @@
[inner]
version = "0.1.1" # 版本号
# 请勿修改版本号,除非你知道自己在做什么
[nickname] # 现在没用
nickname = ""
[napcat_server] # Napcat连接的ws服务设置
host = "localhost" # Napcat设定的主机地址
port = 8095 # Napcat设定的端口
heartbeat_interval = 30 # 与Napcat设置的心跳相同按秒计
[maibot_server] # 连接麦麦的ws服务设置
host = "localhost" # 麦麦在.env文件中设置的主机地址即HOST字段
port = 8000 # 麦麦在.env文件中设置的端口即PORT字段
[chat] # 黑白名单功能
group_list_type = "whitelist" # 群组名单类型可选为whitelist, blacklist
group_list = [] # 群组名单
# 当group_list_type为whitelist时只有群组名单中的群组可以聊天
# 当group_list_type为blacklist时群组名单中的任何群组无法聊天
private_list_type = "whitelist" # 私聊名单类型可选为whitelist, blacklist
private_list = [] # 私聊名单
# 当private_list_type为whitelist时只有私聊名单中的用户可以聊天
# 当private_list_type为blacklist时私聊名单中的任何用户无法聊天
ban_user_id = [] # 全局禁止名单(全局禁止名单中的用户无法进行任何聊天)
ban_qq_bot = false # 是否屏蔽QQ官方机器人
enable_poke = true # 是否启用戳一戳功能
[voice] # 发送语音设置
use_tts = false # 是否使用tts语音请确保你配置了tts并有对应的adapter
[debug]
level = "INFO" # 日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL

0
identifier.sqlite Normal file
View File

View File

@@ -8,9 +8,6 @@ from src.plugin_system import (
ComponentInfo,
ActionActivationType,
ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
ToolParamType
)
@@ -126,22 +123,23 @@ class TimeCommand(BaseCommand):
message = f"⏰ 当前时间:{time_str}"
await self.send_text(message)
return True, f"显示了当前时间: {time_str}", True
return True, f"显示了当前x时间: {time_str}", True
class PrintMessage(BaseEventHandler):
"""打印消息事件处理器 - 处理打印消息事件"""
event_type = EventType.ON_MESSAGE
handler_name = "print_message_handler"
handler_description = "打印接收到的消息"
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
"""执行打印消息事件处理"""
# 打印接收到的消息
if self.get_config("print_message.enabled", False):
print(f"接收到消息: {message.raw_message}")
return True, True, "消息已打印"
# class PrintMessage(BaseEventHandler):
# """打印消息事件处理器 - 处理打印消息事件"""
#
# event_type = EventType.ON_MESSAGE
# handler_name = "print_message_handler"
# handler_description = "打印接收到的消息"
#
# async def execute(self, message: MaiMessages) -> Tuple[bool, bool, str | None]:
# """执行打印消息事件处理"""
# # 打印接收到的消息
#
# if self.get_config("print_message.enabled", False):
# print(f"接收到消息: {message.raw_message}")
# return True, True, "消息已打印1"
# ===== 插件注册 =====
@@ -184,7 +182,7 @@ class HelloWorldPlugin(BasePlugin):
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
# (PrintMessage.get_handler_info(), PrintMessage),
]

View File

@@ -5,6 +5,7 @@ description = "MaiCore 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.10"
dependencies = [
"aiohttp>=3.12.14",
"aiohttp-cors>=0.8.1",
"apscheduler>=3.11.0",
"colorama>=0.4.6",
"cryptography>=45.0.5",
@@ -12,6 +13,8 @@ dependencies = [
"dotenv>=0.9.9",
"faiss-cpu>=1.11.0",
"fastapi>=0.116.0",
"google>=3.0.0",
"google-genai>=1.29.0",
"jieba>=0.42.1",
"json-repair>=0.47.6",
"jsonlines>=4.0.0",
@@ -22,12 +25,12 @@ dependencies = [
"openai>=1.95.0",
"packaging>=25.0",
"pandas>=2.3.1",
"peewee>=3.18.2",
"pillow>=11.3.0",
"psutil>=7.0.0",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
"pymongo>=4.13.2",
"pymysql>=1.1.1",
"pypinyin>=0.54.0",
"python-dateutil>=2.9.0.post0",
"python-dotenv>=1.1.1",
@@ -41,6 +44,7 @@ dependencies = [
"scipy>=1.15.3",
"seaborn>=0.13.2",
"setuptools>=80.9.0",
"sqlalchemy>=2.0.42",
"strawberry-graphql[fastapi]>=0.275.5",
"structlog>=25.4.0",
"toml>=0.10.2",
@@ -50,9 +54,13 @@ dependencies = [
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"watchdog>=6.0.0",
"websockets>=15.0.1",
]
[[tool.uv.index]]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
default = true
[tool.ruff]
@@ -96,3 +104,8 @@ skip-magic-trailing-comma = false
# 自动检测合适的换行符
line-ending = "auto"
[dependency-groups]
lint = [
"loguru>=0.7.3",
]

View File

@@ -1,3 +1,4 @@
sqlalchemy
APScheduler
Pillow
aiohttp
@@ -47,3 +48,4 @@ reportportal-client
scikit-learn
seaborn
structlog
watchdog

View File

@@ -1,208 +0,0 @@
import time
import sys
import os
from typing import Dict, List
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
# 直接从数据库查询ChatStreams表
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
# 如果是私聊,显示用户昵称
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
'0-1天': 0,
'1-3天': 0,
'3-7天': 0,
'7-14天': 0,
'14-30天': 0,
'30-60天': 0,
'60-90天': 0,
'90+天': 0
}
for expr in expressions:
diff_days = (now - expr.last_active_time) / (24*3600)
if diff_days < 1:
distribution['0-1天'] += 1
elif diff_days < 3:
distribution['1-3天'] += 1
elif diff_days < 7:
distribution['3-7天'] += 1
elif diff_days < 14:
distribution['7-14天'] += 1
elif diff_days < 30:
distribution['14-30天'] += 1
elif diff_days < 60:
distribution['30-60天'] += 1
elif diff_days < 90:
distribution['60-90天'] += 1
else:
distribution['90+天'] += 1
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values"""
distribution = {
'0-1': 0,
'1-2': 0,
'2-3': 0,
'3-4': 0,
'4-5': 0,
'5-10': 0,
'10+': 0
}
for expr in expressions:
cnt = expr.count
if cnt < 1:
distribution['0-1'] += 1
elif cnt < 2:
distribution['1-2'] += 1
elif cnt < 3:
distribution['2-3'] += 1
elif cnt < 4:
distribution['3-4'] += 1
elif cnt < 5:
distribution['4-5'] += 1
elif cnt < 10:
distribution['5-10'] += 1
else:
distribution['10+'] += 1
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return (Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.count.desc())
.limit(top_n))
def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics"""
time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}")
print("\n上次激活时间分布:")
for period, count in time_dist.items():
print(f"{period}: {count} ({count/total*100:.2f}%)")
print("\ncount分布:")
for range_, count in count_dist.items():
print(f"{range_}: {count} ({count/total*100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}")
if chat_total == 0:
print("该聊天没有表达式数据")
return
# Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:")
for period, count in time_dist.items():
if count > 0:
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
# Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:")
for range_, count in count_dist.items():
if count > 0:
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
# Top expressions
print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10)
for i, expr in enumerate(top_exprs, 1):
print(f"{i}. [{expr.type}] Count: {expr.count}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print()
def interactive_menu() -> None:
"""Interactive menu for expression statistics"""
# Get all expressions
expressions = list(Expression.select())
if not expressions:
print("数据库中没有找到表达式")
return
total = len(expressions)
# Get unique chat_ids and their names
chat_ids = list(set(expr.chat_id for expr in expressions))
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True:
print("\n" + "="*50)
print("表达式统计分析")
print("="*50)
print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == 'q':
print("再见!")
break
try:
choice_num = int(choice)
if choice_num == 0:
show_overall_statistics(expressions, total)
elif 1 <= choice_num <= len(chat_info):
chat_id, chat_name = chat_info[choice_num - 1]
show_chat_statistics(chat_id, chat_name)
else:
print("无效的选择,请重新输入")
except ValueError:
print("请输入有效的数字")
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@@ -1,217 +0,0 @@
import json
import os
import signal
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event
import sys
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建临时目录: {TEMP_DIR}")
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
# 创建一个事件标志,用于控制程序终止
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
)
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
# 使用文件锁检查和读取缓存文件
with file_lock:
if os.path.exists(temp_file_path):
try:
# 存在对应的提取结果
logger.info(f"找到缓存的提取结果:{pg_hash}")
with open(temp_file_path, "r", encoding="utf-8") as f:
return json.load(f), None
except json.JSONDecodeError:
# 如果JSON文件损坏删除它并重新处理
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
os.remove(temp_file_path)
entity_list, rdf_triple_list = info_extract_from_str(
lpmm_entity_extract_llm,
lpmm_rdf_build_llm,
raw_data,
)
if entity_list is None or rdf_triple_list is None:
return None, pg_hash
doc_item = {
"idx": pg_hash,
"passage": raw_data,
"extracted_entities": entity_list,
"extracted_triples": rdf_triple_list,
}
# 保存临时提取结果
with file_lock:
try:
with open(temp_file_path, "w", encoding="utf-8") as f:
json.dump(doc_item, f, ensure_ascii=False, indent=4)
except Exception as e:
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
# 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
sys.exit(0)
return None, pg_hash
return doc_item, None
def signal_handler(_signum, _frame):
"""处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
sys.exit(0)
def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器
signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在
# 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")
print("举例600万字全剧情提取选用deepseek v3 0324消耗约40元约3小时。")
print("建议使用硅基流动的非Pro模型")
print("或者使用可以用赠金抵扣的Pro模型")
print("请确保账户余额充足,并且在执行前确认无误。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y":
logger.info("用户取消操作")
print("操作已取消")
sys.exit(1)
print("\n" + "=" * 40 + "\n")
ensure_dirs() # 确保目录存在
logger.info("--------进行信息提取--------\n")
# 加载原始数据
logger.info("正在加载原始数据")
all_sha256_list, all_raw_datas = load_raw_data()
failed_sha256 = []
open_ie_doc = []
workers = global_config.lpmm_knowledge.info_extraction_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
}
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try:
for future in as_completed(future_to_hash):
if shutdown_event.is_set():
for f in future_to_hash:
if not f.done():
f.cancel()
break
doc_item, failed_hash = future.result()
if failed_hash:
failed_sha256.append(failed_hash)
logger.error(f"提取失败:{failed_hash}")
elif doc_item:
with open_ie_doc_lock:
open_ie_doc.append(doc_item)
progress.update(task, advance=1)
except KeyboardInterrupt:
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set()
for f in future_to_hash:
if not f.done():
f.cancel()
# 合并所有文件的提取结果并保存
if open_ie_doc:
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
openie_obj = OpenIE(
open_ie_doc,
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
)
# 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
f,
ensure_ascii=False,
indent=4,
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}")
if __name__ == "__main__":
main()

View File

@@ -1,287 +0,0 @@
import time
import sys
import os
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
'0.000-0.010': 0,
'0.010-0.050': 0,
'0.050-0.100': 0,
'0.100-0.500': 0,
'0.500-1.000': 0,
'1.000-2.000': 0,
'2.000-5.000': 0,
'5.000-10.000': 0,
'10.000+': 0
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution['0.000-0.010'] += 1
elif value < 0.050:
distribution['0.010-0.050'] += 1
elif value < 0.100:
distribution['0.050-0.100'] += 1
elif value < 0.500:
distribution['0.100-0.500'] += 1
elif value < 1.000:
distribution['0.500-1.000'] += 1
elif value < 2.000:
distribution['1.000-2.000'] += 1
elif value < 5.000:
distribution['2.000-5.000'] += 1
elif value < 10.000:
distribution['5.000-10.000'] += 1
else:
distribution['10.000+'] += 1
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
if not values:
return {
'count': 0,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
values.sort()
count = len(values)
return {
'count': count,
'min': min(values),
'max': max(values),
'avg': sum(values) / count,
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
).count()
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
elif choice == "2":
return now - 3*24*3600, now
elif choice == "3":
return now - 7*24*3600, now
elif choice == "4":
return now - 30*24*3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where(
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
)
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats['count']
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name}: {count} ({percentage:.2f}%)")
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "="*50)
print("Interest Value 分析工具")
print("="*50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

File diff suppressed because it is too large Load Diff

View File

@@ -1,237 +0,0 @@
"""
插件Manifest管理命令行工具
提供插件manifest文件的创建、验证和管理功能
"""
import os
import sys
import argparse
import json
from pathlib import Path
from src.common.logger import get_logger
from src.plugin_system.utils.manifest_utils import (
ManifestValidator,
)
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
logger = get_logger("manifest_tool")
def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool:
"""创建最小化的manifest文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
description: 插件描述
author: 插件作者
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建最小化manifest
minimal_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": description or f"{plugin_name}插件",
"author": {"name": author or "Unknown"},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(minimal_manifest, f, ensure_ascii=False, indent=2)
print(f"✅ 已创建最小化manifest文件: {manifest_path}")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool:
"""创建完整的manifest模板文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建完整模板
complete_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": f"{plugin_name}插件描述",
"author": {"name": "插件作者", "url": "https://github.com/your-username"},
"license": "MIT",
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
"homepage_url": "https://github.com/your-repo",
"repository_url": "https://github.com/your-repo",
"keywords": ["keyword1", "keyword2"],
"categories": ["Category1"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": False,
"plugin_type": "general",
"components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}],
},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(complete_manifest, f, ensure_ascii=False, indent=2)
print(f"✅ 已创建完整manifest模板: {manifest_path}")
print("💡 请根据实际情况修改manifest文件中的内容")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def validate_manifest_file(plugin_dir: str) -> bool:
"""验证manifest文件
Args:
plugin_dir: 插件目录
Returns:
bool: 是否验证通过
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if not os.path.exists(manifest_path):
print(f"❌ 未找到manifest文件: {manifest_path}")
return False
try:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest_data = json.load(f)
validator = ManifestValidator()
is_valid = validator.validate_manifest(manifest_data)
# 显示验证结果
print("📋 Manifest验证结果:")
print(validator.get_validation_report())
if is_valid:
print("✅ Manifest文件验证通过")
else:
print("❌ Manifest文件验证失败")
return is_valid
except json.JSONDecodeError as e:
print(f"❌ Manifest文件格式错误: {e}")
return False
except Exception as e:
print(f"❌ 验证过程中发生错误: {e}")
return False
def scan_plugins_without_manifest(root_dir: str) -> None:
"""扫描缺少manifest文件的插件
Args:
root_dir: 扫描的根目录
"""
print(f"🔍 扫描目录: {root_dir}")
plugins_without_manifest = []
for root, dirs, files in os.walk(root_dir):
# 跳过隐藏目录和__pycache__
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
# 检查是否包含plugin.py文件标识为插件目录
if "plugin.py" in files:
manifest_path = os.path.join(root, "_manifest.json")
if not os.path.exists(manifest_path):
plugins_without_manifest.append(root)
if plugins_without_manifest:
print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:")
for plugin_dir in plugins_without_manifest:
plugin_name = os.path.basename(plugin_dir)
print(f" - {plugin_name}: {plugin_dir}")
print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件")
else:
print("✅ 所有插件都有manifest文件")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="插件Manifest管理工具")
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 创建最小化manifest命令
create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件")
create_minimal_parser.add_argument("plugin_dir", help="插件目录路径")
create_minimal_parser.add_argument("--name", help="插件名称")
create_minimal_parser.add_argument("--description", help="插件描述")
create_minimal_parser.add_argument("--author", help="插件作者")
# 创建完整manifest命令
create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板")
create_complete_parser.add_argument("plugin_dir", help="插件目录路径")
create_complete_parser.add_argument("--name", help="插件名称")
# 验证manifest命令
validate_parser = subparsers.add_parser("validate", help="验证manifest文件")
validate_parser.add_argument("plugin_dir", help="插件目录路径")
# 扫描插件命令
scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件")
scan_parser.add_argument("root_dir", help="扫描的根目录路径")
args = parser.parse_args()
if not args.command:
parser.print_help()
return
try:
if args.command == "create-minimal":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "")
sys.exit(0 if success else 1)
elif args.command == "create-complete":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_complete_manifest(args.plugin_dir, plugin_name)
sys.exit(0 if success else 1)
elif args.command == "validate":
success = validate_manifest_file(args.plugin_dir)
sys.exit(0 if success else 1)
elif args.command == "scan":
scan_plugins_without_manifest(args.root_dir)
except Exception as e:
print(f"❌ 执行命令时发生错误: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,920 +0,0 @@
import os
import json
import sys # 新增系统模块导入
# import time
import pickle
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, Any, List, Optional, Type
from dataclasses import dataclass, field
from datetime import datetime
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from peewee import Model, Field, IntegrityError
# Rich 进度条和显示组件
from rich.console import Console
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
SpinnerColumn,
)
from rich.table import Table
from rich.panel import Panel
# from rich.text import Text
from src.common.database.database import db
from src.common.database.database_model import (
ChatStreams,
Emoji,
Messages,
Images,
ImageDescriptions,
PersonInfo,
Knowledges,
ThinkingLog,
GraphNodes,
GraphEdges,
)
from src.common.logger import get_logger
logger = get_logger("mongodb_to_sqlite")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@dataclass
class MigrationCheckpoint:
"""迁移断点数据"""
collection_name: str
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
@dataclass
class MigrationStats:
"""迁移统计信息"""
total_documents: int = 0
processed_count: int = 0
success_count: int = 0
error_count: int = 0
skipped_count: int = 0
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
)
self.error_count += 1
def add_validation_error(self, doc_id: Any, field: str, error: str):
"""添加验证错误"""
self.add_error(doc_id, f"验证失败 - {field}: {error}")
self.validation_errors += 1
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: Optional[MongoClient] = None
self.mongo_db = None
# 迁移配置
self.migration_configs = self._initialize_migration_configs()
# 进度条控制台
self.console = Console()
# 检查点目录
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
self.checkpoint_dir.mkdir(exist_ok=True)
# 验证规则已禁用
self.validation_rules = self._initialize_validation_rules()
def _build_mongo_uri(self) -> str:
"""构建MongoDB连接URI"""
if mongo_uri := os.getenv("MONGODB_URI"):
return mongo_uri
user = os.getenv("MONGODB_USER")
password = os.getenv("MONGODB_PASS")
host = os.getenv("MONGODB_HOST", "localhost")
port = os.getenv("MONGODB_PORT", "27017")
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
if user and password:
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> List[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
mongo_collection="emoji",
target_model=Emoji,
field_mapping={
"full_path": "full_path",
"format": "format",
"hash": "emoji_hash",
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
"last_used_time": "last_used_time",
# record_time字段将在转换时自动设置为当前时间
},
enable_validation=False, # 禁用数据验证
unique_fields=["full_path", "emoji_hash"],
),
# 聊天流迁移配置
MigrationConfig(
mongo_collection="chat_streams",
target_model=ChatStreams,
field_mapping={
"stream_id": "stream_id",
"create_time": "create_time",
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null而新的数据库不允许为null所以私聊聊天流是没法迁移的等更新吧。
"group_info.group_id": "group_id", # 同上
"group_info.group_name": "group_name", # 同上
"last_active_time": "last_active_time",
"platform": "platform",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
},
enable_validation=False, # 禁用数据验证
unique_fields=["stream_id"],
),
# 消息迁移配置
MigrationConfig(
mongo_collection="messages",
target_model=Messages,
field_mapping={
"message_id": "message_id",
"time": "time",
"chat_id": "chat_id",
"chat_info.stream_id": "chat_info_stream_id",
"chat_info.platform": "chat_info_platform",
"chat_info.user_info.platform": "chat_info_user_platform",
"chat_info.user_info.user_id": "chat_info_user_id",
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
"chat_info.group_info.platform": "chat_info_group_platform",
"chat_info.group_info.group_id": "chat_info_group_id",
"chat_info.group_info.group_name": "chat_info_group_name",
"chat_info.create_time": "chat_info_create_time",
"chat_info.last_active_time": "chat_info_last_active_time",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text",
"memorized_times": "memorized_times",
},
enable_validation=False, # 禁用数据验证
unique_fields=["message_id"],
),
# 图片迁移配置
MigrationConfig(
mongo_collection="images",
target_model=Images,
field_mapping={
"hash": "emoji_hash",
"description": "description",
"path": "path",
"timestamp": "timestamp",
"type": "type",
},
unique_fields=["path"],
),
# 图片描述迁移配置
MigrationConfig(
mongo_collection="image_descriptions",
target_model=ImageDescriptions,
field_mapping={
"type": "type",
"hash": "image_description_hash",
"description": "description",
"timestamp": "timestamp",
},
unique_fields=["image_description_hash", "type"],
),
# 个人信息迁移配置
MigrationConfig(
mongo_collection="person_info",
target_model=PersonInfo,
field_mapping={
"person_id": "person_id",
"person_name": "person_name",
"name_reason": "name_reason",
"platform": "platform",
"user_id": "user_id",
"nickname": "nickname",
"relationship_value": "relationship_value",
"konw_time": "know_time",
},
unique_fields=["person_id"],
),
# 知识库迁移配置
MigrationConfig(
mongo_collection="knowledges",
target_model=Knowledges,
field_mapping={"content": "content", "embedding": "embedding"},
unique_fields=["content"], # 假设内容唯一
),
# 思考日志迁移配置
MigrationConfig(
mongo_collection="thinking_log",
target_model=ThinkingLog,
field_mapping={
"chat_id": "chat_id",
"trigger_text": "trigger_text",
"response_text": "response_text",
"trigger_info": "trigger_info_json",
"response_info": "response_info_json",
"timing_results": "timing_results_json",
"chat_history": "chat_history_json",
"chat_history_in_thinking": "chat_history_in_thinking_json",
"chat_history_after_response": "chat_history_after_response_json",
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
unique_fields=["chat_id", "trigger_text"],
),
# 图节点迁移配置
MigrationConfig(
mongo_collection="graph_data.nodes",
target_model=GraphNodes,
field_mapping={
"concept": "concept",
"memory_items": "memory_items",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["concept"],
),
# 图边迁移配置
MigrationConfig(
mongo_collection="graph_data.edges",
target_model=GraphEdges,
field_mapping={
"source": "source",
"target": "target",
"strength": "strength",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["source", "target"], # 组合唯一性
),
]
def _initialize_validation_rules(self) -> Dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
def connect_mongodb(self) -> bool:
"""连接到MongoDB"""
try:
self.mongo_client = MongoClient(
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
)
# 测试连接
self.mongo_client.admin.command("ping")
self.mongo_db = self.mongo_client[self.database_name]
logger.info(f"成功连接到MongoDB: {self.database_name}")
return True
except ConnectionFailure as e:
logger.error(f"MongoDB连接失败: {e}")
return False
except Exception as e:
logger.error(f"MongoDB连接异常: {e}")
return False
def disconnect_mongodb(self):
"""断开MongoDB连接"""
if self.mongo_client:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
parts = field_path.split(".")
value = document
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
if value is None:
break
return value
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
"""根据目标字段类型转换值"""
if value is None:
return None
field_type = target_field.__class__.__name__
try:
if target_field.name == "record_time" and field_type == "DateTimeField":
return datetime.now()
if field_type in ["CharField", "TextField"]:
if isinstance(value, (list, dict)):
return json.dumps(value, ensure_ascii=False)
return str(value) if value is not None else ""
elif field_type == "IntegerField":
if isinstance(value, str):
# 处理字符串数字
clean_value = value.strip()
if clean_value.replace(".", "").replace("-", "").isdigit():
return int(float(clean_value))
return 0
return int(value) if value is not None else 0
elif field_type in ["FloatField", "DoubleField"]:
return float(value) if value is not None else 0.0
elif field_type == "BooleanField":
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
elif field_type == "DateTimeField":
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value)
elif isinstance(value, str):
try:
# 尝试解析ISO格式日期
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
try:
# 尝试解析时间戳字符串
return datetime.fromtimestamp(float(value))
except ValueError:
return datetime.now()
return datetime.now()
return value
except (ValueError, TypeError) as e:
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
return self._get_default_value_for_field(target_field)
def _get_default_value_for_field(self, field: Field) -> Any:
"""获取字段的默认值"""
field_type = field.__class__.__name__
if hasattr(field, "default") and field.default is not None:
return field.default
if field.null:
return None
# 根据字段类型返回默认值
if field_type in ["CharField", "TextField"]:
return ""
elif field_type == "IntegerField":
return 0
elif field_type in ["FloatField", "DoubleField"]:
return 0.0
elif field_type == "BooleanField":
return False
elif field_type == "DateTimeField":
return datetime.now()
return None
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
"""保存迁移断点"""
checkpoint = MigrationCheckpoint(
collection_name=collection_name,
processed_count=processed_count,
last_processed_id=last_id,
timestamp=datetime.now(),
)
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
try:
with open(checkpoint_file, "wb") as f:
pickle.dump(checkpoint, f)
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
return None
try:
with open(checkpoint_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
success_count = 0
try:
with db.atomic():
# 分批插入避免SQL语句过长
batch_size = 100
for i in range(0, len(data_list), batch_size):
batch = data_list[i : i + batch_size]
model.insert_many(batch).execute()
success_count += len(batch)
except Exception as e:
logger.error(f"批量插入失败: {e}")
# 如果批量插入失败,尝试逐个插入
for data in data_list:
try:
model.create(**data)
success_count += 1
except Exception:
pass # 忽略单个插入失败
return success_count
def _check_duplicate_by_unique_fields(
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
return False
try:
query = model.select()
for field_name in unique_fields:
if field_name in data and data[field_name] is not None:
field_obj = getattr(model, field_name)
query = query.where(field_obj == data[field_name])
return query.exists()
except Exception as e:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
valid_data = {}
for field_name, value in data.items():
if hasattr(model, field_name):
valid_data[field_name] = value
else:
logger.debug(f"跳过未知字段: {field_name}")
# 创建实例
instance = model.create(**valid_data)
return instance
except IntegrityError as e:
# 处理唯一约束冲突等完整性错误
logger.debug(f"完整性约束冲突: {e}")
return None
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
"""迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
stats.start_time = datetime.now()
# 检查是否有断点
checkpoint = self._load_checkpoint(config.mongo_collection)
start_from_id = checkpoint.last_processed_id if checkpoint else None
if checkpoint:
stats.processed_count = checkpoint.processed_count
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
try:
# 获取MongoDB集合
mongo_collection = self.mongo_db[config.mongo_collection]
# 构建查询条件(用于断点恢复)
query = {}
if start_from_id:
query = {"_id": {"$gt": start_from_id}}
stats.total_documents = mongo_collection.count_documents(query)
if stats.total_documents == 0:
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
return stats
logger.info(f"待迁移文档数量: {stats.total_documents}")
# 创建Rich进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
# 批量处理数据
batch_data = []
batch_count = 0
last_processed_id = None
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
try:
doc_id = mongo_doc.get("_id", "unknown")
last_processed_id = doc_id
# 构建目标数据
target_data = {}
for mongo_field, sqlite_field in config.field_mapping.items():
value = self._get_nested_value(mongo_doc, mongo_field)
# 获取目标字段对象并转换类型
if hasattr(config.target_model, sqlite_field):
field_obj = getattr(config.target_model, sqlite_field)
converted_value = self._convert_field_value(value, field_obj)
target_data[sqlite_field] = converted_value
# 数据验证已禁用
# if config.enable_validation:
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
# stats.skipped_count += 1
# continue
# 重复检查
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
config.target_model, target_data, config.unique_fields
):
stats.duplicate_count += 1
stats.skipped_count += 1
logger.debug(f"跳过重复记录: {doc_id}")
continue
# 添加到批量数据
batch_data.append(target_data)
stats.processed_count += 1
# 执行批量插入
if len(batch_data) >= config.batch_size:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
# 保存断点
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
batch_data.clear()
batch_count += 1
# 更新进度条
progress.update(task, advance=config.batch_size)
except Exception as e:
doc_id = mongo_doc.get("_id", "unknown")
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
# 处理剩余的批量数据
if batch_data:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
progress.update(task, advance=len(batch_data))
# 完成进度条
progress.update(task, completed=stats.total_documents)
stats.end_time = datetime.now()
duration = stats.end_time - stats.start_time
logger.info(
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
)
# 清理断点文件
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
if checkpoint_file.exists():
checkpoint_file.unlink()
except Exception as e:
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
stats.add_error("collection_error", str(e))
return stats
def migrate_all(self) -> Dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
if not self.connect_mongodb():
logger.error("无法连接到MongoDB迁移终止")
return {}
all_stats = {}
try:
# 创建总体进度表格
total_collections = len(self.migration_configs)
self.console.print(
Panel(
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
f"[yellow]总集合数: {total_collections}[/yellow]",
title="迁移开始",
expand=False,
)
)
for idx, config in enumerate(self.migration_configs, 1):
self.console.print(
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
)
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
# 显示单个集合的快速统计
if stats.processed_count > 0:
success_rate = stats.success_count / stats.processed_count * 100
if success_rate >= 95:
status_emoji = ""
status_color = "bright_green"
elif success_rate >= 80:
status_emoji = "⚠️"
status_color = "yellow"
else:
status_emoji = ""
status_color = "red"
self.console.print(
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
)
# 错误率检查
if stats.processed_count > 0:
error_rate = stats.error_count / stats.processed_count
if error_rate > 0.1: # 错误率超过10%
self.console.print(
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
f"({stats.error_count}/{stats.processed_count})[/red]"
)
finally:
self.disconnect_mongodb()
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
total_success = sum(stats.success_count for stats in all_stats.values())
total_errors = sum(stats.error_count for stats in all_stats.values())
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
# 计算总耗时
total_duration_seconds = 0
for stats in all_stats.values():
if stats.start_time and stats.end_time:
duration = stats.end_time - stats.start_time
total_duration_seconds += duration.total_seconds()
# 创建详细统计表格
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
table.add_column("集合名称", style="cyan", width=20)
table.add_column("文档总数", justify="right", style="blue")
table.add_column("处理数量", justify="right", style="green")
table.add_column("成功数量", justify="right", style="green")
table.add_column("错误数量", justify="right", style="red")
table.add_column("跳过数量", justify="right", style="yellow")
table.add_column("重复数量", justify="right", style="bright_yellow")
table.add_column("验证错误", justify="right", style="red")
table.add_column("批次数", justify="right", style="purple")
table.add_column("成功率", justify="right", style="bright_green")
table.add_column("耗时(秒)", justify="right", style="blue")
for collection_name, stats in all_stats.items():
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
duration = 0
if stats.start_time and stats.end_time:
duration = (stats.end_time - stats.start_time).total_seconds()
# 根据成功率设置颜色
if success_rate >= 95:
success_rate_style = "[bright_green]"
elif success_rate >= 80:
success_rate_style = "[yellow]"
else:
success_rate_style = "[red]"
table.add_row(
collection_name,
str(stats.total_documents),
str(stats.processed_count),
str(stats.success_count),
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
str(stats.batch_insert_count),
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
f"{duration:.2f}",
)
# 添加总计行
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
if total_success_rate >= 95:
total_rate_style = "[bright_green]"
elif total_success_rate >= 80:
total_rate_style = "[yellow]"
else:
total_rate_style = "[red]"
table.add_section()
table.add_row(
"[bold]总计[/bold]",
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
f"[bold]{total_processed}[/bold]",
f"[bold]{total_success}[/bold]",
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
if total_duplicates > 0
else "[bold]0[/bold]",
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
f"[bold]{total_batch_inserts}[/bold]",
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
f"[bold]{total_duration_seconds:.2f}[/bold]",
)
self.console.print(table)
# 创建状态面板
status_items = []
if total_errors > 0:
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
if total_validation_errors > 0:
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
if total_duplicates > 0:
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
if total_success_rate >= 95:
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
elif total_success_rate >= 80:
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
else:
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
if status_items:
status_panel = Panel(
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
)
self.console.print(status_panel)
# 性能统计面板
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
performance_info = (
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f}\n"
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
)
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
self.console.print(performance_panel)
def add_migration_config(self, config: MigrationConfig):
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
logger.error(f"未找到集合 {collection_name} 的迁移配置")
return None
if not self.connect_mongodb():
logger.error("无法连接到MongoDB")
return None
try:
stats = self.migrate_collection(config)
self._print_migration_summary({collection_name: stats})
return stats
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),
"summary": {
collection: {
"total": stats.total_documents,
"processed": stats.processed_count,
"success": stats.success_count,
"errors": stats.error_count,
"skipped": stats.skipped_count,
"duplicates": stats.duplicate_count,
}
for collection, stats in all_stats.items()
},
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
}
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(error_report, f, ensure_ascii=False, indent=2)
logger.info(f"错误报告已导出到: {filepath}")
except Exception as e:
logger.error(f"导出错误报告失败: {e}")
def main():
"""主程序入口"""
migrator = MongoToSQLiteMigrator()
# 执行迁移
migration_results = migrator.migrate_all()
# 导出错误报告(如果有错误)
if any(stats.error_count > 0 for stats in migration_results.values()):
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
migrator.export_error_report(migration_results, error_report_path)
logger.info("数据迁移完成!")
if __name__ == "__main__":
main()

View File

@@ -1,75 +0,0 @@
import os
from pathlib import Path
import sys # 新增系统模块导入
from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
logger = get_logger("lpmm")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
raw = f.read()
paragraphs = []
paragraph = ""
for line in raw.split("\n"):
if line.strip() == "":
if paragraph != "":
paragraphs.append(paragraph.strip())
paragraph = ""
else:
paragraph += line + "\n"
if paragraph != "":
paragraphs.append(paragraph.strip())
return paragraphs
def _process_multi_files() -> list:
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1)
# 处理所有文件
all_paragraphs = []
for file in raw_files:
logger.info(f"正在处理文件: {file.name}")
paragraphs = _process_text_file(file)
all_paragraphs.extend(paragraphs)
return all_paragraphs
def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件
读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns:
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
raw_data = _process_multi_files()
sha256_list = []
sha256_set = set()
for item in raw_data:
if not isinstance(item, str):
logger.warning(f"数据类型错误:{item}")
continue
pg_hash = get_sha256(item)
if pg_hash in sha256_set:
logger.warning(f"重复数据:{item}")
continue
sha256_set.add(pg_hash)
sha256_list.append(pg_hash)
raw_data.append(item)
logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data

View File

@@ -1,556 +0,0 @@
#!/bin/bash
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.5-refactor"
LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址
GITHUB_REPO="https://ghfast.top/https://github.com"
# 颜色输出
GREEN="\e[32m"
RED="\e[31m"
RESET="\e[0m"
# 需要的基本软件包
declare -A REQUIRED_PACKAGES=(
["common"]="git sudo python3 curl gnupg"
["debian"]="python3-venv python3-pip build-essential"
["ubuntu"]="python3-venv python3-pip build-essential"
["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make"
["arch"]="python-virtualenv python-pip base-devel"
)
# 默认项目目录
DEFAULT_INSTALL_DIR="/opt/maicore"
# 服务名称
SERVICE_NAME="maicore"
SERVICE_NAME_WEB="maicore-web"
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false
# 检查是否已安装
check_installed() {
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
}
# 加载安装信息
load_install_info() {
if [[ -f /etc/maicore_install.conf ]]; then
source /etc/maicore_install.conf
else
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
BRANCH="refactor"
fi
}
# 显示管理菜单
show_menu() {
while true; do
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
"1" "启动MaiCore" \
"2" "停止MaiCore" \
"3" "重启MaiCore" \
"4" "启动NapCat Adapter" \
"5" "停止NapCat Adapter" \
"6" "重启NapCat Adapter" \
"7" "拉取最新MaiCore仓库" \
"8" "切换分支" \
"9" "退出" 3>&1 1>&2 2>&3)
[[ $? -ne 0 ]] && exit 0
case "$choice" in
1)
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅MaiCore已启动" 10 60
;;
2)
systemctl stop ${SERVICE_NAME}
whiptail --msgbox "🛑MaiCore已停止" 10 60
;;
3)
systemctl restart ${SERVICE_NAME}
whiptail --msgbox "🔄MaiCore已重启" 10 60
;;
4)
systemctl start ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
;;
5)
systemctl stop ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
;;
6)
systemctl restart ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
;;
7)
update_dependencies
;;
8)
switch_branch
;;
9)
exit 0
;;
*)
whiptail --msgbox "无效选项!" 10 60
;;
esac
done
}
# 更新依赖
update_dependencies() {
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
systemctl stop ${SERVICE_NAME}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git pull origin "${BRANCH}"; then
whiptail --msgbox "🚫 代码更新失败!" 10 60
return 1
fi
source "${INSTALL_DIR}/venv/bin/activate"
if ! pip install -r requirements.txt; then
whiptail --msgbox "🚫 依赖安装失败!" 10 60
deactivate
return 1
fi
deactivate
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
}
# 切换分支
switch_branch() {
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
[[ -z "$new_branch" ]] && {
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
return 1
}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
return 1
fi
if ! git checkout "${new_branch}"; then
whiptail --msgbox "🚫 分支切换失败!" 10 60
return 1
fi
if ! git pull origin "${new_branch}"; then
whiptail --msgbox "🚫 代码拉取失败!" 10 60
return 1
fi
systemctl stop ${SERVICE_NAME}
source "${INSTALL_DIR}/venv/bin/activate"
pip install -r requirements.txt
deactivate
sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf
BRANCH="${new_branch}"
check_eula
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} " 10 60
}
check_eula() {
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}')
# 如果当前的md5值为空则直接返回
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
fi
# 检查eula.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed)
else
confirmed_md5=""
fi
# 检查privacy.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed)
else
confirmed_md5_privacy=""
fi
# 如果EULA或隐私条款有更新提示用户重新确认
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70
if [[ $? -eq 0 ]]; then
echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed
echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed
else
exit 1
fi
fi
}
# ----------- 主安装流程 -----------
run_installation() {
# 1/6: 检测是否安装 whiptail
if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
if command -v apt-get &>/dev/null; then
apt-get update && apt-get install -y whiptail
elif command -v pacman &>/dev/null; then
pacman -Syu --noconfirm whiptail
elif command -v yum &>/dev/null; then
yum install -y whiptail
else
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
exit 1
fi
fi
whiptail --title " 提示" --msgbox "如果您没有特殊需求请优先使用docker方式部署。" 10 60
# 协议确认
if ! (whiptail --title " [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议" 12 70); then
exit 1
fi
# 欢迎信息
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore将自动进入安装流程安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段代码可能随时更改\n文档未完善有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险请自行了解谨慎使用\n由于持续迭代可能存在一些已知或未知的bug\n由于开发中可能消耗较多token\n\n本脚本可能更新不及时如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
# 系统检查
check_system() {
if [[ "$(id -u)" -ne 0 ]]; then
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
exit 1
fi
if [[ -f /etc/os-release ]]; then
source /etc/os-release
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
return
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
return
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
return
elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
return
else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
exit 1
fi
else
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
exit 1
fi
}
check_system
# 设置包管理器
case "$ID" in
debian|ubuntu)
PKG_MANAGER="apt"
;;
centos)
PKG_MANAGER="yum"
;;
arch)
# 添加arch包管理器
PKG_MANAGER="pacman"
;;
esac
# 检查NapCat
check_napcat() {
if command -v napcat &>/dev/null; then
NAPCAT_INSTALLED=true
else
NAPCAT_INSTALLED=false
fi
}
check_napcat
# 安装必要软件包
install_packages() {
missing_packages=()
# 检查 common 及当前系统专属依赖
for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
case "$PKG_MANAGER" in
apt)
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
;;
yum)
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
;;
pacman)
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
;;
esac
done
if [[ ${#missing_packages[@]} -gt 0 ]]; then
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装" 10 60
if [[ $? -eq 0 ]]; then
IS_INSTALL_DEPENDENCIES=true
else
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续" 10 60 || exit 1
fi
fi
}
install_packages
# 安装NapCat
install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && {
IS_INSTALL_NAPCAT=true
}
}
# 仅在非Arch系统上安装NapCat
[[ "$ID" != "arch" ]] && install_napcat
# Python版本检查
check_python() {
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
exit 1
fi
}
# 如果没安装python则不检查python版本
if command -v python3 &>/dev/null; then
check_python
fi
# 选择分支
choose_branch() {
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
"main" "稳定版本(推荐)" ON \
"dev" "开发版(不知道什么意思就别选)" OFF \
"classical" "经典版0.6.0以前的版本)" OFF \
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 操作取消!" 10 60
exit 1
fi
if [[ "$BRANCH" == "custom" ]]; then
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 输入取消!" 10 60
exit 1
fi
if [[ -z "$BRANCH" ]]; then
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
exit 1
fi
fi
}
choose_branch
# 选择安装路径
choose_install_dir() {
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
[[ -z "$INSTALL_DIR" ]] && {
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
}
}
choose_install_dir
# 确认安装
confirm_install() {
local confirm_msg="请确认以下更改:\n\n"
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
}
confirm_install
# 开始安装
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
case "$PKG_MANAGER" in
apt)
apt update && apt install -y "${missing_packages[@]}"
;;
yum)
yum install -y "${missing_packages[@]}" --nobest
;;
pacman)
pacman -S --noconfirm "${missing_packages[@]}"
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
fi
echo -e "${GREEN}创建安装目录...${RESET}"
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR" || exit 1
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
python3 -m venv venv
source venv/bin/activate
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
echo -e "${RED}克隆MaiCore仓库失败${RESET}"
exit 1
}
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}安装Python依赖...${RESET}"
pip install -r MaiBot/requirements.txt
cd MaiBot
pip install uv
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}安装maim_message依赖...${RESET}"
cd maim_message
uv pip install -i https://mirrors.aliyun.com/pypi/simple -e .
cd ..
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
cd MaiBot-Napcat-Adapter
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}同意协议...${RESET}"
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}')
echo -n $current_md5 > MaiBot/eula.confirmed
echo -n $current_md5_privacy > MaiBot/privacy.confirmed
echo -e "${GREEN}创建系统服务...${RESET}"
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
[Unit]
Description=MaiCore
After=network.target ${SERVICE_NAME_NBADAPTER}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
# [Unit]
# Description=MaiCore WebUI
# After=network.target ${SERVICE_NAME}.service
# [Service]
# Type=simple
# WorkingDirectory=${INSTALL_DIR}/MaiBot
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
# Restart=always
# RestartSec=10s
# [Install]
# WantedBy=multi-user.target
# EOF
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
[Unit]
Description=MaiBot Napcat Adapter
After=network.target mongod.service ${SERVICE_NAME}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
# 保存安装信息
echo "INSTALLER_VERSION=${INSTALLER_VERSION}" > /etc/maicore_install.conf
echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf
echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成\n已创建系统服务${SERVICE_NAME}${SERVICE_NAME_WEB}${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务\n启动服务systemctl start ${SERVICE_NAME}\n查看状态systemctl status ${SERVICE_NAME}" 14 60
}
# ----------- 主执行流程 -----------
# 检查root权限
[[ $(id -u) -ne 0 ]] && {
echo -e "${RED}请使用root用户运行此脚本${RESET}"
exit 1
}
# 如果已安装显示菜单,并检查协议是否更新
if check_installed; then
load_install_info
check_eula
show_menu
else
run_installation
# 安装完成后询问是否启动
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务" 10 60; then
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
fi
fi

View File

@@ -1,51 +0,0 @@
#!/bin/bash
# ==============================================
# Environment Initialization
# ==============================================
# Step 1: Locate project root directory
SCRIPTS_DIR="scripts"
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
PROJECT_ROOT=$(cd "$SCRIPT_DIR/.." && pwd)
# Step 2: Verify scripts directory exists
if [ ! -d "$PROJECT_ROOT/$SCRIPTS_DIR" ]; then
echo "❌ Error: scripts directory not found in project root" >&2
echo "Current path: $PROJECT_ROOT" >&2
exit 1
fi
# Step 3: Set up Python environment
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
cd "$PROJECT_ROOT" || {
echo "❌ Failed to cd to project root: $PROJECT_ROOT" >&2
exit 1
}
# Debug info
echo "============================"
echo "Project Root: $PROJECT_ROOT"
echo "Python Path: $PYTHONPATH"
echo "Working Dir: $(pwd)"
echo "============================"
# ==============================================
# Python Script Execution
# ==============================================
run_python_script() {
local script_name=$1
echo "🔄 Running $script_name"
if ! python3 "$SCRIPTS_DIR/$script_name"; then
echo "$script_name failed" >&2
exit 1
fi
}
# Execute scripts in order
run_python_script "raw_data_preprocessor.py"
run_python_script "info_extraction.py"
run_python_script "import_openie.py"
echo "✅ All scripts completed successfully"

View File

@@ -1,394 +0,0 @@
import time
import sys
import os
import re
from typing import Dict, List, Tuple, Optional
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa
def contains_emoji_or_image_tags(text: str) -> bool:
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
if not text:
return False
# 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r'\[表情包[^\]]*\]'
image_pattern = r'\[图片[^\]]*\]'
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
def clean_reply_text(text: str) -> str:
"""Remove reply references like [回复 xxxx...] from text"""
if not text:
return text
# 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
# 去除多余的空白字符
cleaned_text = cleaned_text.strip()
return cleaned_text
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
'0': 0, # 空文本
'1-5': 0, # 极短文本
'6-10': 0, # 很短文本
'11-20': 0, # 短文本
'21-30': 0, # 较短文本
'31-50': 0, # 中短文本
'51-70': 0, # 中等文本
'71-100': 0, # 较长文本
'101-150': 0, # 长文本
'151-200': 0, # 很长文本
'201-300': 0, # 超长文本
'301-500': 0, # 极长文本
'501-1000': 0, # 巨长文本
'1000+': 0 # 超巨长文本
}
for msg in messages:
if msg.processed_plain_text is None:
continue
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
if length == 0:
distribution['0'] += 1
elif length <= 5:
distribution['1-5'] += 1
elif length <= 10:
distribution['6-10'] += 1
elif length <= 20:
distribution['11-20'] += 1
elif length <= 30:
distribution['21-30'] += 1
elif length <= 50:
distribution['31-50'] += 1
elif length <= 70:
distribution['51-70'] += 1
elif length <= 100:
distribution['71-100'] += 1
elif length <= 150:
distribution['101-150'] += 1
elif length <= 200:
distribution['151-200'] += 1
elif length <= 300:
distribution['201-300'] += 1
elif length <= 500:
distribution['301-500'] += 1
elif length <= 1000:
distribution['501-1000'] += 1
else:
distribution['1000+'] += 1
return distribution
def get_text_length_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for processed_plain_text length"""
lengths = []
null_count = 0
excluded_count = 0 # 被排除的消息数量
for msg in messages:
if msg.processed_plain_text is None:
null_count += 1
elif contains_emoji_or_image_tags(msg.processed_plain_text):
# 排除包含表情包或图片标记的消息
excluded_count += 1
else:
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
lengths.append(len(cleaned_text))
if not lengths:
return {
'count': 0,
'null_count': null_count,
'excluded_count': excluded_count,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
lengths.sort()
count = len(lengths)
return {
'count': count,
'null_count': null_count,
'excluded_count': excluded_count,
'min': min(lengths),
'max': max(lengths),
'avg': sum(lengths) / count,
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
}
def get_available_chats() -> List[Tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id排除特殊类型消息
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
).count()
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24*3600, now
elif choice == "2":
return now - 3*24*3600, now
elif choice == "3":
return now - 7*24*3600, now
elif choice == "4":
return now - 30*24*3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
for msg in messages:
if msg.processed_plain_text is not None:
# 排除包含表情包或图片标记的消息
if contains_emoji_or_image_tags(msg.processed_plain_text):
continue
# 清理文本中的回复引用
cleaned_text = clean_reply_text(msg.processed_plain_text)
length = len(cleaned_text)
chat_name = get_chat_name(msg.chat_id)
time_str = format_timestamp(msg.time)
# 截取前100个字符作为预览
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
message_lengths.append((chat_name, length, time_str, preview))
# 按长度排序取前N个
message_lengths.sort(key=lambda x: x[1], reverse=True)
return message_lengths[:top_n]
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
"""Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息
query = Messages.select().where(
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
)
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_text_length_distribution(messages)
stats = get_text_length_stats(messages)
top_longest = get_top_longest_messages(messages, 10)
# 显示结果
print("\n=== Processed Plain Text 长度分析结果 ===")
print("(已排除表情、图片ID、命令类型消息已排除[表情包]和[图片]标记消息,已清理回复引用)")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"总消息数量: {len(messages)}")
print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}")
if stats['count'] > 0:
print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:")
total = stats['count']
if total > 0:
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
# 显示最长的消息
if top_longest:
print(f"\n最长的 {len(top_longest)} 条消息:")
for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1):
print(f"{i}. [{chat_name}] {time_str}")
print(f" 长度: {length} 字符")
print(f" 预览: {preview}")
print()
def interactive_menu() -> None:
"""Interactive menu for text length analysis"""
while True:
print("\n" + "="*50)
print("Processed Plain Text 长度分析工具")
print("="*50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q':
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到聊天数据")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_text_lengths(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@@ -12,9 +12,10 @@ import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from sqlalchemy import select
from src.common.database.database import db
from src.common.database.sqlalchemy_database_api import get_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
session = get_session()
"""
还没经过测试,有些地方数据库和内存数据同步可能不完全
@@ -151,7 +154,7 @@ class MaiEmoji:
# 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else ""
Emoji.create(
emoji = Emoji(
emoji_hash=self.hash,
full_path=self.full_path,
format=self.format,
@@ -165,6 +168,8 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -200,7 +205,7 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
@@ -393,12 +397,17 @@ class EmojiManager:
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
Emoji.create_table(safe=True) # Ensures table exists
self._initialized = True
try:
db.connect(reuse_if_open=True)
if db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
self._initialized = True # 标记为已初始化
logger.info("EmojiManager初始化成功")
except Exception as e:
logger.error(f"EmojiManager初始化失败: {e}")
self._initialized = False
raise
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
@@ -410,7 +419,7 @@ class EmojiManager:
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
@@ -644,10 +653,10 @@ class EmojiManager:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -675,15 +684,15 @@ class EmojiManager:
self._ensure_db()
if emoji_hash:
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
query = session.execute(select(Emoji)).scalars().all()
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -760,7 +769,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -921,9 +930,10 @@ class EmojiManager:
# 尝试从Images表获取已有的详细描述可能在收到表情包时已生成
existing_description = None
try:
from src.common.database.database_model import Images
# from src.common.database.database_model_compat import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
existing_image = session.execute(stmt).scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -7,7 +7,9 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
@@ -20,7 +22,7 @@ DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
session = get_session()
def format_create_date(timestamp: float) -> str:
"""
@@ -168,30 +170,50 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# """
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
# """
# learnt_style_expressions = []
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
"""
获取指定chat_id的stylegrammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "style",
"create_date": create_date,
}
)
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
for expr in grammar_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions
# # 直接从数据库查询
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
# for expr in style_query:
# # 确保create_date存在如果不存在则使用last_active_time
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# learnt_style_expressions.append(
# {
# "situation": expr.situation,
# "style": expr.style,
# "count": expr.count,
# "last_active_time": expr.last_active_time,
# "source_id": self.chat_id,
# "type": "style",
# "create_date": create_date,
# }
# )
# return learnt_style_expressions
@@ -201,7 +223,7 @@ class ExpressionLearner:
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
all_expressions = session.execute(select(Expression)).scalars()
updated_count = 0
deleted_count = 0
@@ -217,18 +239,20 @@ class ExpressionLearner:
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
session.commit()
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
session.rollback()
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
@@ -297,23 +321,22 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
@@ -322,16 +345,18 @@ class ExpressionLearner:
type="style",
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
session.execute(select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())).scalars()
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
session.delete(expr)
session.commit()
return learnt_expressions
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
@@ -509,54 +534,35 @@ class ExpressionLearnerManager:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
# 查重同chat_id+type+situation+style
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)).scalar()
if query:
expr_obj = query
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
else:
new_expression = Expression(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
session.add(new_expression)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
@@ -581,18 +587,20 @@ class ExpressionLearnerManager:
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
session.commit()
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
session.rollback()
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:

View File

@@ -9,8 +9,11 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
logger = get_logger("expression_selector")
@@ -131,9 +134,12 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
@@ -146,9 +152,24 @@ class ExpressionSelector:
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
@@ -174,19 +195,19 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)

View File

@@ -1,33 +0,0 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
from .utils.dyn_topk import dyn_select_top_k
class MemoryActiveManager:
def __init__(
self,
embed_manager: EmbeddingManager,
llm_client_embedding: LLMClient,
):
self.embed_manager = embed_manager
self.embedding_client = llm_client_embedding
def get_activation(self, question: str) -> float:
"""获取记忆激活度"""
# 生成问题的Embedding
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
# 查询关系库中的相似度
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
# 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
# 未找到相关关系
return 0.0
# 计算激活度
activation = sum([item[2] for item in rel_scores]) * 10
return activation

View File

@@ -16,8 +16,10 @@ from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from sqlalchemy import select,insert,update,text,delete
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
@@ -37,7 +39,7 @@ def cosine_similarity(v1, v2):
install(extra_lines=3)
session = get_session()
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
@@ -731,13 +733,14 @@ class Hippocampus:
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(keywords)
all_words = memory_words | text_words
if all_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
logger.debug(f"节点包含 {len(memory_items)}记忆")
# 计算每条记忆与输入文本的相似度
memory_similarities = []
for memory in memory_items:
# 计算与输入文本的相似度
memory_words = set(jieba.cut(memory))
text_words = set(jieba.cut(text))
all_words = memory_words | text_words
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
@@ -844,11 +847,6 @@ class Hippocampus:
else:
activate_map[node] = activation_value
# 输出激活映射
# logger.info("激活映射统计:")
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
# 计算激活节点数与总节点数的比值
total_activation = sum(activate_map.values())
# logger.debug(f"总激活值: {total_activation:.2f}")
@@ -942,10 +940,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
# 使用 Peewee 更新记录
Messages.update(memorized_times=current_memorized_times + 1).where(
Messages.message_id == message["message_id"]
).execute()
# 使用 SQLAlchemy 2.0 更新记录
session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -959,7 +960,7 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
db_nodes = {node.concept: node for node in GraphNodes.select()}
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -1025,22 +1026,27 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
session.execute(insert(GraphNodes), batch)
session.commit()
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
).execute()
session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
session.commit()
if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
session.commit()
# 处理边的信息
db_edges = list(GraphEdges.select())
db_edges = list(session.execute(select(GraphEdges)).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1092,20 +1098,29 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
).execute()
session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
session.commit()
if edges_to_delete:
for source, target in edges_to_delete:
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
session.commit()
end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1118,8 +1133,9 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
GraphNodes.delete().execute()
GraphEdges.delete().execute()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
session.commit()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1186,12 +1202,27 @@ class EntorhinalCortex:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
# 批量插入边
# 批量写入节点
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
session.commit()
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
# 批量写入边
edge_start = time.time()
if edges_data:
batch_size = 100
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
end_time = time.time()
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1211,9 +1242,7 @@ class EntorhinalCortex:
skipped_nodes = 0
# 从数据库加载所有节点
nodes = list(GraphNodes.select())
total_nodes = len(nodes)
nodes = list(session.execute(select(GraphNodes)).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1235,8 +1264,10 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
if update_data:
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1256,7 +1287,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(GraphEdges.select())
edges = list(session.execute(select(GraphEdges)).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1272,9 +1303,12 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
GraphEdges.update(**update_data).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.created_time or current_time
@@ -1398,7 +1432,6 @@ class ParahippocampalGyrus:
all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.7:
@@ -1502,7 +1535,7 @@ class ParahippocampalGyrus:
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
# 只有在有足够的节点和边时进行采样
# 只有在有足够的节点和边时进行采样
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
try:
nodes_to_check = random.sample(all_nodes, check_nodes_count)
@@ -1548,6 +1581,11 @@ class ParahippocampalGyrus:
logger.info("[遗忘] 开始检查节点...")
node_check_start = time.time()
# 初始化整合相关变量
merged_count = 0
nodes_modified = set()
for node in nodes_to_check:
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
if node not in self.memory_graph.G:
@@ -1567,64 +1605,91 @@ class ParahippocampalGyrus:
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
continue # 处理下一个节点
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
# 检查节点的最后修改时间,如果太旧则尝试遗忘
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
continue
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
# 随机遗忘一条记忆
if len(memory_items) > 1:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)")
elif len(memory_items) == 1:
# 如果只有一条记忆,检查是否应该完全移除节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node} (最后记忆)")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}")
# 检查节点内是否有相似的记忆项需要整合
if len(memory_items) > 1:
merged_in_this_node = False
items_to_remove = []
for i in range(len(memory_items)):
for j in range(i + 1, len(memory_items)):
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
if similarity > 0.8: # 相似度阈值
# 合并相似记忆项
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
# 保留更长的记忆项,标记短的用于删除
if shorter_item not in items_to_remove:
items_to_remove.append(shorter_item)
merged_count += 1
merged_in_this_node = True
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
# 移除被合并的记忆项
if items_to_remove:
for item in items_to_remove:
if item in memory_items:
memory_items.remove(item)
nodes_modified.add(node)
# 更新节点的记忆项
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
self.memory_graph.G.nodes[node]["last_modified"] = current_time
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
if any(edge_changes.values()) or any(node_changes.values()):
# 输出变化统计
if edge_changes["weakened"]:
logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接")
if edge_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接")
if node_changes["reduced"]:
logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆")
if node_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点")
# 检查是否有变化需要同步到数据库
has_changes = (
edge_changes["weakened"] or
edge_changes["removed"] or
node_changes["reduced"] or
node_changes["removed"] or
merged_count > 0
)
if has_changes:
logger.info("[遗忘] 开始将变更同步到数据库...")
sync_start = time.time()
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
sync_end = time.time()
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}")
# 汇总输出所有变化
logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]:
logger.info(
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
)
if edge_changes["removed"]:
logger.info(
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
)
if node_changes["reduced"]:
logger.info(
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
)
if node_changes["removed"]:
logger.info(
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
)
if merged_count > 0:
logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。")
sync_start = time.time()
logger.info("[整合] 开始将变更同步到数据库...")
# 使用 resync 更安全地处理删除和添加
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
sync_end = time.time()
logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}")
else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
logger.info("[整合] 本次检查未发现需要合并的记忆项。")
@@ -1734,10 +1799,7 @@ class HippocampusManager:
"""获取所有节点名称的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus.get_all_node_names()
# 创建全局实例
hippocampus_manager = HippocampusManager()

View File

@@ -10,12 +10,13 @@ from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import model_config
from sqlalchemy import select
logger = get_logger(__name__)
session = get_session()
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
@@ -120,7 +121,8 @@ class InstantMemory:
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
memory.save()
session.add(memory)
session.commit()
async def get_memory(self, target: str):
from json_repair import repair_json
@@ -166,13 +168,13 @@ class InstantMemory:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
query = session.execute(select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)).scalars()
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
for mem in query:
# 对每条记忆

View File

@@ -8,8 +8,12 @@ from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from sqlalchemy import select, text
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
@@ -19,7 +23,7 @@ install(extra_lines=3)
logger = get_logger("chat_stream")
session = get_session()
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
@@ -131,7 +135,8 @@ class ChatManager:
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
session.commit()
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
@@ -231,7 +236,7 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
@@ -342,7 +347,28 @@ class ChatManager:
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
session.execute(stmt)
session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -361,7 +387,7 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
for model_instance in session.execute(select(ChatStreams)).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,

View File

@@ -1,16 +1,18 @@
import re
import json
import traceback
import json
from typing import Union
from src.common.database.database_model import Messages, Images
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, update, desc
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
@@ -33,15 +35,11 @@ class MessageStorage:
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# print(message)
processed_plain_text = message.processed_plain_text
# print(processed_plain_text)
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
@@ -93,11 +91,16 @@ class MessageStorage:
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
Messages.create(
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = json.dumps(priority_info) if priority_info else None
# 获取数据库会话
session = get_session()
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time), # type: ignore
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
@@ -111,18 +114,16 @@ class MessageStorage:
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
@@ -131,35 +132,44 @@ class MessageStorage:
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
)
session.add(new_message)
session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
# 如果需要其他存储相关的函数,可以在这里添加
@staticmethod
async def update_message(
message: MessageRecv,
) -> None: # 用于实时更新数据库的自身发送消息ID目前能处理text,reply,image和emoji
"""更新最新一条匹配消息的message_id"""
async def update_message(message):
"""更新消息ID"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
mmc_message_id = message.message_info.message_id # 修复正确访问message_id
if message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
if matched_message:
session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
# session.commit() 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@@ -178,10 +188,12 @@ class MessageStorage:
def replace_match(match):
description = match.group(1).strip()
try:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -7,13 +7,14 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person,get_person_id
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, and_
install(extra_lines=3)
session = get_session()
def replace_user_references_sync(
content: str,
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
if image and image.description:
description = image.description
except Exception:
@@ -813,7 +854,7 @@ def build_readable_messages(
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_actions: bool = True,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
@@ -846,21 +887,21 @@ def build_readable_messages(
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
actions_in_range = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id
)
.order_by(ActionRecords.time)
)
).order_by(ActionRecords.time)).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
action_after_latest = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time > max_time,
ActionRecords.chat_id == chat_id
)
).order_by(ActionRecords.time).limit(1)).scalars()
# 合并两部分动作记录
actions = list(actions_in_range) + list(action_after_latest)

View File

@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建新的事件循环
import threading
result = None
exception = None
def run_in_thread():
nonlocal result, exception
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result = new_loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
new_loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
if exception:
raise exception
return result
else:
return loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
except RuntimeError:
# 没有事件循环,创建一个新的
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: int | None = None # Changed to int for Peewee's default ID
self.record_id: int | None = None
"""记录ID"""
self._init_database() # 初始化数据库
@staticmethod
def _init_database():
"""初始化数据库"""
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
if self.record_id:
# 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
updated_rows = await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
self.record_id = None
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
if not self.record_id:
# 查找最近一分钟内的记录
recent_threshold = current_time - timedelta(minutes=1)
recent_records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": recent_threshold}},
order_by="-end_timestamp",
limit=1,
single_result=True
)
if recent_record:
# 如果有记录,更新结束时间
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
if recent_records:
# 找到近期记录,更新
self.record_id = recent_records['id']
await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
self.record_id = new_record.id
else:
# 创建新记录
new_record = await db_save(
model_class=OnlineTime,
data={
"timestamp": str(current_time),
"duration": 5, # 初始时长为5分钟
"start_timestamp": current_time,
"end_timestamp": extended_end_time,
}
)
if new_record:
self.record_id = new_record['id']
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_timestamp = record['timestamp'] # 从字典中获取
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
request_type = record.get('request_type') or "unknown"
user_id = record.get('user_id') or "unknown"
model_name = record.get('model_name') or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
prompt_tokens = record.get('prompt_tokens') or 0
completion_tokens = record.get('completion_tokens') or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
records = _sync_db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp"
)
for record in records:
record_end_timestamp = record['end_timestamp']
record_start_timestamp = record['start_timestamp']
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time'] # This is a float timestamp
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get('chat_info_group_id'):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get('user_nickname') # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
)
continue
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_time = record['timestamp']
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
model_name = record.get('model_name') or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
request_type = record.request_type or "unknown"
request_type = record.get('request_type') or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time']
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.chat_info_group_id:
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id:
chat_name = message.user_nickname or f"用户{message.user_id}"
if message.get('chat_info_group_id'):
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'):
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
else:
continue

View File

@@ -13,10 +13,12 @@ from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import get_db_session
from sqlalchemy import select, and_
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -41,9 +43,10 @@ class ImageManager:
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
# 使用SQLAlchemy创建表已在初始化时完成
logger.debug("使用SQLAlchemy进行表管理")
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
logger.error(f"数据库连接失败: {e}")
self._initialized = True
@@ -63,12 +66,13 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
record = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None
@staticmethod
@@ -82,16 +86,28 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
# 查找现有记录
existing = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
if existing:
# 更新现有记录
existing.description = description
existing.timestamp = current_timestamp
else:
# 创建新记录
new_desc = ImageDescriptions(
image_description_hash=image_hash,
type=description_type,
description=description,
timestamp=current_timestamp
)
session.add(new_desc)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -214,19 +230,29 @@ class ImageManager:
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = detailed_description # 保存详细描述
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist: # type: ignore
Images.create(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(select(Images).where(
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
)).scalar()
if existing_img:
existing_img.path = file_path
existing_img.description = detailed_description # 保存详细描述
existing_img.timestamp = current_timestamp
else:
new_img = Images(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
session.add(new_img)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
except Exception as e:
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
@@ -249,19 +275,19 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
@@ -300,10 +326,10 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
session.commit()
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
new_img = Images(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
@@ -313,6 +339,8 @@ class ImageManager:
vlm_processed=True,
count=1,
)
session.add(new_img)
session.commit()
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -465,31 +493,32 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
existing_image.count += 1
session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -503,7 +532,7 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
new_img = Images(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
@@ -512,6 +541,8 @@ class ImageManager:
vlm_processed=False,
count=1,
)
session.add(new_img)
session.commit()
# 启动异步VLM处理
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
@@ -536,60 +567,64 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
# 获取当前图片记录
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = session.execute(select(Images).where(
and_(
Images.emoji_hash == image_hash,
Images.description.isnot(None),
Images.description != "",
Images.id != image.id
)
)).scalar()
if existing_with_description:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
session.commit()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
session.commit()
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
return
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
except Exception as e:
logger.error(f"VLM处理图片失败: {str(e)}")

View File

@@ -1,14 +1,103 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_session
install(extra_lines=3)
_client = None
_db = None
_sql_engine = None
logger = get_logger("database")
# 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy
class DatabaseProxy:
"""数据库代理类提供Peewee到SQLAlchemy的兼容性接口"""
def __init__(self):
self._engine = None
self._session = None
def initialize(self, *args, **kwargs):
"""初始化数据库连接"""
return initialize_database_compat()
def connect(self, reuse_if_open=True):
"""连接数据库(兼容性方法)"""
try:
self._engine = get_engine()
return True
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return False
def is_closed(self):
"""检查数据库是否关闭(兼容性方法)"""
return self._engine is None
def create_tables(self, models, safe=True):
"""创建表(兼容性方法)"""
try:
from src.common.database.sqlalchemy_models import Base
engine = get_engine()
Base.metadata.create_all(bind=engine)
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def table_exists(self, model):
"""检查表是否存在(兼容性方法)"""
try:
from sqlalchemy import inspect
engine = get_engine()
inspector = inspect(engine)
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
return table_name in inspector.get_table_names()
except Exception:
return False
def execute_sql(self, sql):
"""执行SQL兼容性方法"""
try:
from sqlalchemy import text
session = get_session()
result = session.execute(text(sql))
session.close()
return result
except Exception as e:
logger.error(f"执行SQL失败: {e}")
raise
def atomic(self):
"""事务上下文管理器(兼容性方法)"""
return SQLAlchemyTransaction()
class SQLAlchemyTransaction:
"""SQLAlchemy事务上下文管理器"""
def __init__(self):
self.session = None
def __enter__(self):
self.session = get_session()
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.session.commit()
else:
self.session.rollback()
self.session.close()
# 创建全局数据库代理实例
db = DatabaseProxy()
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
@@ -39,7 +128,7 @@ def __create_database_instance():
def get_db():
"""获取数据库连接实例,延迟初始化。"""
"""获取MongoDB连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
@@ -47,6 +136,47 @@ def get_db():
return _db
def initialize_sql_database(database_config):
"""
根据配置初始化SQL数据库连接SQLAlchemy版本
Args:
database_config: DatabaseConfig对象
"""
global _sql_engine
try:
logger.info("使用SQLAlchemy初始化SQL数据库...")
# 记录数据库配置信息
if database_config.database_type == "mysql":
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
logger.info("MySQL数据库连接配置:")
logger.info(f" 连接信息: {connection_info}")
logger.info(f" 字符集: {database_config.mysql_charset}")
else:
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if not os.path.isabs(database_config.sqlite_path):
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
else:
db_path = database_config.sqlite_path
logger.info("SQLite数据库连接配置:")
logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化
success = initialize_database_compat()
if success:
_sql_engine = get_engine()
logger.info("SQLAlchemy数据库初始化成功")
else:
logger.error("SQLAlchemy数据库初始化失败")
return _sql_engine
except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}")
return None
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
@@ -57,26 +187,6 @@ class DBWrapper:
return get_db()[key] # type: ignore
# 全局数据库访问点
# 全局MongoDB数据库访问点
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(
_DB_FILE,
pragmas={
"journal_mode": "wal", # WAL模式提高并发性能
"cache_size": -64 * 1000, # 64MB缓存
"foreign_keys": 1,
"ignore_check_constraints": 0,
"synchronous": 0, # 异步写入提高性能
"busy_timeout": 1000, # 1秒超时而不是3秒
},
)

View File

@@ -0,0 +1,420 @@
"""SQLAlchemy数据库API模块
提供基于SQLAlchemy的数据库操作替换Peewee以解决MySQL连接问题
支持自动重连、连接池管理和更好的错误处理
"""
import traceback
import time
from typing import Dict, List, Any, Union, Type, Optional
from contextlib import contextmanager
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
from sqlalchemy import desc, asc, func, and_, or_
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
)
logger = get_logger("sqlalchemy_database_api")
# 模型映射表,用于通过名称获取模型类
MODEL_MAPPING = {
'Messages': Messages,
'ActionRecords': ActionRecords,
'PersonInfo': PersonInfo,
'ChatStreams': ChatStreams,
'LLMUsage': LLMUsage,
'Emoji': Emoji,
'Images': Images,
'ImageDescriptions': ImageDescriptions,
'OnlineTime': OnlineTime,
'Memory': Memory,
'Expression': Expression,
'ThinkingLog': ThinkingLog,
'GraphNodes': GraphNodes,
'GraphEdges': GraphEdges,
}
@contextmanager
def get_db_session():
"""数据库会话上下文管理器,自动处理事务和连接错误"""
session = None
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
session = get_session()
yield session
session.commit()
break
except (DisconnectionError, OperationalError) as e:
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if session:
session.rollback()
session.close()
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
for field_name, value in filters.items():
if not hasattr(model_class, field_name):
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
continue
field = getattr(model_class, field_name)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(~field.in_(op_value))
else:
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
else:
# 直接相等比较
conditions.append(field == value)
return conditions
async def db_query(
model_class: Type[Base],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
query_type: 查询类型 ("get", "create", "update", "delete", "count")
filters: 过滤条件字典
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
with get_db_session() as session:
if query_type == "get":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 应用排序
if order_by:
for field_name in order_by:
if field_name.startswith("-"):
field_name = field_name[1:]
if hasattr(model_class, field_name):
query = query.order_by(desc(getattr(model_class, field_name)))
else:
if hasattr(model_class, field_name):
query = query.order_by(asc(getattr(model_class, field_name)))
# 应用限制
if limit and limit > 0:
query = query.limit(limit)
# 执行查询
results = query.all()
# 转换为字典格式
result_dicts = []
for result in results:
result_dict = {}
for column in result.__table__.columns:
result_dict[column.name] = getattr(result, column.name)
result_dicts.append(result_dict)
if single_result:
return result_dicts[0] if result_dicts else None
return result_dicts
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush() # 获取自动生成的ID
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行更新
affected_rows = query.update(data)
return affected_rows
elif query_type == "delete":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行删除
affected_rows = query.delete()
return affected_rows
elif query_type == "count":
query = session.query(func.count(model_class.id))
# 应用过滤条件
if filters:
base_query = session.query(model_class)
conditions = build_filters(session, model_class, filters)
if conditions:
base_query = base_query.filter(and_(*conditions))
query = session.query(func.count()).select_from(base_query.subquery())
return query.scalar()
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 意外错误: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(
model_class: Type[Base],
data: Dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
"""保存数据到数据库(创建或更新)
Args:
model_class: SQLAlchemy模型类
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名
key_value: 用于查找现有记录的字段值
Returns:
保存后的记录数据或None
"""
try:
with get_db_session() as session:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if hasattr(model_class, key_field):
existing_record = session.query(model_class).filter(
getattr(model_class, key_field) == key_value
).first()
if existing_record:
# 更新现有记录
for field, value in data.items():
if hasattr(existing_record, field):
setattr(existing_record, field, value)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in existing_record.__table__.columns:
result_dict[column.name] = getattr(existing_record, column.name)
return result_dict
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Base],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
order_by_list = [order_by] if order_by else None
return await db_query(
model_class=model_class,
query_type="get",
filters=filters,
limit=limit,
order_by=order_by_list,
single_result=single_result
)
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
Args:
chat_stream: 聊天流对象
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
保存的记录数据或None
"""
try:
import json
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update({
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
})
else:
record_data.update({
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
})
# 保存记录
saved_record = await db_save(
ActionRecords,
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 兼容性函数方便从Peewee迁移
def get_model_class(model_name: str) -> Optional[Type[Base]]:
"""根据模型名称获取模型类"""
return MODEL_MAPPING.get(model_name)

View File

@@ -0,0 +1,158 @@
"""SQLAlchemy数据库初始化模块
替换Peewee的数据库初始化逻辑
提供统一的数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_engine, get_session, initialize_database
)
logger = get_logger("sqlalchemy_init")
def initialize_sqlalchemy_database() -> bool:
"""
初始化SQLAlchemy数据库
创建所有表结构
Returns:
bool: 初始化是否成功
"""
try:
logger.info("开始初始化SQLAlchemy数据库...")
# 初始化数据库引擎和会话
engine, session_local = initialize_database()
if engine is None:
logger.error("数据库引擎初始化失败")
return False
logger.info("SQLAlchemy数据库初始化成功")
return True
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
return False
except Exception as e:
logger.error(f"数据库初始化过程中发生未知错误: {e}")
return False
def create_all_tables() -> bool:
"""
创建所有数据库表
Returns:
bool: 创建是否成功
"""
try:
logger.info("开始创建数据库表...")
engine = get_engine()
if engine is None:
logger.error("无法获取数据库引擎")
return False
# 创建所有表
Base.metadata.create_all(bind=engine)
logger.info("数据库表创建成功")
return True
except SQLAlchemyError as e:
logger.error(f"创建数据库表失败: {e}")
return False
except Exception as e:
logger.error(f"创建数据库表过程中发生未知错误: {e}")
return False
def check_database_connection() -> bool:
"""
检查数据库连接是否正常
Returns:
bool: 连接是否正常
"""
try:
session = get_session()
if session is None:
logger.error("无法获取数据库会话")
return False
# 检查会话是否可用(如果能获取到会话说明连接正常)
if session is None:
logger.error("数据库会话无效")
return False
session.close()
logger.info("数据库连接检查通过")
return True
except SQLAlchemyError as e:
logger.error(f"数据库连接检查失败: {e}")
return False
except Exception as e:
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
return False
def get_database_info() -> Optional[dict]:
"""
获取数据库信息
Returns:
dict: 数据库信息字典,包含引擎信息等
"""
try:
engine = get_engine()
if engine is None:
return None
info = {
'engine_name': engine.name,
'driver': engine.driver,
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
'pool_size': getattr(engine.pool, 'size', None),
'max_overflow': getattr(engine.pool, 'max_overflow', None),
}
return info
except Exception as e:
logger.error(f"获取数据库信息失败: {e}")
return None
_database_initialized = False
def initialize_database_compat() -> bool:
"""
兼容性数据库初始化函数
用于替换原有的Peewee初始化代码
Returns:
bool: 初始化是否成功
"""
global _database_initialized
if _database_initialized:
return True
success = initialize_sqlalchemy_database()
if success:
success = create_all_tables()
if success:
success = check_database_connection()
if success:
_database_initialized = True
return success

View File

@@ -0,0 +1,555 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
"""
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
import os
import datetime
from src.config.config import global_config
from src.common.logger import get_logger
import threading
from contextlib import contextmanager
logger = get_logger("sqlalchemy_models")
# 创建基类
Base = declarative_base()
# MySQL兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
"""
if global_config.database.database_type == "mysql":
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception as e:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base):
"""聊天流模型"""
__tablename__ = 'chat_streams'
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
__table_args__ = (
Index('idx_chatstreams_stream_id', 'stream_id'),
Index('idx_chatstreams_user_id', 'user_id'),
Index('idx_chatstreams_group_id', 'group_id'),
)
class LLMUsage(Base):
"""LLM使用记录模型"""
__tablename__ = 'llm_usage'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index('idx_llmusage_model_name', 'model_name'),
Index('idx_llmusage_user_id', 'user_id'),
Index('idx_llmusage_request_type', 'request_type'),
Index('idx_llmusage_timestamp', 'timestamp'),
)
class Emoji(Base):
"""表情包模型"""
__tablename__ = 'emoji'
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_emoji_full_path', 'full_path'),
Index('idx_emoji_hash', 'emoji_hash'),
)
class Messages(Base):
"""消息模型"""
__tablename__ = 'messages'
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_messages_message_id', 'message_id'),
Index('idx_messages_chat_id', 'chat_id'),
Index('idx_messages_time', 'time'),
Index('idx_messages_user_id', 'user_id'),
)
class ActionRecords(Base):
"""动作记录模型"""
__tablename__ = 'action_records'
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
__table_args__ = (
Index('idx_actionrecords_action_id', 'action_id'),
Index('idx_actionrecords_chat_id', 'chat_id'),
Index('idx_actionrecords_time', 'time'),
)
class Images(Base):
"""图像信息模型"""
__tablename__ = 'images'
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_images_emoji_hash', 'emoji_hash'),
Index('idx_images_path', 'path'),
)
class ImageDescriptions(Base):
"""图像描述信息模型"""
__tablename__ = 'image_descriptions'
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
__table_args__ = (
Index('idx_imagedesc_hash', 'image_description_hash'),
)
class OnlineTime(Base):
"""在线时长记录模型"""
__tablename__ = 'online_time'
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
__table_args__ = (
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
)
class PersonInfo(Base):
"""人物信息模型"""
__tablename__ = 'person_info'
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
__table_args__ = (
Index('idx_personinfo_person_id', 'person_id'),
Index('idx_personinfo_user_id', 'user_id'),
)
class Memory(Base):
"""记忆模型"""
__tablename__ = 'memory'
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_memory_memory_id', 'memory_id'),
)
class Expression(Base):
"""表达风格模型"""
__tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False)
style = Column(Text, nullable=False)
count = Column(Float, nullable=False)
last_active_time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False)
create_date = Column(Float, nullable=True)
__table_args__ = (
Index('idx_expression_chat_id', 'chat_id'),
)
class ThinkingLog(Base):
"""思考日志模型"""
__tablename__ = 'thinking_logs'
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_thinkinglog_chat_id', 'chat_id'),
)
class GraphNodes(Base):
"""记忆图节点模型"""
__tablename__ = 'graph_nodes'
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphnodes_concept', 'concept'),
)
class GraphEdges(Base):
"""记忆图边模型"""
__tablename__ = 'graph_edges'
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphedges_source', 'source'),
Index('idx_graphedges_target', 'target'),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url():
"""获取数据库连接URL"""
config = global_config.database
if config.database_type == "mysql":
# 对用户名和密码进行URL编码处理特殊字符
from urllib.parse import quote_plus
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
else: # SQLite
# 如果是相对路径,则相对于项目根目录
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
def initialize_database():
"""初始化数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
return _engine, _SessionLocal
database_url = get_database_url()
config = global_config.database
# 配置引擎参数
engine_kwargs = {
'echo': False, # 生产环境关闭SQL日志
'future': True,
}
if config.database_type == "mysql":
# MySQL连接池配置
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': config.connection_pool_size,
'max_overflow': config.connection_pool_size * 2,
'pool_timeout': config.connection_timeout,
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'autocommit': config.mysql_autocommit,
'charset': config.mysql_charset,
'connect_timeout': config.connection_timeout,
'read_timeout': 30,
'write_timeout': 30,
}
})
else:
# SQLite配置 - 添加连接池设置以避免连接耗尽
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': 20, # 增加池大小
'max_overflow': 30, # 增加溢出连接数
'pool_timeout': 60, # 增加超时时间
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'check_same_thread': False,
'timeout': 30,
}
})
_engine = create_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
# 创建所有表
Base.metadata.create_all(bind=_engine)
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@contextmanager
def get_db_session():
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session = None
try:
_, SessionLocal = initialize_database()
session = SessionLocal()
yield session
session.commit()
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def get_engine():
"""获取数据库引擎"""
engine, _ = initialize_database()
return engine

View File

@@ -373,6 +373,7 @@ MODULE_COLORS = {
"base_command": "\033[38;5;208m", # 橙色
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"plugin_hot_reload": "\033[38;5;226m", #品红色
"config_api": "\033[38;5;226m", # 亮黄色
"heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
@@ -406,6 +407,7 @@ MODULE_COLORS = {
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
"database_model": "\033[38;5;94m", # 橙褐色
"database": "\033[38;5;46m", # 橙褐色
"maim_message": "\033[38;5;140m", # 紫褐色
# 日志系统
"logger": "\033[38;5;8m", # 深灰色
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
"memory_activator": "记忆",
"tool_use": "工具",
"expressor": "表达方式",
"plugin_hot_reload": "热重载",
"database": "数据库",
"database_model": "数据库",
"mood": "情绪",
"memory": "记忆",

View File

@@ -1,20 +1,26 @@
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from typing import List, Optional, Any, Dict
from sqlalchemy import not_, select, func
from sqlalchemy.orm import DeclarativeBase
from src.config.config import global_config
from src.common.database.database_model import Messages
# from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_session
from src.common.logger import get_logger
logger = get_logger(__name__)
class Base(DeclarativeBase):
pass
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
def _model_to_dict(instance: Base) -> Dict[str, Any]:
"""
Peewee 模型实例转换为字典。
SQLAlchemy 模型实例转换为字典。
"""
return model_instance.__data__
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
def find_messages(
@@ -38,7 +44,8 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
query = Messages.select()
session = get_session()
query = select(Messages)
# 应用过滤器
if message_filter:
@@ -77,42 +84,57 @@ def find_messages(
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not Messages.is_command)
query = query.where(not_(Messages.is_command))
if limit > 0:
# 确保limit是正整数
limit = max(1, int(limit))
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
peewee_results = list(query)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
peewee_sort_terms = []
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in peewee_results]
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
query = Messages.select()
session = get_session()
query = select(func.count(Messages.id))
# 应用过滤器
if message_filter:
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = query.count()
return count
count = session.execute(query).scalar()
return count or 0
except Exception as e:
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

View File

@@ -13,6 +13,7 @@ from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
DatabaseConfig,
BotConfig,
PersonalityConfig,
ExpressionConfig,
@@ -340,6 +341,7 @@ class Config(ConfigBase):
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
database: DatabaseConfig
bot: BotConfig
personality: PersonalityConfig
relationship: RelationshipConfig
@@ -466,4 +468,25 @@ update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
# 初始化数据库连接
logger.info("正在初始化数据库连接...")
from src.common.database.database import initialize_sql_database
try:
initialize_sql_database(global_config.database)
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
except Exception as e:
logger.error(f"数据库连接初始化失败: {e}")
raise e
# 初始化数据库表结构
logger.info("正在初始化数据库表结构...")
from src.common.database.sqlalchemy_models import initialize_database as init_db
try:
init_db()
logger.info("数据库表结构初始化完成")
except Exception as e:
logger.error(f"数据库表结构初始化失败: {e}")
raise e
logger.info("非常的新鲜,非常的美味!")

View File

@@ -13,6 +13,65 @@ from src.config.config_base import ConfigBase
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class DatabaseConfig(ConfigBase):
"""数据库配置类"""
database_type: Literal["sqlite", "mysql"] = "sqlite"
"""数据库类型,支持 sqlite 或 mysql"""
# SQLite 配置
sqlite_path: str = "data/MaiBot.db"
"""SQLite数据库文件路径"""
# MySQL 配置
mysql_host: str = "localhost"
"""MySQL服务器地址"""
mysql_port: int = 3306
"""MySQL服务器端口"""
mysql_database: str = "maibot"
"""MySQL数据库名"""
mysql_user: str = "root"
"""MySQL用户名"""
mysql_password: str = ""
"""MySQL密码"""
mysql_charset: str = "utf8mb4"
"""MySQL字符集"""
mysql_unix_socket: str = ""
"""MySQL Unix套接字路径可选用于本地连接优先于host/port"""
# MySQL SSL 配置
mysql_ssl_mode: str = "DISABLED"
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
mysql_ssl_ca: str = ""
"""SSL CA证书路径"""
mysql_ssl_cert: str = ""
"""SSL客户端证书路径"""
mysql_ssl_key: str = ""
"""SSL客户端密钥路径"""
# MySQL 高级配置
mysql_autocommit: bool = True
"""自动提交事务"""
mysql_sql_mode: str = "TRADITIONAL"
"""SQL模式"""
# 连接池配置
connection_pool_size: int = 10
"""连接池大小仅MySQL有效"""
connection_timeout: int = 10
"""连接超时时间(秒)"""
@dataclass
class BotConfig(ConfigBase):
@@ -72,6 +131,19 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
replyer_random_probability: float = 0.5
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
thinking_timeout: int = 40
"""麦麦最长思考规划时间超过这个时间的思考会放弃往往是api反应太慢"""
talk_frequency: float = 1
"""回复频率阈值"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
@@ -93,17 +165,17 @@ class ChatConfig(ConfigBase):
"""
统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例:
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
]
说明:
- 当第一个元素为空字符串""时,表示全局默认配置
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
@@ -155,72 +227,11 @@ class ChatConfig(ConfigBase):
# 检查全局时段配置(第一个元素为空字符串的配置)
global_frequency = self._get_global_frequency()
return self.talk_frequency if global_frequency is None else global_frequency
def _get_global_focus_value(self) -> Optional[float]:
"""
获取全局默认专注度配置
if global_frequency is not None:
return global_frequency
Returns:
float: 专注度值,如果没有配置则返回 None
"""
for config_item in self.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return self._get_time_based_focus_value(config_item[1:])
return None
def _get_time_based_focus_value(self, time_focus_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的专注度
Args:
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
Returns:
float: 专注度值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间专注度配置
time_focus_pairs = []
for time_focus_str in time_focus_list:
try:
time_str, focus_str = time_focus_str.split(",")
hour, minute = map(int, time_str.split(":"))
focus_value = float(focus_str)
minutes = hour * 60 + minute
time_focus_pairs.append((minutes, focus_value))
except (ValueError, IndexError):
continue
if not time_focus_pairs:
return None
# 按时间排序
time_focus_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的专注度
current_focus_value = None
for minutes, focus_value in time_focus_pairs:
if current_minutes >= minutes:
current_focus_value = focus_value
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
if current_focus_value is None and time_focus_pairs:
current_focus_value = time_focus_pairs[-1][1]
return current_focus_value
# 如果都没有匹配,返回默认值
return self.talk_frequency
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
"""
@@ -395,6 +406,14 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
willing_mode: str = "classical"
"""意愿模式"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
@@ -403,14 +422,14 @@ class ExpressionConfig(ConfigBase):
"""
表达学习配置列表,支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
示例:
[
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
]
说明:
- 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable")
@@ -475,14 +494,14 @@ class ExpressionConfig(ConfigBase):
# 优先检查聊天流特定的配置
if chat_stream_id:
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
if specific_expression_config is not None:
return specific_expression_config
specific_config = self._get_stream_specific_config(chat_stream_id)
if specific_config is not None:
return specific_config
# 检查全局配置(第一个元素为空字符串的配置)
global_expression_config = self._get_global_config()
if global_expression_config is not None:
return global_expression_config
global_config = self._get_global_config()
if global_config is not None:
return global_config
# 如果都没有匹配,返回默认值
return True, True, 300
@@ -518,10 +537,10 @@ class ExpressionConfig(ConfigBase):
# 解析配置
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity: float = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
@@ -541,10 +560,10 @@ class ExpressionConfig(ConfigBase):
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
@@ -558,7 +577,6 @@ class ToolConfig(ConfigBase):
enable_tool: bool = False
"""是否在聊天中启用工具"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@@ -703,7 +721,6 @@ class KeywordReactionConfig(ConfigBase):
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类"""
@@ -852,3 +869,4 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""

View File

@@ -5,8 +5,7 @@ from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.common.database.sqlalchemy_models import LLMUsage, get_session
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
@@ -143,16 +142,9 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
class LLMUsageRecorder:
"""
LLM使用情况记录器
LLM使用情况记录器SQLAlchemy版本
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
@@ -160,9 +152,13 @@ class LLMUsageRecorder:
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
session = None
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
# 使用 SQLAlchemy 会话创建记录
session = get_session()
usage_record = LLMUsage(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
model_api_provider=model_info.api_provider,
@@ -175,8 +171,12 @@ class LLMUsageRecorder:
cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
)
session.add(usage_record)
session.commit()
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
@@ -184,6 +184,11 @@ class LLMUsageRecorder:
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
if session:
session.rollback()
logger.error(f"记录token使用情况失败: {str(e)}")
finally:
if session:
session.close()
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,5 +1,7 @@
import asyncio
import time
import signal
import sys
from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask
@@ -17,8 +19,9 @@ from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server
# 导入新的插件管理器
# 导入新的插件管理器和热重载管理器
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
@@ -48,6 +51,28 @@ class MainSystem:
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
# 设置信号处理器用于优雅退出
self._setup_signal_handlers()
def _setup_signal_handlers(self):
"""设置信号处理器"""
def signal_handler(signum, frame):
logger.info("收到退出信号,正在优雅关闭系统...")
self._cleanup()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def _cleanup(self):
"""清理资源"""
try:
# 停止插件热重载系统
hot_reload_manager.stop()
logger.info("🛑 插件热重载系统已停止")
except Exception as e:
logger.error(f"停止热重载系统时出错: {e}")
async def initialize(self):
"""初始化系统组件"""
logger.info(f"正在唤醒{global_config.bot.nickname}......")
@@ -58,14 +83,7 @@ class MainSystem:
logger.info(f"""
--------------------------------
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
--------------------------------
如果想要自定义{global_config.bot.nickname}的功能,请查阅https://docs.mai-mai.org/manual/usage/
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
--------------------------------
如果你想要编写或了解插件相关内容请访问开发文档https://docs.mai-mai.org/develop/
--------------------------------
如果你需要查阅模型的消耗以及麦麦的统计数据请访问根目录的maibot_statistics.html文件
""")
--------------------------------""")
async def _init_components(self):
"""初始化其他组件"""
@@ -87,6 +105,10 @@ class MainSystem:
# 加载所有actions包括默认的和插件的
plugin_manager.load_all_plugins()
# 启动插件热重载系统
hot_reload_manager.start()
# 初始化表情管理器
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")

View File

View File

@@ -2,17 +2,17 @@ import hashlib
import asyncio
import json
import time
import random
from json_repair import repair_json
from typing import Union
from typing import Any, Callable, Dict, Union, Optional
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.common.database.sqlalchemy_models import PersonInfo
from src.common.database.sqlalchemy_database_api import get_session
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
session = get_session()
logger = get_logger("person_info")
@@ -380,36 +380,282 @@ class Person:
return relation_info
# 统一的会话管理函数
def with_session(func):
"""装饰器为函数自动注入session参数"""
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
return await func(session, *args, **kwargs)
return async_wrapper
else:
def sync_wrapper(*args, **kwargs):
return func(session, *args, **kwargs)
return sync_wrapper
# 全局会话获取函数用于替换所有裸露的session使用
def _get_session():
"""获取数据库会话的统一函数"""
return get_session()
class PersonInfoManager:
def __init__(self):
"""初始化PersonInfoManager"""
from src.common.database.sqlalchemy_models import PersonInfo
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
# 设置连接池参数仅对SQLite有效
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
# 检查数据库类型只对SQLite执行PRAGMA语句
if global_config.database.database_type == "sqlite":
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
from src.common.database.sqlalchemy_models import PersonInfo
# 在这里获取会话
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_not(None)
)).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e:
logger.error(f"Peewee 加载 person_name_list 失败: {e}")
logger.error(f"SQLAlchemy 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
platform = platform.split("-")[1]
components = [platform, str(user_id)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
def _db_check_known_sync(p_id: str):
# 在需要时获取会话
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
# 在需要时获取会话
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
return ""
@staticmethod
async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
final_data = {"person_id": person_id}
# Start with defaults for all model fields
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
"""安全地创建用户信息,处理竞态条件"""
if not person_id:
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
final_data = {"person_id": person_id}
# Start with defaults for all model fields
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(p_data: dict):
try:
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义的字段。")
return
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
def _db_update_sync(p_id: str, f_name: str, val_to_set):
start_time = time.time()
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
session.commit()
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
return True, False # Found and updated, no creation needed
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
session.rollback()
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
if needs_creation:
logger.info(f"{person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
# Ensure platform and user_id are present for context if available from 'data'
# but primarily, set the field that triggered the update.
# The create_person_info will handle defaults and serialization.
creation_data[field_name] = value # Pass original value to create_person_info
# Ensure platform and user_id are in creation_data if available,
# otherwise create_person_info will use defaults.
if data and "platform" in data:
creation_data["platform"] = data["platform"]
if data and "user_id" in data:
creation_data["user_id"] = data["user_id"]
# 使用安全的创建方法,处理竞态条件
await self._safe_create_person_info(person_id, creation_data)
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
return False
@staticmethod
def _extract_json_from_text(text: str) -> dict:
@@ -513,12 +759,13 @@ class PersonInfoManager:
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
person.person_name = generated_nickname
person.name_reason = result.get("reason", "未提供理由")
@@ -547,4 +794,304 @@ class PersonInfoManager:
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
person_info_manager = PersonInfoManager()
@staticmethod
async def del_one_document(person_id: str):
"""删除指定 person_id 的文档"""
if not person_id:
logger.debug("删除失败person_id 不能为空")
return
def _db_delete_sync(p_id: str):
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
session.delete(record)
session.commit()
return 1
return 0
except Exception as e:
session.rollback()
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return []
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_values获取失败person_id不能为空")
return {}
result = {}
def _db_get_record_sync(p_id: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
for field_name in field_names:
if field_name not in model_fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result
@staticmethod
async def get_specific_value_list(
field_name: str,
way: Callable[[Any], bool],
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
except Exception as e_query:
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e:
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {}
async def get_or_create_person(
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。
使用try-except处理竞态条件避免重复创建错误。
"""
person_id = self.get_person_id(platform, user_id)
def _db_get_or_create_sync(p_id: str, init_data: dict):
"""原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
session.commit()
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
except Exception as e:
session.rollback()
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
unique_nickname = await self._generate_unique_person_name(nickname)
initial_data = {
"person_id": person_id,
"platform": platform,
"user_id": str(user_id),
"nickname": nickname,
"person_name": unique_nickname, # 使用群昵称作为person_name
"name_reason": "从群昵称获取",
"know_times": 0,
"know_since": int(datetime.datetime.now().timestamp()),
"last_know": int(datetime.datetime.now().timestamp()),
"impression": None,
"points": [],
"forgotten_points": [],
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
# 获取 SQLAlchemy 模odel的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
else:
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
return person_id
async def get_person_info_by_name(self, person_name: str) -> dict | None:
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
if not person_name:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None
found_person_id = None
for pid, name_in_cache in self.person_name_list.items():
if name_in_cache == person_name:
found_person_id = pid
break
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record:
found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
valid_fields_to_get = [
f
for f in required_fields
if f in model_fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None
person_info_manager = None
def get_person_info_manager():
global person_info_manager
if person_info_manager is None:
person_info_manager = PersonInfoManager()
return person_info_manager

View File

@@ -5,385 +5,25 @@
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
import traceback
from typing import Dict, List, Any, Union, Type, Optional
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
from src.common.database.sqlalchemy_database_api import (
db_query,
db_save,
db_get,
store_action_info,
get_model_class,
MODEL_MAPPING
)
logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数
# =============================================================================
async def db_query(
model_class: Type[Model],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
data: 用于创建或更新的数据字典
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
return None if query_type == "get" and single_result else []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if existing_records := list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
):
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
Args:
chat_stream: 聊天流对象,包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
try:
import time
import json
from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
# 如果没有chat_stream设置默认值
record_data.update(
{
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
}
)
# 使用已有的db_save函数保存记录
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 保持向后兼容性
__all__ = [
'db_query',
'db_save',
'db_get',
'store_action_info',
'get_model_class',
'MODEL_MAPPING'
]

View File

@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -237,35 +237,55 @@ class ComponentRegistry:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
# 根据组件类型进行特定的清理操作
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name)
self._default_actions.pop(component_name)
# 移除Action注册
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
logger.debug(f"已移除Action组件: {component_name}")
case ComponentType.COMMAND:
self._command_registry.pop(component_name)
# 移除Command注册和模式
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
self._command_patterns.pop(key, None)
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
# 移除Tool注册
self._tool_registry.pop(component_name, None)
self._llm_available_tools.pop(component_name, None)
logger.debug(f"已移除Tool组件: {component_name}")
case ComponentType.EVENT_HANDLER:
# 移除EventHandler注册和事件订阅
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name)
self._enabled_event_handlers.pop(component_name)
await events_manager.unregister_event_subscriber(component_name)
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
try:
await events_manager.unregister_event_subscriber(component_name)
logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
case _:
logger.warning(f"未知的组件类型: {component_type}")
return False
# 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name)
self._components_by_type[component_type].pop(component_name)
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None)
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
return True
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
@@ -615,5 +635,54 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# === 组件移除相关 ===
async def unregister_plugin(self, plugin_name: str) -> bool:
"""卸载插件及其所有组件
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功卸载
"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
return False
logger.info(f"开始卸载插件: {plugin_name}")
# 记录卸载失败的组件
failed_components = []
# 逐个移除插件的所有组件
for component_info in plugin_info.components:
try:
success = await self.remove_component(
component_info.name,
component_info.component_type,
plugin_name,
)
if not success:
failed_components.append(f"{component_info.component_type}.{component_info.name}")
except Exception as e:
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
failed_components.append(f"{component_info.component_type}.{component_info.name}")
# 移除插件注册信息
plugin_removed = self.remove_plugin_registry(plugin_name)
if failed_components:
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
return False
elif not plugin_removed:
logger.error(f"插件 {plugin_name} 注册信息移除失败")
return False
else:
logger.info(f"插件 {plugin_name} 卸载成功")
return True
# 创建全局组件注册中心实例
component_registry = ComponentRegistry()

View File

@@ -33,7 +33,7 @@ class EventsManager:
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
return True
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")

View File

@@ -0,0 +1,242 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import time
from pathlib import Path
from threading import Thread
from typing import Dict, Set
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 1.0 # 防抖延迟(秒)
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_name = self._get_plugin_name_from_path(file_path)
if not plugin_name:
return
current_time = time.time()
last_time = self.last_reload_time.get(plugin_name, 0)
# 防抖处理,避免频繁重载
if current_time - last_time < self.debounce_delay:
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
if file_name == "plugin.py":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
elif file_name == "manifest.toml":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,避免文件正在写入时重载
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name,),
daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}")
def _delayed_reload(self, plugin_name: str):
"""延迟重载插件"""
try:
time.sleep(self.debounce_delay)
if plugin_name in self.pending_reloads:
self.pending_reloads.remove(plugin_name)
self.hot_reload_manager._reload_plugin(plugin_name)
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
def _get_plugin_name_from_path(self, file_path: str) -> str:
"""从文件路径获取插件名称"""
try:
path = Path(file_path)
# 检查是否在监听的插件目录中
plugin_root = Path(self.hot_reload_manager.watch_directory)
if not path.is_relative_to(plugin_root):
return ""
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
return plugin_name
return ""
except Exception:
return ""
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directory: str = None):
print("fuck")
print(os.getcwd())
self.watch_directory = os.path.join(os.getcwd(), "plugins")
self.observer = None
self.file_handler = None
self.is_running = False
# 确保监听目录存在
if not os.path.exists(self.watch_directory):
os.makedirs(self.watch_directory, exist_ok=True)
logger.info(f"创建插件监听目录: {self.watch_directory}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
self.observer = Observer()
self.file_handler = PluginFileHandler(self)
self.observer.schedule(
self.file_handler,
self.watch_directory,
recursive=True
)
self.observer.start()
self.is_running = True
logger.info("🚀 插件热重载已启动,监听目录: plugins")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running:
return
if self.observer:
self.observer.stop()
self.observer.join()
self.is_running = False
def _reload_plugin(self, plugin_name: str):
"""重载指定插件"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
if plugin_manager.reload_plugin(plugin_name):
logger.info(f"✅ 插件重载成功: {plugin_name}")
else:
logger.error(f"❌ 插件重载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
def _unload_plugin(self, plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
if plugin_manager.reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}")
def get_status(self) -> dict:
"""获取热重载状态"""
return {
"is_running": self.is_running,
"watch_directory": self.watch_directory,
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
}
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -1,5 +1,6 @@
import os
import traceback
import sys
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
@@ -488,6 +489,105 @@ class PluginManager:
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# === 插件卸载和重载管理 ===
def unload_plugin(self, plugin_name: str) -> bool:
"""卸载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 卸载是否成功
"""
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
return False
try:
# 获取插件实例
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
component_registry.unregister_plugin(plugin_name)
# 从已加载插件中移除
del self.loaded_plugins[plugin_name]
# 从失败列表中移除(如果存在)
if plugin_name in self.failed_plugins:
del self.failed_plugins[plugin_name]
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
# 先卸载插件
if plugin_name in self.loaded_plugins:
self.unload_plugin(plugin_name)
# 清除Python模块缓存
plugin_path = self.plugin_paths.get(plugin_name)
if plugin_path:
plugin_file = os.path.join(plugin_path, "plugin.py")
if os.path.exists(plugin_file):
# 从sys.modules中移除相关模块
modules_to_remove = []
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
for module_name in sys.modules:
if module_name.startswith(plugin_module_prefix):
modules_to_remove.append(module_name)
for module_name in modules_to_remove:
del sys.modules[module_name]
# 从插件类注册表中移除
if plugin_name in self.plugin_classes:
del self.plugin_classes[plugin_name]
# 重新加载插件模块
if self._load_plugin_module_file(plugin_file):
# 重新加载插件实例
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"🔄 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
logger.debug("详细错误信息: ", exc_info=True)
return False
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -1,149 +1,149 @@
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts")
class TTSAction(BaseAction):
"""TTS语音转换动作处理类"""
# 激活设置
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.KEYWORD
mode_enable = ChatMode.ALL
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 关键词配置 - Normal模式下使用关键词触发
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
# 动作参数定义
action_parameters = {
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
}
# 动作使用场景
action_require = [
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("text")
if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用
processed_text = self._process_text_for_tts(text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}"
return processed_text
@register_plugin
class TTSPlugin(BasePlugin):
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts")
class TTSAction(BaseAction):
"""TTS语音转换动作处理类"""
# 激活设置
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.KEYWORD
mode_enable = ChatMode.ALL
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 关键词配置 - Normal模式下使用关键词触发
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
# 动作参数定义
action_parameters = {
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
}
# 动作使用场景
action_require = [
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("text")
if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用
processed_text = self._process_text_for_tts(text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}"
return processed_text
@register_plugin
class TTSPlugin(BasePlugin):
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.4.0"
version = "6.2.3"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -11,8 +11,38 @@ version = "6.4.0"
# 修订号:配置文件内容小更新
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
[database]
# 数据库配置
database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "mysql"
# SQLite 配置(当 database_type = "sqlite" 时使用)
sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径
# MySQL 配置(当 database_type = "mysql" 时使用)
mysql_host = "localhost" # MySQL服务器地址
mysql_port = 3306 # MySQL服务器端口
mysql_database = "maibot" # MySQL数据库名
mysql_user = "root" # MySQL用户名
mysql_password = "" # MySQL密码
mysql_charset = "utf8mb4" # MySQL字符集
mysql_unix_socket = "" # MySQL Unix套接字路径可选用于本地连接优先于host/port
# MySQL SSL 配置
mysql_ssl_mode = "DISABLED" # SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY
mysql_ssl_ca = "" # SSL CA证书路径
mysql_ssl_cert = "" # SSL客户端证书路径
mysql_ssl_key = "" # SSL客户端密钥路径
# MySQL 高级配置
mysql_autocommit = true # 自动提交事务
mysql_sql_mode = "TRADITIONAL" # SQL模式
# 连接池配置
connection_pool_size = 10 # 连接池大小仅MySQL有效
connection_timeout = 10 # 连接超时时间(秒)
[bot]
platform = "qq"
platform = "qq"
qq_account = 1145141919810 # 麦麦的QQ账号
nickname = "麦麦" # 麦麦的昵称
alias_names = ["麦叠", "牢麦"] # 麦麦的别名
@@ -53,11 +83,11 @@ expression_groups = [
]
[chat] #麦麦的聊天设置
talk_frequency = 0.5
# 麦麦活跃度越高麦麦回复越多范围0-1
focus_value = 0.5
# 麦麦的专注度越高越容易持续连续对话可能消耗更多token, 范围0-1
[chat] #麦麦的聊天通用设置
focus_value = 1
# 麦麦的专注思考能力越高越容易专注可能消耗更多token
# 专注时能更好把握发言时机,能够进行持久的连续对话
max_context_size = 20 # 上下文长度
@@ -120,6 +150,7 @@ mood_update_threshold = 1 # 情绪更新阈值,越高,更新越慢
[emoji]
emoji_chance = 0.6 # 麦麦激活表情包动作的概率
emoji_activate_type = "llm" # 表情包激活类型可选randomllm ; random下表情包动作随机启用llm下表情包动作根据llm判断是否启用
max_reg_num = 60 # 表情包最大注册数量
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包

3813
uv.lock generated

File diff suppressed because it is too large Load Diff