This commit is contained in:
Windpicker-owo
2025-10-19 22:49:14 +08:00
48 changed files with 1700 additions and 2279 deletions

View File

@@ -1,70 +1,51 @@
name: Docker CI name: Docker Build and Push
on: on:
# push: push:
# branches: branches:
# - master - master
# - develop - dev
# tags: tags:
# - "v*.*.*" - "v*.*.*"
# - "v*" - "v*"
# - "*.*.*" - "*.*.*"
# - "*.*.*-*" - "*.*.*-*"
workflow_dispatch: # 允许手动触发工作流 workflow_dispatch: # 允许手动触发工作流
# Workflow's jobs
jobs: jobs:
build-amd64: build-amd64:
name: 构建 AMD64 镜像 name: Build AMD64 Image
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
outputs: outputs:
digest: ${{ steps.build.outputs.digest }} digest: ${{ steps.build.outputs.digest }}
steps: steps:
- name: 检出 Git 仓库 - name: Check out git repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: 克隆 maim_message - name: Set up Docker Buildx
uses: actions/checkout@v4
with:
repository: MaiM-with-u/maim_message
path: maim_message
- name: 克隆 MaiMBot-LPMM
uses: actions/checkout@v4
with:
repository: MaiM-with-u/MaiMBot-LPMM
path: MaiMBot-LPMM
- name: 设置 Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
with: with:
buildkitd-flags: --debug buildkitd-flags: --debug
- name: 登录到 Docker Hub # Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Docker 元数据 # Generate metadata for Docker images
- name: Docker meta
id: meta id: meta
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
- name: 动态生成镜像标签 # Build and push AMD64 image by digest
id: tag - name: Build and push AMD64
run: |
if [ "$GITHUB_REF" == "refs/heads/master" ]; then
echo "tag=latest" >> $GITHUB_ENV
elif [ "$GITHUB_REF" == "refs/heads/develop" ]; then
echo "tag=dev" >> $GITHUB_ENV
else
echo "tag=${{ github.ref_name }}" >> $GITHUB_ENV
fi
- name: 构建并推送 AMD64 镜像
id: build id: build
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
@@ -72,10 +53,97 @@ jobs:
platforms: linux/amd64 platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
file: ./Dockerfile file: ./Dockerfile
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:amd64-buildcache cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:amd64-buildcache,mode=max cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot:${{ env.tag }},name-canonical=true,push=true outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
build-args: | build-args: |
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ') BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
VCS_REF=${{ github.sha }} VCS_REF=${{ github.sha }}
BRANCH_NAME=${{ github.ref_name }}
build-arm64:
name: Build ARM64 Image
runs-on: ubuntu-24.04-arm
outputs:
digest: ${{ steps.build.outputs.digest }}
steps:
- name: Check out git repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-flags: --debug
# Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Generate metadata for Docker images
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
# Build and push ARM64 image by digest
- name: Build and push ARM64
id: build
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/arm64/v8
labels: ${{ steps.meta.outputs.labels }}
file: ./Dockerfile
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
build-args: |
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
VCS_REF=${{ github.sha }}
create-manifest:
name: Create Multi-Arch Manifest
runs-on: ubuntu-24.04
needs:
- build-amd64
- build-arm64
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Generate metadata for Docker images
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
tags: |
type=ref,event=branch
type=ref,event=tag
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
- name: Create and Push Manifest
run: |
# 为每个标签创建多架构镜像
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
echo "Creating manifest for $tag"
docker buildx imagetools create -t $tag \
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
done

View File

@@ -2,31 +2,19 @@ FROM python:3.13.5-slim-bookworm
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
# 工作目录 # 工作目录
WORKDIR /mmc WORKDIR /app
# 复制依赖列表 # 复制依赖列表
COPY requirements.txt . COPY pyproject.toml .
# 同级目录下需要有 maim_message MaiMBot-LPMM
#COPY maim_message /maim_message
COPY MaiMBot-LPMM /MaiMBot-LPMM
# 编译器 # 编译器
RUN apt-get update && apt-get install -y build-essential RUN apt-get update && apt-get install -y build-essential
# lpmm编译安装
RUN cd /MaiMBot-LPMM && uv pip install --system -r requirements.txt
RUN uv pip install --system Cython py-cpuinfo setuptools
RUN cd /MaiMBot-LPMM/lib/quick_algo && python build_lib.py --cleanup --cythonize --install
# 安装依赖 # 安装依赖
RUN uv pip install --system --upgrade pip RUN uv sync
#RUN uv pip install --system -e /maim_message
RUN uv pip install --system -r requirements.txt
# 复制项目代码
COPY . . COPY . .
EXPOSE 8000 EXPOSE 8000
ENTRYPOINT [ "python","bot.py" ] ENTRYPOINT [ "uv","run","bot.py" ]

View File

@@ -38,12 +38,12 @@
**MoFox_Bot** 是一个基于 [MaiCore](https://github.com/MaiM-with-u/MaiBot) `0.10.0 snapshot.5` 的增强型 fork 项目。我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个**更稳定、更智能、更具趣味性**的 AI 智能体。 **MoFox_Bot** 是一个基于 [MaiCore](https://github.com/MaiM-with-u/MaiBot) `0.10.0 snapshot.5` 的增强型 fork 项目。我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个**更稳定、更智能、更具趣味性**的 AI 智能体。
> [!IMPORTANT] > [IMPORTANT]
> **第三方项目声明** > **第三方项目声明**
> >
> 本项目由 **MoFox Studio** 独立维护,为 **MaiBot 的第三方分支**,并非官方版本。所有更新与支持均由我们团队负责,与 MaiBot 官方无直接关系。 > 本项目由 **MoFox Studio** 独立维护,为 **MaiBot 的第三方分支**,并非官方版本。所有更新与支持均由我们团队负责,与 MaiBot 官方无直接关系。
> [!WARNING] > [WARNING]
> **迁移风险提示** > **迁移风险提示**
> >
> 由于我们对数据库结构进行了重构与优化,从官方 MaiBot 直接迁移至 MoFox_Bot **可能导致数据不兼容**。请在迁移前**务必备份原始数据**,以避免信息丢失。 > 由于我们对数据库结构进行了重构与优化,从官方 MaiBot 直接迁移至 MoFox_Bot **可能导致数据不兼容**。请在迁移前**务必备份原始数据**,以避免信息丢失。
@@ -63,8 +63,6 @@
<td width="50%"> <td width="50%">
### 🔧 原版功能(全部保留) ### 🔧 原版功能(全部保留)
- 🧠 **智能对话系统** - 基于 LLM 的自然语言交互,支持 normal 和 focus 统一化处理
- 🔌 **强大插件系统** - 全面重构的插件架构,支持完整的管理 API 和权限控制 - 🔌 **强大插件系统** - 全面重构的插件架构,支持完整的管理 API 和权限控制
- 💭 **实时思维系统** - 模拟人类思考过程 - 💭 **实时思维系统** - 模拟人类思考过程
- 📚 **表达学习功能** - 学习群友的说话风格和表达方式 - 📚 **表达学习功能** - 学习群友的说话风格和表达方式
@@ -78,6 +76,7 @@
### 🚀 拓展功能 ### 🚀 拓展功能
- 🧠 **AFC 智能对话** - 基于亲和力流,实现兴趣感知和动态关系构建
- 🔄 **数据库切换** - 支持 SQLite 与 MySQL 自由切换,采用 SQLAlchemy 2.0 重新构建 - 🔄 **数据库切换** - 支持 SQLite 与 MySQL 自由切换,采用 SQLAlchemy 2.0 重新构建
- 🛡️ **反注入集成** - 内置一整套回复前注入过滤系统,为人格保驾护航 - 🛡️ **反注入集成** - 内置一整套回复前注入过滤系统,为人格保驾护航
- 🎥 **视频分析** - 支持多种视频识别模式,拓展原版视觉 - 🎥 **视频分析** - 支持多种视频识别模式,拓展原版视觉

View File

@@ -1,47 +1,22 @@
services: services:
adapters:
container_name: maim-bot-adapters
#### prod ####
image: unclas/maimbot-adapter:latest
# image: infinitycat/maimbot-adapter:latest
#### dev ####
# image: unclas/maimbot-adapter:dev
# image: infinitycat/maimbot-adapter:dev
environment:
- TZ=Asia/Shanghai
# ports:
# - "8095:8095"
volumes:
- ./docker-config/adapters/config.toml:/adapters/config.toml # 持久化adapters配置文件
- ./data/adapters:/adapters/data # adapters 数据持久化
restart: always
networks:
- maim_bot
core: core:
container_name: maim-bot-core container_name: MoFox-Bot
#### prod #### #### prod ####
image: sengokucola/maibot:latest image: hunuon/mofox:latest
# image: infinitycat/maibot:latest
#### dev #### #### dev ####
# image: sengokucola/maibot:dev # image: hunuon/mofox:dev
# image: infinitycat/maibot:dev
environment: environment:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
# ports:
# - "8000:8000"
volumes: volumes:
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件 - ./docker-config/core/.env:/app/.env # 持久化env配置文件
- ./docker-config/mmc:/MaiMBot/config # 持久化bot配置文件 - ./docker-config/core:/app/config # 持久化bot配置文件
- ./data/MaiMBot/maibot_statistics.html:/MaiMBot/maibot_statistics.html #统计数据输出 - ./data/core/maibot_statistics.html:/app/maibot_statistics.html #统计数据输出
- ./data/MaiMBot:/MaiMBot/data # 共享目录 - ./data/app:/app/data # 共享目录
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录 - ./data/core/plugins:/app/plugins # 插件目录
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录 - ./data/core/logs:/app/logs # 日志目录
- site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包
restart: always restart: always
networks: networks:
- maim_bot - mofox
napcat: napcat:
environment: environment:
- NAPCAT_UID=1000 - NAPCAT_UID=1000
@@ -52,25 +27,12 @@ services:
volumes: volumes:
- ./docker-config/napcat:/app/napcat/config # 持久化napcat配置文件 - ./docker-config/napcat:/app/napcat/config # 持久化napcat配置文件
- ./data/qq:/app/.config/QQ # 持久化QQ本体 - ./data/qq:/app/.config/QQ # 持久化QQ本体
- ./data/MaiMBot:/MaiMBot/data # 共享目录 - ./data/app:/app/data # 共享目录
container_name: maim-bot-napcat container_name: mofox-napcat
restart: always restart: always
image: mlikiowa/napcat-docker:latest image: mlikiowa/napcat-docker:latest
networks: networks:
- maim_bot - mofox
sqlite-web:
# 注意coleifer/sqlite-web 镜像不支持arm64
image: coleifer/sqlite-web
container_name: sqlite-web
restart: always
ports:
- "8120:8080"
volumes:
- ./data/MaiMBot:/data/MaiMBot
environment:
- SQLITE_DATABASE=MaiMBot/MaiBot.db # 你的数据库文件
networks:
- maim_bot
# chat2db占用相对较高但是功能强大 # chat2db占用相对较高但是功能强大
# 内存占用约600m内存充足推荐选此 # 内存占用约600m内存充足推荐选此
@@ -81,11 +43,11 @@ services:
# ports: # ports:
# - "10824:10824" # - "10824:10824"
# volumes: # volumes:
# - ./data/MaiMBot:/data/MaiMBot # - ./data/chat2db:/data/app
# networks: # networks:
# - maim_bot # - mofox
volumes: volumes:
site-packages: site-packages:
networks: networks:
maim_bot: mofox:
driver: bridge driver: bridge

View File

@@ -3,10 +3,11 @@ import random
from typing import Any from typing import Any
from src.plugin_system import ( from src.plugin_system import (
ActionActivationType,
BaseAction, BaseAction,
BaseEventHandler, BaseEventHandler,
BasePlugin, BasePlugin,
BasePrompt,
ToolParamType,
BaseTool, BaseTool,
ChatType, ChatType,
CommandArgs, CommandArgs,
@@ -37,7 +38,17 @@ class GetSystemInfoTool(BaseTool):
name = "get_system_info" name = "get_system_info"
description = "获取当前系统的模拟版本和状态信息。" description = "获取当前系统的模拟版本和状态信息。"
available_for_llm = True available_for_llm = True
parameters = [] parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None),
(
"time_range",
ToolParamType.STRING,
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'",
False,
["any", "week", "month"],
),
] # type: ignore
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"} return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
@@ -100,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction):
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool: async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
"""LLM 判断激活:判断用户是否情绪低落""" """LLM 判断激活:判断用户是否情绪低落"""
return await self._llm_judge_activation( return await self._llm_judge_activation(
chat_content=chat_content,
judge_prompt=""" judge_prompt="""
判断用户是否表达了以下情绪或需求: 判断用户是否表达了以下情绪或需求:
1. 感到难过、沮丧或失落 1. 感到难过、沮丧或失落
@@ -133,11 +143,11 @@ class CombinedActivationExampleAction(BaseAction):
# 先尝试随机激活 # 先尝试随机激活
if await self._random_activation(0.2): if await self._random_activation(0.2):
return True return True
# 如果随机未激活,尝试关键词匹配 # 如果随机未激活,尝试关键词匹配
if await self._keyword_match(chat_content, ["表情", "emoji", "😊"], case_sensitive=False): if await self._keyword_match(chat_content, ["表情", "emoji", "😊"], case_sensitive=False):
return True return True
# 都不满足则不激活 # 都不满足则不激活
return False return False
@@ -170,6 +180,19 @@ class RandomEmojiAction(BaseAction):
return True, "成功发送了一个随机表情" return True, "成功发送了一个随机表情"
class WeatherPrompt(BasePrompt):
"""一个简单的Prompt组件用于向Planner注入天气信息。"""
prompt_name = "weather_info_prompt"
prompt_description = "向Planner注入当前天气信息以丰富对话上下文。"
injection_point = "planner_prompt"
async def execute(self) -> str:
# 在实际应用中这里可以调用天气API
# 为了演示,我们返回一个固定的天气信息
return "当前天气晴朗温度25°C。"
@register_plugin @register_plugin
class HelloWorldPlugin(BasePlugin): class HelloWorldPlugin(BasePlugin):
"""一个包含四大核心组件和高级配置功能的入门示例插件。""" """一个包含四大核心组件和高级配置功能的入门示例插件。"""
@@ -179,7 +202,6 @@ class HelloWorldPlugin(BasePlugin):
dependencies = [] dependencies = []
python_dependencies = [] python_dependencies = []
config_file_name = "config.toml" config_file_name = "config.toml"
enable_plugin = False
config_schema = { config_schema = {
"meta": { "meta": {
@@ -209,4 +231,7 @@ class HelloWorldPlugin(BasePlugin):
if self.get_config("components.random_emoji_action_enabled", True): if self.get_config("components.random_emoji_action_enabled", True):
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction)) components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
# 注册新的Prompt组件
components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt))
return components return components

View File

@@ -2,17 +2,16 @@
name = "MoFox-Bot" name = "MoFox-Bot"
version = "0.8.1" version = "0.8.1"
description = "MoFox-Bot 是一个基于大语言模型的可交互智能体" description = "MoFox-Bot 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.11" requires-python = ">=3.11,<=3.13"
dependencies = [ dependencies = [
"aiohttp>=3.12.14", "aiohttp>=3.12.14",
"aiohttp-cors>=0.8.1", "aiohttp-cors>=0.8.1",
"apscheduler>=3.11.0", "apscheduler>=3.11.0",
"asyncddgs>=0.1.0a1",
"asyncio>=4.0.0", "asyncio>=4.0.0",
"beautifulsoup4>=4.13.4", "beautifulsoup4>=4.13.4",
"chromadb>=0.5.0", "chromadb>=1.2.0",
"colorama>=0.4.6", "colorama>=0.4.6",
"cryptography>=45.0.5", "cryptography>=46.0.3",
"customtkinter>=5.2.2", "customtkinter>=5.2.2",
"dotenv>=0.9.9", "dotenv>=0.9.9",
"exa-py>=1.14.20", "exa-py>=1.14.20",
@@ -21,11 +20,10 @@ dependencies = [
"google>=3.0.0", "google>=3.0.0",
"google-genai>=1.29.0", "google-genai>=1.29.0",
"httpx>=0.28.1", "httpx>=0.28.1",
"jieba>=0.1.13",
"json-repair>=0.47.6", "json-repair>=0.47.6",
"json5>=0.12.1", "json5>=0.12.1",
"jsonlines>=4.0.0", "jsonlines>=4.0.0",
"langfuse==2.46.2", "langfuse==3.7.0",
"lunar-python>=1.4.4", "lunar-python>=1.4.4",
"lxml>=6.0.0", "lxml>=6.0.0",
"maim-message>=0.3.8", "maim-message>=0.3.8",
@@ -33,16 +31,16 @@ dependencies = [
"networkx>=3.4.2", "networkx>=3.4.2",
"orjson>=3.10", "orjson>=3.10",
"numpy>=2.2.6", "numpy>=2.2.6",
"openai>=1.95.0", "openai>=2.5.0",
"opencv-python>=4.11.0.86", "opencv-python>=4.11.0.86",
"packaging>=23.2", "packaging>=25.0",
"pandas>=2.3.1", "pandas>=2.3.1",
"peewee>=3.18.2", "peewee>=3.18.2",
"pillow>=11.3.0", "pillow>=12.0.0",
"pip-check-reqs>=2.5.5", "pip-check-reqs>=2.5.5",
"psutil>=7.0.0", "psutil>=7.0.0",
"pyarrow>=20.0.0", "pyarrow>=21.0.0",
"pydantic>=2.11.7", "pydantic>=2.12.3",
"pygments>=2.19.2", "pygments>=2.19.2",
"pymongo>=4.13.2", "pymongo>=4.13.2",
"pymysql>=1.1.1", "pymysql>=1.1.1",
@@ -76,8 +74,8 @@ dependencies = [
"aiosqlite>=0.21.0", "aiosqlite>=0.21.0",
"inkfox>=0.1.1", "inkfox>=0.1.1",
"rjieba>=0.1.13", "rjieba>=0.1.13",
"mcp>=0.9.0", "mcp>=1.18.0",
"sse-starlette>=2.2.1", "sse-starlette>=3.0.2",
] ]
[[tool.uv.index]] [[tool.uv.index]]

106
scripts/convert_manifest.py Normal file
View File

@@ -0,0 +1,106 @@
import os
import shutil
import sys
from pathlib import Path
import orjson
# 将脚本所在的目录添加到系统路径中,以便导入项目模块
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
logger = get_logger("convert_manifest")
def convert_and_copy_plugin(plugin_dir: Path, output_dir: Path):
"""
转换插件的 _manifest.json 文件,并将其整个目录复制到输出位置。
"""
manifest_path = plugin_dir / "_manifest.json"
if not manifest_path.is_file():
logger.warning(f"在目录 '{plugin_dir.name}' 中未找到 '_manifest.json',已跳过。")
return
try:
# 1. 复制整个插件目录
target_plugin_dir = output_dir / plugin_dir.name
if target_plugin_dir.exists():
shutil.rmtree(target_plugin_dir) # 如果目标已存在,先删除
shutil.copytree(plugin_dir, target_plugin_dir)
logger.info(f"已将插件 '{plugin_dir.name}' 完整复制到 '{target_plugin_dir}'")
# 2. 读取 manifest 并生成 __init__.py 内容
with open(manifest_path, "rb") as f:
manifest = orjson.loads(f.read())
plugin_name = manifest.get("name", "Unknown Plugin")
description = manifest.get("description", "No description provided.")
version = manifest.get("version", "1.0.0")
author = manifest.get("author", {}).get("name", "Unknown Author")
license_type = manifest.get("license")
repository_url = manifest.get("repository_url")
keywords = manifest.get("keywords", [])
categories = manifest.get("categories", [])
plugin_type = manifest.get("plugin_info", {}).get("plugin_type")
meta_template = f"""from src.plugin_system.base.plugin_metadata import PluginMetadata
__plugin_meta__ = PluginMetadata(
name="{plugin_name}",
description="{description}",
usage="暂无说明",
type={repr(plugin_type)},
version="{version}",
author="{author}",
license={repr(license_type)},
repository_url={repr(repository_url)},
keywords={keywords},
categories={categories},
)
"""
# 3. 在复制后的目录中创建或覆盖 __init__.py
output_init_path = target_plugin_dir / "__init__.py"
with open(output_init_path, "w", encoding="utf-8") as f:
f.write(meta_template)
# 4. 删除复制后的 _manifest.json
copied_manifest_path = target_plugin_dir / "_manifest.json"
if copied_manifest_path.is_file():
copied_manifest_path.unlink()
logger.info(f"成功为 '{plugin_dir.name}' 创建元数据文件并清理清单。")
except FileNotFoundError:
logger.error(f"错误: 在 '{manifest_path}' 未找到清单文件")
except orjson.JSONDecodeError:
logger.error(f"错误: 无法解析 '{manifest_path}' 的 JSON 内容")
except Exception as e:
logger.error(f"处理 '{plugin_dir.name}' 时发生意外错误: {e}")
def main():
"""
主函数,扫描 "plugins" 目录,并将合格的插件转换并复制到 "completed_plugins" 目录。
"""
# 使用相对于脚本位置的固定路径
script_dir = Path(__file__).parent
input_path = script_dir / "pending_plugins"
output_path = script_dir / "completed_plugins"
if not input_path.is_dir():
logger.error(f"错误: 输入目录 '{input_path}' 不存在。")
input_path.mkdir(parents=True, exist_ok=True)
logger.info("请在新建的文件夹里面投入插件文件夹并重新启动脚本")
return
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"正在扫描 '{input_path}' 中的插件...")
for item in input_path.iterdir():
if item.is_dir():
logger.info(f"发现插件目录: '{item.name}',开始处理...")
convert_and_copy_plugin(item, output_path)
logger.info("所有插件处理完成。")
if __name__ == "__main__":
main()

View File

@@ -1,218 +0,0 @@
"""批量将经典 SQLAlchemy 模型字段写法
field = Column(Integer, nullable=False, default=0)
转换为 2.0 推荐的带类型注解写法:
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
脚本特点:
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
3. 已经是 Mapped 写法的行会跳过。
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
5. nullable=True 时自动添加 "| None"
6. 保留 Column(...) 内的原始参数顺序与内容。
7. 生成 .bak 备份文件,确保可回滚。
8. 支持 --dry-run 查看差异,不写回文件。
局限/注意:
- 简单基于正则/括号计数,不解析完整 AST非常规写法(例如变量中构造 Column 再赋值)不会处理。
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
- 不自动添加 from __future__ import annotations如需 Python 3.10 以下更先进类型表达式,请自行处理。
使用方式: (在项目根目录执行)
python scripts/convert_sqlalchemy_models.py \
--file src/common/database/sqlalchemy_models.py --dry-run
确认无误后去掉 --dry-run 真实写入。
"""
from __future__ import annotations
import argparse
import re
import shutil
from pathlib import Path
TYPE_MAP = {
"Integer": "int",
"Float": "float",
"Boolean": "bool",
"Text": "str",
"String": "str",
"DateTime": "datetime.datetime",
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
"get_string_field": "str",
}
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
使用括号计数法处理多行。
返回 (start, end) 行号 (包含 end)。"""
line = lines[start_index]
if "Column(" not in line:
return None
open_parens = line.count("(") - line.count(")")
i = start_index
while open_parens > 0 and i + 1 < len(lines):
i += 1
l2 = lines[i]
open_parens += l2.count("(") - l2.count(")")
return (start_index, i)
def extract_column_body(block_lines: list[str]) -> str:
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
joined = "\n".join(block_lines)
# 找到第一次出现 Column(
start_pos = joined.find("Column(")
if start_pos == -1:
return ""
inner = joined[start_pos + len("Column(") :]
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
last_paren = inner.rfind(")")
if last_paren != -1:
inner = inner[:last_paren]
return inner.strip()
def guess_python_type(column_body: str) -> str:
# 简单取第一个类型标识符 (去掉前导装饰/空格)
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
# 利用正则抓取第一个标识符
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
if not m:
return "Any"
type_token = m.group(1)
py_type = TYPE_MAP.get(type_token, "Any")
# nullable 检测
if "nullable=True" in column_body or "nullable = True" in column_body:
# 避免重复 Optional
if py_type != "Any" and not py_type.endswith(" | None"):
py_type = f"{py_type} | None"
elif py_type == "Any":
py_type = "Any | None"
return py_type
def convert_block(block_lines: list[str]) -> list[str]:
first_line = block_lines[0]
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
if not m_name:
return block_lines
indent = m_name.group("indent")
name = m_name.group("name")
body = extract_column_body(block_lines)
py_type = guess_python_type(body)
# 构造新的多行 mapped_column 写法
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
inner_lines = body.split("\n")
if len(inner_lines) == 1:
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
return [new_line]
else:
# 多行情况
ind2 = indent + " "
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
for il in inner_lines:
if il.strip():
rebuilt.append(f"{ind2}{il.rstrip()}")
rebuilt.append(f"{indent})\n")
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
def ensure_imports(content: str) -> str:
if "Mapped," in content or "Mapped[" in content:
# 已经可能存在导入
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
# 简单插到第一个 import sqlalchemy 之后
lines = content.splitlines()
for i, line in enumerate(lines):
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
return "\n".join(lines)
return content
def process_file(path: Path) -> tuple[str, str]:
original = path.read_text(encoding="utf-8")
lines = original.splitlines(keepends=True)
i = 0
out: list[str] = []
changed = 0
while i < len(lines):
line = lines[i]
# 跳过已是 Mapped 风格
if ALREADY_MAPPED_RE.match(line):
out.append(line)
i += 1
continue
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
start, end = detect_column_block(lines, i) or (i, i)
block = lines[start : end + 1]
converted = convert_block(block)
out.extend(converted)
i = end + 1
# 如果转换结果与原始不同,计数
if "".join(converted) != "".join(block):
changed += 1
else:
out.append(line)
i += 1
new_content = "".join(out)
new_content = ensure_imports(new_content)
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
return original, new_content if changed else original
def main():
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
args = parser.parse_args()
target = Path(args.file)
if not target.exists():
raise SystemExit(f"文件不存在: {target}")
original, new_content = process_file(target)
if original == new_content:
print("[INFO] 没有需要转换的内容或转换后无差异。")
return
# 简单差异输出 (行对比)
if args.dry_run or not args.write:
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
import difflib
diff = difflib.unified_diff(
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
)
count = 0
for d in diff:
print(d)
count += 1
if count == 0:
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
if not args.write:
print("\n未写回。若确认无误,添加 --write 执行替换。")
return
backup = target.with_suffix(target.suffix + ".bak")
shutil.copyfile(target, backup)
target.write_text(new_content, encoding="utf-8")
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
if __name__ == "__main__": # pragma: no cover
main()

View File

@@ -5,6 +5,7 @@ import shutil
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Lock from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed
import orjson import orjson
from json_repair import repair_json from json_repair import repair_json
@@ -191,43 +192,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
return None, pg_hash return None, pg_hash
async def extract_information(paragraphs_dict, model_set): def extract_info_sync(pg_hash, paragraph, model_set):
llm_api = LLMRequest(model_set=model_set)
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
def extract_information(paragraphs_dict, model_set):
logger.info("--- 步骤 2: 开始信息提取 ---") logger.info("--- 步骤 2: 开始信息提取 ---")
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True) os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True) os.makedirs(TEMP_DIR, exist_ok=True)
llm_api = LLMRequest(model_set=model_set)
failed_hashes, open_ie_docs = [], [] failed_hashes, open_ie_docs = [], []
tasks = [ with ThreadPoolExecutor(max_workers=3) as executor:
extract_info_async(p_hash, p, llm_api) f_to_hash = {
for p_hash, p in paragraphs_dict.items() executor.submit(extract_info_sync, p_hash, p, model_set): p_hash
] for p_hash, p in paragraphs_dict.items()
}
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
TaskProgressColumn(), TaskProgressColumn(),
MofNCompleteColumn(), MofNCompleteColumn(),
"", "",
TimeElapsedColumn(), TimeElapsedColumn(),
"<", "<",
TimeRemainingColumn(), TimeRemainingColumn(),
) as progress: ) as progress:
prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks)) task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
for future in asyncio.as_completed(tasks): for future in as_completed(f_to_hash):
doc_item, failed_hash = await future doc_item, failed_hash = future.result()
if failed_hash: if failed_hash:
failed_hashes.append(failed_hash) failed_hashes.append(failed_hash)
elif doc_item: elif doc_item:
open_ie_docs.append(doc_item) open_ie_docs.append(doc_item)
progress.update(prog_task, advance=1) progress.update(task, advance=1)
if open_ie_docs: if open_ie_docs:
all_entities = [ all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
e for doc in open_ie_docs for e in doc["extracted_entities"]
]
num_entities = len(all_entities) num_entities = len(all_entities)
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0 avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
@@ -312,7 +315,7 @@ async def import_data(openie_obj: OpenIE | None = None):
logger.info("--- 数据导入完成 ---") logger.info("--- 数据导入完成 ---")
async def import_from_specific_file(): def import_from_specific_file():
"""从用户指定的 openie.json 文件导入数据""" """从用户指定的 openie.json 文件导入数据"""
file_path = input("请输入 openie.json 文件的完整路径: ").strip() file_path = input("请输入 openie.json 文件的完整路径: ").strip()
@@ -327,7 +330,7 @@ async def import_from_specific_file():
try: try:
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...") logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
openie_obj = OpenIE.load() openie_obj = OpenIE.load()
await import_data(openie_obj=openie_obj) asyncio.run(import_data(openie_obj=openie_obj))
except Exception as e: except Exception as e:
logger.error(f"从指定文件导入数据时发生错误: {e}") logger.error(f"从指定文件导入数据时发生错误: {e}")
@@ -335,20 +338,14 @@ async def import_from_specific_file():
# --- 主函数 --- # --- 主函数 ---
async def async_main(): def main():
# 使用 os.path.relpath 创建相对于项目根目录的友好路径 # 使用 os.path.relpath 创建相对于项目根目录的友好路径
raw_data_relpath = os.path.relpath( raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
RAW_DATA_PATH, os.path.join(ROOT_PATH, "..") openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
)
openie_output_relpath = os.path.relpath(
OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")
)
print("=== LPMM 知识库学习工具 ===") print("=== LPMM 知识库学习工具 ===")
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)") print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
print( print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)"
)
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识") print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3") print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识") print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
@@ -362,20 +359,16 @@ async def async_main():
elif choice == "2": elif choice == "2":
paragraphs = preprocess_raw_data() paragraphs = preprocess_raw_data()
if paragraphs: if paragraphs:
await extract_information( extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
paragraphs, model_config.model_task_config.lpmm_qa
)
elif choice == "3": elif choice == "3":
await import_data() asyncio.run(import_data())
elif choice == "4": elif choice == "4":
paragraphs = preprocess_raw_data() paragraphs = preprocess_raw_data()
if paragraphs: if paragraphs:
await extract_information( extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
paragraphs, model_config.model_task_config.lpmm_qa asyncio.run(import_data())
)
await import_data()
elif choice == "5": elif choice == "5":
await import_from_specific_file() import_from_specific_file()
elif choice == "6": elif choice == "6":
clear_cache() clear_cache()
elif choice == "0": elif choice == "0":
@@ -385,4 +378,4 @@ async def async_main():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(async_main()) main()

View File

@@ -1,62 +0,0 @@
"""
更新Prompt类导入脚本
将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt
"""
import os
# 需要更新的文件列表
files_to_update = [
"src/person_info/relationship_fetcher.py",
"src/mood/mood_manager.py",
"src/mais4u/mais4u_chat/body_emotion_action_manager.py",
"src/chat/express/expression_learner.py",
"src/chat/planner_actions/planner.py",
"src/mais4u/mais4u_chat/s4u_prompt.py",
"src/chat/message_receive/bot.py",
"src/chat/replyer/default_generator.py",
"src/chat/express/expression_selector.py",
"src/mais4u/mai_think.py",
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py",
]
def update_prompt_imports(file_path):
"""更新文件中的Prompt导入"""
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, encoding="utf-8") as f:
content = f.read()
# 替换导入语句
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
if old_import in content:
new_content = content.replace(old_import, new_import)
with open(file_path, "w", encoding="utf-8") as f:
f.write(new_content)
print(f"已更新: {file_path}")
return True
else:
print(f"无需更新: {file_path}")
return False
def main():
"""主函数"""
updated_count = 0
for file_path in files_to_update:
if update_prompt_imports(file_path):
updated_count += 1
print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__":
main()

View File

@@ -30,8 +30,8 @@ from .utils.hash import get_sha256
install(extra_lines=3) install(extra_lines=3)
# 多线程embedding配置常量 # 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 DEFAULT_MAX_WORKERS = 3 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小 MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小 MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数 MIN_WORKERS = 1 # 最小线程数
@@ -124,60 +124,124 @@ class EmbeddingStore:
self.faiss_index = None self.faiss_index = None
self.idx2hash = None self.idx2hash = None
@staticmethod
def _get_embedding(s: str) -> list[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 创建新的LLMRequest实例
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
...
@staticmethod @staticmethod
def _get_embeddings_batch_threaded( def _get_embeddings_batch_threaded(
strs: list[str], strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
main_loop: asyncio.AbstractEventLoop,
chunk_size: int = 10,
max_workers: int = 10,
progress_callback=None,
) -> list[tuple[str, list[float]]]: ) -> list[tuple[str, list[float]]]:
"""使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务""" """使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数
progress_callback: 进度回调函数,接收一个参数表示完成的数量
Returns:
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
"""
if not strs: if not strs:
return [] return []
# 导入必要的模块
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
# 在主线程即主事件循环所在的线程中创建LLMRequest实例
# 这样可以确保它绑定到正确的事件循环
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 分块 # 分块
chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)] chunks = []
for i in range(0, len(strs), chunk_size):
chunk = strs[i : i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序
results = {} results = {}
def process_chunk(chunk_data): def process_chunk(chunk_data):
"""在工作线程中运行的函数""" """处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data start_idx, chunk_strs = chunk_data
chunk_results = [] chunk_results = []
for i, s in enumerate(chunk_strs): # 为每个线程创建独立的LLMRequest实例
embedding = [] from src.config.config import model_config
try: from src.llm_models.utils_model import LLMRequest
# 将异步的 get_embedding 调用提交到主事件循环
future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop)
# 同步等待结果,延长超时时间
embedding_result, _ = future.result(timeout=60)
if embedding_result and len(embedding_result) > 0: try:
embedding = embedding_result # 创建线程专用的LLM实例
else: llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
logger.error(f"获取嵌入失败(返回为空): {s}")
except Exception as e: for i, s in enumerate(chunk_strs):
logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}") try:
finally: # 在线程中创建独立的事件循环
chunk_results.append((start_idx + i, s, embedding)) loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
for i, s in enumerate(chunk_strs):
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback: if progress_callback:
progress_callback(1) progress_callback(1)
return chunk_results return chunk_results
# 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk): for future in as_completed(future_to_chunk):
try: try:
chunk_results = future.result() chunk_results = future.result()
@@ -185,14 +249,22 @@ class EmbeddingStore:
results[idx] = (s, embedding) results[idx] = (s, embedding)
except Exception as e: except Exception as e:
chunk = future_to_chunk[future] chunk = future_to_chunk[future]
logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}") logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
# 为失败的块添加空结果
start_idx, chunk_strs = chunk start_idx, chunk_strs = chunk
for i, s_item in enumerate(chunk_strs): for i, s in enumerate(chunk_strs):
if (start_idx + i) not in results: results[start_idx + i] = (s, [])
results[start_idx + i] = (s_item, [])
# 按原始顺序返回结果 # 按原始顺序返回结果
return [results.get(i, (strs[i], [])) for i in range(len(strs))] ordered_results = []
for i in range(len(strs)):
if i in results:
ordered_results.append(results[i])
else:
# 防止遗漏
ordered_results.append((strs[i], []))
return ordered_results
@staticmethod @staticmethod
def get_test_file_path(): def get_test_file_path():
@@ -202,17 +274,9 @@ class EmbeddingStore:
"""保存测试字符串的嵌入到本地(使用多线程优化)""" """保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...") logger.info("开始保存测试字符串的嵌入向量...")
# 获取当前正在运行的事件循环
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
return
# 使用多线程批量获取测试字符串的嵌入 # 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
main_loop,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
@@ -224,6 +288,8 @@ class EmbeddingStore:
test_vectors[str(idx)] = embedding test_vectors[str(idx)] = embedding
else: else:
logger.error(f"获取测试字符串嵌入失败: {s}") logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f: with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8")) f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
@@ -255,17 +321,9 @@ class EmbeddingStore:
logger.info("开始检验嵌入模型一致性...") logger.info("开始检验嵌入模型一致性...")
# 获取当前正在运行的事件循环
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
return False
# 使用多线程批量获取当前模型的嵌入 # 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
main_loop,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)), max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
@@ -325,20 +383,11 @@ class EmbeddingStore:
progress.update(task, advance=already_processed) progress.update(task, advance=already_processed)
if new_strs: if new_strs:
try:
main_loop = asyncio.get_running_loop()
except RuntimeError:
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
# 更新进度条以反映未处理的项目
progress.update(task, advance=len(new_strs))
return
# 使用实例配置的参数,智能调整分块和线程数 # 使用实例配置的参数,智能调整分块和线程数
optimal_chunk_size = max( optimal_chunk_size = max(
MIN_CHUNK_SIZE, MIN_CHUNK_SIZE,
min( min(
self.chunk_size, self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size,
), ),
) )
optimal_max_workers = min( optimal_max_workers = min(
@@ -355,13 +404,12 @@ class EmbeddingStore:
# 批量获取嵌入,并实时更新进度 # 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
new_strs, new_strs,
main_loop,
chunk_size=optimal_chunk_size, chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers, max_workers=optimal_max_workers,
progress_callback=update_progress, progress_callback=update_progress,
) )
# 存入结果 # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results: for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s) item_hash = self.namespace + "-" + get_sha256(s)
if embedding: # 只有成功获取到嵌入才存入 if embedding: # 只有成功获取到嵌入才存入

View File

@@ -88,6 +88,8 @@ class QAManager:
else: else:
logger.info("未找到相关关系,将使用文段检索结果") logger.info("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res result = paragraph_search_res
if result and result[0][1] < global_config.lpmm_knowledge.qa_paragraph_threshold:
result = []
ppr_node_weights = None ppr_node_weights = None
# 过滤阈值 # 过滤阈值

View File

@@ -45,8 +45,8 @@ class MessageManager:
self.chatter_manager = ChatterManager(self.action_manager) self.chatter_manager = ChatterManager(self.action_manager)
# 消息缓存系统 - 直接集成到消息管理器 # 消息缓存系统 - 直接集成到消息管理器
self.message_caches: Dict[str, deque] = defaultdict(deque) # 每个流的消息缓存 self.message_caches: dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
self.stream_processing_status: Dict[str, bool] = defaultdict(bool) # 流的处理状态 self.stream_processing_status: dict[str, bool] = defaultdict(bool) # 流的处理状态
self.cache_stats = { self.cache_stats = {
"total_cached_messages": 0, "total_cached_messages": 0,
"total_flushed_messages": 0, "total_flushed_messages": 0,

View File

@@ -1,10 +1,10 @@
from datetime import datetime, time, timedelta
import random import random
from typing import Optional, Tuple from datetime import datetime, timedelta
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.schedule.schedule_manager import schedule_manager from src.schedule.schedule_manager import schedule_manager
from .state_manager import SleepState, sleep_state_manager from .state_manager import SleepState, sleep_state_manager
logger = get_logger("sleep_logic") logger = get_logger("sleep_logic")
@@ -77,7 +77,7 @@ class SleepLogic:
logger.info(f"当前时间 {now.strftime('%H:%M')} 已到达或超过预定起床时间 {wake_up_time.strftime('%H:%M')}") logger.info(f"当前时间 {now.strftime('%H:%M')} 已到达或超过预定起床时间 {wake_up_time.strftime('%H:%M')}")
sleep_state_manager.set_state(SleepState.AWAKE) sleep_state_manager.set_state(SleepState.AWAKE)
def _should_be_sleeping(self, now: datetime) -> Tuple[bool, Optional[datetime]]: def _should_be_sleeping(self, now: datetime) -> tuple[bool, datetime | None]:
""" """
判断在当前时刻,是否应该处于睡眠时间。 判断在当前时刻,是否应该处于睡眠时间。
@@ -108,10 +108,10 @@ class SleepLogic:
return True, wake_up_time return True, wake_up_time
# 如果当前时间大于入睡时间,说明已经进入睡眠窗口 # 如果当前时间大于入睡时间,说明已经进入睡眠窗口
return True, wake_up_time return True, wake_up_time
return False, None return False, None
def _get_fixed_sleep_times(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]: def _get_fixed_sleep_times(self, now: datetime) -> tuple[datetime | None, datetime | None]:
""" """
当使用“固定时间”模式时,从此方法计算睡眠和起床时间。 当使用“固定时间”模式时,从此方法计算睡眠和起床时间。
会加入配置中的随机偏移量,让作息更自然。 会加入配置中的随机偏移量,让作息更自然。
@@ -129,7 +129,7 @@ class SleepLogic:
wake_up_t = datetime.strptime(sleep_config.fixed_wake_up_time, "%H:%M").time() wake_up_t = datetime.strptime(sleep_config.fixed_wake_up_time, "%H:%M").time()
sleep_time = datetime.combine(now.date(), sleep_t) + timedelta(minutes=sleep_offset) sleep_time = datetime.combine(now.date(), sleep_t) + timedelta(minutes=sleep_offset)
# 如果起床时间比睡觉时间早,说明是第二天 # 如果起床时间比睡觉时间早,说明是第二天
wake_up_day = now.date() + timedelta(days=1) if wake_up_t < sleep_t else now.date() wake_up_day = now.date() + timedelta(days=1) if wake_up_t < sleep_t else now.date()
wake_up_time = datetime.combine(wake_up_day, wake_up_t) + timedelta(minutes=wake_up_offset) wake_up_time = datetime.combine(wake_up_day, wake_up_t) + timedelta(minutes=wake_up_offset)
@@ -139,7 +139,7 @@ class SleepLogic:
logger.error(f"解析固定睡眠时间失败: {e}") logger.error(f"解析固定睡眠时间失败: {e}")
return None, None return None, None
def _get_sleep_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]: def _get_sleep_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
""" """
当使用“日程表”模式时,从此方法获取睡眠时间。 当使用“日程表”模式时,从此方法获取睡眠时间。
实现了核心逻辑: 实现了核心逻辑:
@@ -164,8 +164,8 @@ class SleepLogic:
wake_up_time = None wake_up_time = None
return sleep_time, wake_up_time return sleep_time, wake_up_time
def _get_wakeup_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]: def _get_wakeup_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
""" """
当使用“日程表”模式时,从此方法获取睡眠时间。 当使用“日程表”模式时,从此方法获取睡眠时间。
实现了核心逻辑: 实现了核心逻辑:
@@ -192,4 +192,4 @@ class SleepLogic:
# 全局单例 # 全局单例
sleep_logic = SleepLogic() sleep_logic = SleepLogic()

View File

@@ -1,6 +1,6 @@
import enum import enum
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Optional from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
@@ -43,7 +43,7 @@ class SleepStateManager:
""" """
初始化状态管理器,定义状态数据结构并从本地加载历史状态。 初始化状态管理器,定义状态数据结构并从本地加载历史状态。
""" """
self.state: Dict[str, Any] = {} self.state: dict[str, Any] = {}
self._default_state() self._default_state()
self.load_state() self.load_state()
@@ -115,9 +115,9 @@ class SleepStateManager:
def set_state( def set_state(
self, self,
new_state: SleepState, new_state: SleepState,
duration_seconds: Optional[float] = None, duration_seconds: float | None = None,
sleep_start: Optional[datetime] = None, sleep_start: datetime | None = None,
wake_up: Optional[datetime] = None, wake_up: datetime | None = None,
): ):
""" """
核心函数:切换到新的睡眠状态,并更新相关的状态数据。 核心函数:切换到新的睡眠状态,并更新相关的状态数据。
@@ -132,7 +132,7 @@ class SleepStateManager:
if new_state == SleepState.AWAKE: if new_state == SleepState.AWAKE:
self._default_state() # 醒来时重置所有状态 self._default_state() # 醒来时重置所有状态
self.state["state"] = SleepState.AWAKE # 确保状态正确 self.state["state"] = SleepState.AWAKE # 确保状态正确
elif new_state == SleepState.SLEEPING: elif new_state == SleepState.SLEEPING:
self.state["sleep_start_time"] = (sleep_start or datetime.now()).isoformat() self.state["sleep_start_time"] = (sleep_start or datetime.now()).isoformat()
self.state["wake_up_time"] = wake_up.isoformat() if wake_up else None self.state["wake_up_time"] = wake_up.isoformat() if wake_up else None
@@ -153,7 +153,7 @@ class SleepStateManager:
self.state["last_checked"] = datetime.now().isoformat() self.state["last_checked"] = datetime.now().isoformat()
self.save_state() self.save_state()
def get_wake_up_time(self) -> Optional[datetime]: def get_wake_up_time(self) -> datetime | None:
"""获取预定的起床时间,如果已设置的话。""" """获取预定的起床时间,如果已设置的话。"""
wake_up_str = self.state.get("wake_up_time") wake_up_str = self.state.get("wake_up_time")
if wake_up_str: if wake_up_str:
@@ -163,7 +163,7 @@ class SleepStateManager:
return None return None
return None return None
def get_sleep_start_time(self) -> Optional[datetime]: def get_sleep_start_time(self) -> datetime | None:
"""获取本次睡眠的开始时间,如果已设置的话。""" """获取本次睡眠的开始时间,如果已设置的话。"""
sleep_start_str = self.state.get("sleep_start_time") sleep_start_str = self.state.get("sleep_start_time")
if sleep_start_str: if sleep_start_str:
@@ -187,4 +187,4 @@ class SleepStateManager:
# 全局单例 # 全局单例
sleep_state_manager = SleepStateManager() sleep_state_manager = SleepStateManager()

View File

@@ -1,5 +1,6 @@
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from .sleep_logic import sleep_logic from .sleep_logic import sleep_logic
logger = get_logger("sleep_tasks") logger = get_logger("sleep_tasks")

View File

@@ -402,19 +402,31 @@ class ChatBot:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
platform = message_data["message_info"].get("platform") # 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
if not isinstance(message_data, dict):
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
return
message_info = message_data.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
", ".join(message_data.keys()),
)
return
platform = message_info.get("platform")
if platform == "amaidesu_default": if platform == "amaidesu_default":
await self.do_s4u(message_data) await self.do_s4u(message_data)
return return
if message_data["message_info"].get("group_info") is not None: if message_info.get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str( message_info["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"] message_info["group_info"]["group_id"]
) )
if message_data["message_info"].get("user_info") is not None: if message_info.get("user_info") is not None:
message_data["message_info"]["user_info"]["user_id"] = str( message_info["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"] message_info["user_info"]["user_id"]
) )
# print(message_data) # print(message_data)
# logger.debug(str(message_data)) # logger.debug(str(message_data))

View File

@@ -11,7 +11,7 @@ from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo from src.plugin_system.base.component_types import ActionInfo
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -207,18 +207,18 @@ class ActionModifier:
List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表 List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表
""" """
deactivated_actions = [] deactivated_actions = []
# 获取 Action 类注册表 # 获取 Action 类注册表
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.core.component_registry import component_registry
actions_to_check = list(actions_with_info.items()) actions_to_check = list(actions_with_info.items())
random.shuffle(actions_to_check) random.shuffle(actions_to_check)
# 创建并行任务列表 # 创建并行任务列表
activation_tasks = [] activation_tasks = []
task_action_names = [] task_action_names = []
for action_name, action_info in actions_to_check: for action_name, action_info in actions_to_check:
# 获取 Action 类 # 获取 Action 类
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION) action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
@@ -226,7 +226,7 @@ class ActionModifier:
logger.warning(f"{self.log_prefix}未找到 Action 类: {action_name},默认不激活") logger.warning(f"{self.log_prefix}未找到 Action 类: {action_name},默认不激活")
deactivated_actions.append((action_name, "未找到 Action 类")) deactivated_actions.append((action_name, "未找到 Action 类"))
continue continue
# 创建一个临时实例来调用 go_activate 方法 # 创建一个临时实例来调用 go_activate 方法
# 注意:这里只是为了调用 go_activate不需要完整的初始化 # 注意:这里只是为了调用 go_activate不需要完整的初始化
try: try:
@@ -237,24 +237,24 @@ class ActionModifier:
action_instance.log_prefix = self.log_prefix action_instance.log_prefix = self.log_prefix
# 设置聊天内容,用于激活判断 # 设置聊天内容,用于激活判断
action_instance._activation_chat_content = chat_content action_instance._activation_chat_content = chat_content
# 调用 go_activate 方法(不再需要传入 chat_content # 调用 go_activate 方法(不再需要传入 chat_content
task = action_instance.go_activate( task = action_instance.go_activate(
llm_judge_model=self.llm_judge, llm_judge_model=self.llm_judge,
) )
activation_tasks.append(task) activation_tasks.append(task)
task_action_names.append(action_name) task_action_names.append(action_name)
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}创建 Action 实例 {action_name} 失败: {e}") logger.error(f"{self.log_prefix}创建 Action 实例 {action_name} 失败: {e}")
deactivated_actions.append((action_name, f"创建实例失败: {e}")) deactivated_actions.append((action_name, f"创建实例失败: {e}"))
# 并行执行所有激活判断 # 并行执行所有激活判断
if activation_tasks: if activation_tasks:
logger.debug(f"{self.log_prefix}并行执行激活判断,任务数: {len(activation_tasks)}") logger.debug(f"{self.log_prefix}并行执行激活判断,任务数: {len(activation_tasks)}")
try: try:
task_results = await asyncio.gather(*activation_tasks, return_exceptions=True) task_results = await asyncio.gather(*activation_tasks, return_exceptions=True)
# 处理结果 # 处理结果
for action_name, result in zip(task_action_names, task_results, strict=False): for action_name, result in zip(task_action_names, task_results, strict=False):
if isinstance(result, Exception): if isinstance(result, Exception):
@@ -267,7 +267,7 @@ class ActionModifier:
else: else:
# go_activate 返回 True激活 # go_activate 返回 True激活
logger.debug(f"{self.log_prefix}激活动作: {action_name}") logger.debug(f"{self.log_prefix}激活动作: {action_name}")
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}并行激活判断失败: {e}") logger.error(f"{self.log_prefix}并行激活判断失败: {e}")
# 如果并行执行失败,为所有任务默认不激活 # 如果并行执行失败,为所有任务默认不激活

View File

@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.memory_mappings import get_memory_type_chinese_label from src.chat.utils.memory_mappings import get_memory_type_chinese_label
# 导入新的统一Prompt系统 # 导入新的统一Prompt系统
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.prompt_params import PromptParameters
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -1312,7 +1313,7 @@ class DefaultReplyer:
} }
# 设置超时 # 设置超时
timeout = 15.0 # 秒 timeout = 45.0 # 秒
async def get_task_result(task_name, task): async def get_task_result(task_name, task):
try: try:

View File

@@ -8,13 +8,14 @@ import contextvars
import re import re
import time import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass, field from typing import Any, Optional
from typing import Any, Literal, Optional
from rich.traceback import install from rich.traceback import install
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.prompt_component_manager import prompt_component_manager
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
@@ -23,81 +24,6 @@ install(extra_lines=3)
logger = get_logger("unified_prompt") logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
bot_name: str = ""
bot_nickname: str = ""
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
notice_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
safety_guidelines_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
class PromptContext: class PromptContext:
"""提示词上下文管理器""" """提示词上下文管理器"""
@@ -132,7 +58,7 @@ class PromptContext:
context_id = None context_id = None
previous_context = self._current_context previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else: else:
previous_context = self._current_context previous_context = self._current_context
token = None token = None
@@ -185,16 +111,42 @@ class PromptManager:
async with self._context.async_scope(message_id): async with self._context.async_scope(message_id):
yield self yield self
async def get_prompt_async(self, name: str) -> "Prompt": async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
"""异步获取提示模板""" """
异步获取提示模板,并动态注入插件内容
"""
original_prompt = None
context_prompt = await self._context.get_prompt_async(name) context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None: if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt original_prompt = context_prompt
elif name in self._prompts:
if name not in self._prompts: original_prompt = self._prompts[name]
else:
raise KeyError(f"Prompt '{name}' not found") raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
# 动态注入插件内容
if original_prompt.name:
# 确保我们有有效的parameters实例
params_for_injection = parameters or original_prompt.parameters
components_prefix = await prompt_component_manager.execute_components_for(
injection_point=original_prompt.name, params=params_for_injection
)
logger.info(components_prefix)
if components_prefix:
logger.info(f"'{name}'注入插件内容: \n{components_prefix}")
# 创建一个新的临时Prompt实例不进行注册
new_template = f"{components_prefix}\n\n{original_prompt.template}"
temp_prompt = Prompt(
template=new_template,
name=original_prompt.name,
parameters=original_prompt.parameters,
should_register=False, # 确保不重新注册
)
return temp_prompt
return original_prompt
def generate_name(self, template: str) -> str: def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称""" """为未命名的prompt生成名称"""
@@ -216,7 +168,9 @@ class PromptManager:
async def format_prompt(self, name: str, **kwargs) -> str: async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板""" """格式化提示模板"""
prompt = await self.get_prompt_async(name) # 提取parameters用于注入
parameters = kwargs.get("parameters")
prompt = await self.get_prompt_async(name, parameters=parameters)
result = prompt.format(**kwargs) result = prompt.format(**kwargs)
return result return result
@@ -304,11 +258,14 @@ class Prompt:
start_time = time.time() start_time = time.time()
try: try:
# 构建上下文数据 # 1. 构建核心上下文数据
context_data = await self._build_context_data() context_data = await self._build_context_data()
# 格式化模板 # 2. 格式化模板
result = await self._format_with_context(context_data) main_formatted_prompt = await self._format_with_context(context_data)
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
result = main_formatted_prompt
total_time = time.time() - start_time total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
@@ -470,9 +427,13 @@ class Prompt:
if not self.parameters.message_list_before_now_long: if not self.parameters.message_list_before_now_long:
return return
target_user_id = ""
if self.parameters.target_user_info:
target_user_id = self.parameters.target_user_info.get("user_id") or ""
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts( read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long, self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", target_user_id,
self.parameters.sender, self.parameters.sender,
self.parameters.chat_id, self.parameters.chat_id,
) )
@@ -498,11 +459,14 @@ class Prompt:
# 创建临时生成器实例来使用其方法 # 创建临时生成器实例来使用其方法
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building") temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
return await temp_generator.build_s4u_chat_history_prompts( if temp_generator:
message_list_before_now, target_user_id, sender, chat_id return await temp_generator.build_s4u_chat_history_prompts(
) message_list_before_now, target_user_id, sender, chat_id
)
return "", ""
except Exception as e: except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}") logger.error(f"构建S4U历史消息prompt失败: {e}")
return "", ""
async def _build_expression_habits(self) -> dict[str, Any]: async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯""" """构建表达习惯"""
@@ -589,10 +553,10 @@ class Prompt:
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True) running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
# 处理可能的异常结果 # 处理可能的异常结果
if isinstance(running_memories, Exception): if isinstance(running_memories, BaseException):
logger.warning(f"长期记忆查询失败: {running_memories}") logger.warning(f"长期记忆查询失败: {running_memories}")
running_memories = [] running_memories = []
if isinstance(instant_memory, Exception): if isinstance(instant_memory, BaseException):
logger.warning(f"即时记忆查询失败: {instant_memory}") logger.warning(f"即时记忆查询失败: {instant_memory}")
instant_memory = None instant_memory = None
@@ -763,20 +727,15 @@ class Prompt:
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
try: try:
from src.chat.knowledge.knowledge_lib import QAManager from src.chat.knowledge.knowledge_lib import qa_manager
# 获取问题文本(当前消息) # 获取问题文本(当前消息)
question = self.parameters.target or "" question = self.parameters.target or ""
if not question: if not question or not qa_manager:
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
# 创建QA管理器
qa_manager = QAManager()
# 搜索相关知识 # 搜索相关知识
knowledge_results = await qa_manager.get_knowledge( knowledge_results = await qa_manager.get_knowledge(question=question)
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
)
# 构建知识块 # 构建知识块
if knowledge_results and knowledge_results.get("knowledge_items"): if knowledge_results and knowledge_results.get("knowledge_items"):
@@ -786,12 +745,17 @@ class Prompt:
content = item.get("content", "") content = item.get("content", "")
source = item.get("source", "") source = item.get("source", "")
relevance = item.get("relevance", 0.0) relevance = item.get("relevance", 0.0)
if content: if content:
try:
relevance_float = float(relevance)
relevance_str = f"{relevance_float:.2f}"
except (ValueError, TypeError):
relevance_str = str(relevance)
if source: if source:
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})") knowledge_parts.append(f"- [{relevance_str}] {content} (来源: {source})")
else: else:
knowledge_parts.append(f"- [{relevance:.2f}] {content}") knowledge_parts.append(f"- [{relevance_str}] {content}")
if knowledge_results.get("summary"): if knowledge_results.get("summary"):
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}") knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
@@ -1108,8 +1072,24 @@ def create_prompt(
async def create_prompt_async( async def create_prompt_async(
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt: ) -> Prompt:
"""异步创建Prompt实例""" """异步创建Prompt实例,并动态注入插件内容"""
prompt = create_prompt(template, name, parameters, **kwargs) # 确保有可用的parameters实例
final_params = parameters or PromptParameters(**kwargs)
# 动态注入插件内容
if name:
components_prefix = await prompt_component_manager.execute_components_for(
injection_point=name, params=final_params
)
if components_prefix:
logger.debug(f"'{name}'注入插件内容: \n{components_prefix}")
template = f"{components_prefix}\n\n{template}"
# 使用可能已修改的模板创建实例
prompt = create_prompt(template, name, final_params)
# 如果在特定上下文中,则异步注册
if global_prompt_manager._context._current_context: if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt) await global_prompt_manager._context.register_async(prompt)
return prompt return prompt

View File

@@ -0,0 +1,109 @@
import asyncio
from typing import Type
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.component_types import ComponentType, PromptInfo
from src.plugin_system.core.component_registry import component_registry
logger = get_logger("prompt_component_manager")
class PromptComponentManager:
"""
管理所有 `BasePrompt` 组件的单例类。
该管理器负责:
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
2. 根据注入点目标Prompt名称对它们进行筛选。
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
"""
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
"""
获取指定注入点的所有已注册组件类。
Args:
injection_point: 目标Prompt的名称。
Returns:
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
"""
# 从组件注册中心获取所有启用的Prompt组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_components: list[Type[BasePrompt]] = []
for prompt_name, prompt_info in enabled_prompts.items():
# 确保 prompt_info 是 PromptInfo 类型
if not isinstance(prompt_info, PromptInfo):
continue
# 获取注入点信息
injection_points = prompt_info.injection_point
if isinstance(injection_points, str):
injection_points = [injection_points]
# 检查当前注入点是否匹配
if injection_point in injection_points:
# 获取组件类
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
if component_class and issubclass(component_class, BasePrompt):
matching_components.append(component_class)
return matching_components
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
"""
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
Args:
injection_point: 目标Prompt的名称。
params: 用于初始化组件的 PromptParameters 对象。
Returns:
str: 所有相关组件生成的、用换行符连接的文本内容。
"""
component_classes = self.get_components_for(injection_point)
if not component_classes:
return ""
tasks = []
for component_class in component_classes:
try:
# 从注册中心获取组件信息
prompt_info = component_registry.get_component_info(
component_class.prompt_name, ComponentType.PROMPT
)
if not isinstance(prompt_info, PromptInfo):
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
plugin_config = {}
else:
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
instance = component_class(params=params, plugin_config=plugin_config)
tasks.append(instance.execute())
except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
if not tasks:
return ""
# 并行执行所有组件
results = await asyncio.gather(*tasks, return_exceptions=True)
# 过滤掉执行失败的结果和空字符串
valid_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
elif result and isinstance(result, str) and result.strip():
valid_results.append(result.strip())
# 使用换行符拼接所有有效结果
return "\n".join(valid_results)
# 创建全局单例
prompt_component_manager = PromptComponentManager()

View File

@@ -0,0 +1,79 @@
"""
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
"""
from dataclasses import dataclass, field
from typing import Any, Literal
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
bot_name: str = ""
bot_nickname: str = ""
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
safety_guidelines_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors

View File

@@ -298,14 +298,14 @@ def random_remove_punctuation(text: str) -> str:
def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]: def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
"""识别并保护数学公式和代码块,返回处理后的文本和映射""" """识别并保护数学公式和代码块,返回处理后的文本和映射"""
placeholder_map = {} placeholder_map = {}
# 第一层防护优先保护标准Markdown格式 # 第一层防护优先保护标准Markdown格式
# 使用 re.S 来让 . 匹配换行符 # 使用 re.S 来让 . 匹配换行符
markdown_patterns = { markdown_patterns = {
'code': r"```.*?```", "code": r"```.*?```",
'math': r"\$\$.*?\$\$", "math": r"\$\$.*?\$\$",
} }
placeholder_idx = 0 placeholder_idx = 0
for block_type, pattern in markdown_patterns.items(): for block_type, pattern in markdown_patterns.items():
matches = re.findall(pattern, text, re.S) matches = re.findall(pattern, text, re.S)
@@ -318,7 +318,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
# 第二层防护:保护非标准的、可能是公式或代码的片段 # 第二层防护:保护非标准的、可能是公式或代码的片段
# 这个正则表达式寻找连续5个以上的、主要由非中文字符组成的片段 # 这个正则表达式寻找连续5个以上的、主要由非中文字符组成的片段
general_pattern = r"(?:[a-zA-Z0-9\s.,;:(){}\[\]_+\-*/=<>^|&%?!'\"√²³ⁿ∑∫≠≥≤]){5,}" general_pattern = r"(?:[a-zA-Z0-9\s.,;:(){}\[\]_+\-*/=<>^|&%?!'\"√²³ⁿ∑∫≠≥≤]){5,}"
# 为了避免与已保护的占位符冲突,我们在剩余的文本上进行查找 # 为了避免与已保护的占位符冲突,我们在剩余的文本上进行查找
# 这是一个简化的处理,更稳妥的方式是分段查找,但目前这样足以应对多数情况 # 这是一个简化的处理,更稳妥的方式是分段查找,但目前这样足以应对多数情况
try: try:
@@ -327,7 +327,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
# 避免将包含占位符的片段再次保护 # 避免将包含占位符的片段再次保护
if "__SPECIAL_" in match: if "__SPECIAL_" in match:
continue continue
placeholder = f"__SPECIAL_GENERAL_{placeholder_idx}__" placeholder = f"__SPECIAL_GENERAL_{placeholder_idx}__"
text = text.replace(match, placeholder, 1) text = text.replace(match, placeholder, 1)
placeholder_map[placeholder] = match placeholder_map[placeholder] = match
@@ -352,23 +352,23 @@ def protect_quoted_content(text: str) -> tuple[str, dict[str, str]]:
placeholder_map = {} placeholder_map = {}
# 匹配中英文单双引号,使用非贪婪模式 # 匹配中英文单双引号,使用非贪婪模式
quote_pattern = re.compile(r'(".*?")|(\'.*?\')|(“.*?”)|(.*?)') quote_pattern = re.compile(r'(".*?")|(\'.*?\')|(“.*?”)|(.*?)')
matches = quote_pattern.finditer(text) matches = quote_pattern.finditer(text)
# 为了避免替换时索引错乱,我们从后往前替换 # 为了避免替换时索引错乱,我们从后往前替换
# finditer 找到的是 match 对象,我们需要转换为 list 来反转 # finditer 找到的是 match 对象,我们需要转换为 list 来反转
match_list = list(matches) match_list = list(matches)
for idx, match in enumerate(reversed(match_list)): for idx, match in enumerate(reversed(match_list)):
original_quoted_text = match.group(0) original_quoted_text = match.group(0)
placeholder = f"__QUOTE_{len(match_list) - 1 - idx}__" placeholder = f"__QUOTE_{len(match_list) - 1 - idx}__"
# 直接在原始文本上操作,替换 match 对象的 span # 直接在原始文本上操作,替换 match 对象的 span
start, end = match.span() start, end = match.span()
text = text[:start] + placeholder + text[end:] text = text[:start] + placeholder + text[end:]
placeholder_map[placeholder] = original_quoted_text placeholder_map[placeholder] = original_quoted_text
return text, placeholder_map return text, placeholder_map
@@ -389,13 +389,13 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
# --- 三层防护系统 --- # --- 三层防护系统 ---
# 第一层:保护颜文字 # 第一层:保护颜文字
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {}) protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})
# 第二层:保护引号内容 # 第二层:保护引号内容
protected_text, quote_mapping = protect_quoted_content(protected_text) protected_text, quote_mapping = protect_quoted_content(protected_text)
# 第三层:保护数学公式和代码块 # 第三层:保护数学公式和代码块
protected_text, special_blocks_mapping = protect_special_blocks(protected_text) protected_text, special_blocks_mapping = protect_special_blocks(protected_text)
# 提取被 () 或 [] 或 ()包裹且包含中文的内容 # 提取被 () 或 [] 或 ()包裹且包含中文的内容
pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]") pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]")
_extracted_contents = pattern.findall(protected_text) _extracted_contents = pattern.findall(protected_text)
@@ -412,7 +412,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
# 对清理后的文本进行进一步处理 # 对清理后的文本进行进一步处理
max_sentence_num = global_config.response_splitter.max_sentence_num max_sentence_num = global_config.response_splitter.max_sentence_num
# --- 移除总长度检查 --- # --- 移除总长度检查 ---
# 原有的总长度检查会导致长回复被直接丢弃,现已移除,由后续的智能合并逻辑处理。 # 原有的总长度检查会导致长回复被直接丢弃,现已移除,由后续的智能合并逻辑处理。
# max_length = global_config.response_splitter.max_length * 2 # max_length = global_config.response_splitter.max_length * 2
@@ -472,7 +472,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
break break
# 寻找最短的相邻句子对 # 寻找最短的相邻句子对
min_len = float('inf') min_len = float("inf")
merge_idx = -1 merge_idx = -1
for i in range(len(sentences) - 1): for i in range(len(sentences) - 1):
combined_len = len(sentences[i]) + len(sentences[i+1]) combined_len = len(sentences[i]) + len(sentences[i+1])
@@ -488,7 +488,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
sentences[merge_idx] = merged_sentence sentences[merge_idx] = merged_sentence
# 删除后一个句子 # 删除后一个句子
del sentences[merge_idx + 1] del sentences[merge_idx + 1]
logger.info(f"智能合并完成,最终消息数量: {len(sentences)}") logger.info(f"智能合并完成,最终消息数量: {len(sentences)}")
# if extracted_contents: # if extracted_contents:

View File

@@ -79,7 +79,7 @@ class Server:
logger.warning(f"端口 {self.port} 已被占用,正在尝试下一个端口...") logger.warning(f"端口 {self.port} 已被占用,正在尝试下一个端口...")
self.port += 1 self.port += 1
logger.info(f"将在 http://{self.host}:{self.port} 上启动服务器") logger.info(f"将在 {self.host}:{self.port} 上启动服务器")
# 禁用 uvicorn 默认日志和访问日志 # 禁用 uvicorn 默认日志和访问日志
config = Config(app=self.app, host=self.host, port=self.port, log_config=None, access_log=False) config = Config(app=self.app, host=self.host, port=self.port, log_config=None, access_log=False)
self._server = UvicornServer(config=config) self._server = UvicornServer(config=config)

View File

@@ -7,7 +7,7 @@ from src.config.config_base import ValidatedConfigBase
""" """
须知: 须知:
1. 本文件中记录了所有的配置项 1. 本文件中记录了所有的配置项
2. 重要的配置类继承自ValidatedConfigBase进行Pydantic验证 2. 所有配置类必须继承自ValidatedConfigBase进行Pydantic验证
3. 所有新增的class都应在config.py中的Config类中添加字段 3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default 4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
""" """
@@ -492,6 +492,7 @@ class LPMMKnowledgeConfig(ValidatedConfigBase):
info_extraction_workers: int = Field(default=3, description="信息提取工作线程数") info_extraction_workers: int = Field(default=3, description="信息提取工作线程数")
qa_relation_search_top_k: int = Field(default=10, description="QA关系搜索Top K") qa_relation_search_top_k: int = Field(default=10, description="QA关系搜索Top K")
qa_relation_threshold: float = Field(default=0.75, description="QA关系阈值") qa_relation_threshold: float = Field(default=0.75, description="QA关系阈值")
qa_paragraph_threshold: float = Field(default=0.3, description="QA段落阈值")
qa_paragraph_search_top_k: int = Field(default=1000, description="QA段落搜索Top K") qa_paragraph_search_top_k: int = Field(default=1000, description="QA段落搜索Top K")
qa_paragraph_node_weight: float = Field(default=0.05, description="QA段落节点权重") qa_paragraph_node_weight: float = Field(default=0.05, description="QA段落节点权重")
qa_ent_filter_top_k: int = Field(default=10, description="QA实体过滤Top K") qa_ent_filter_top_k: int = Field(default=10, description="QA实体过滤Top K")

View File

@@ -13,6 +13,7 @@ from rich.traceback import install
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.memory_system.memory_manager import memory_manager from src.chat.memory_system.memory_manager import memory_manager
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
@@ -29,7 +30,6 @@ from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
from src.schedule.monthly_plan_manager import monthly_plan_manager from src.schedule.monthly_plan_manager import monthly_plan_manager
from src.schedule.schedule_manager import schedule_manager from src.schedule.schedule_manager import schedule_manager
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
# 插件系统现在使用统一的插件加载器 # 插件系统现在使用统一的插件加载器
install(extra_lines=3) install(extra_lines=3)

View File

@@ -26,6 +26,7 @@ from .base import (
ActionInfo, ActionInfo,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BasePrompt,
BaseEventHandler, BaseEventHandler,
BasePlugin, BasePlugin,
BaseTool, BaseTool,
@@ -64,6 +65,7 @@ __all__ = [
"BaseEventHandler", "BaseEventHandler",
# 基础类 # 基础类
"BasePlugin", "BasePlugin",
"BasePrompt",
"BaseTool", "BaseTool",
"ChatMode", "ChatMode",
"ChatType", "ChatType",

View File

@@ -8,6 +8,7 @@ from .base_action import BaseAction
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
from .base_prompt import BasePrompt
from .base_tool import BaseTool from .base_tool import BaseTool
from .command_args import CommandArgs from .command_args import CommandArgs
from .component_types import ( from .component_types import (
@@ -37,6 +38,7 @@ __all__ = [
"BaseCommand", "BaseCommand",
"BaseEventHandler", "BaseEventHandler",
"BasePlugin", "BasePlugin",
"BasePrompt",
"BaseTool", "BaseTool",
"ChatMode", "ChatMode",
"ChatType", "ChatType",

View File

@@ -615,15 +615,15 @@ class BaseAction(ABC):
""" """
# 尝试从不同的实例属性中获取聊天内容 # 尝试从不同的实例属性中获取聊天内容
# 优先级_activation_chat_content > action_data['chat_content'] > "" # 优先级_activation_chat_content > action_data['chat_content'] > ""
# 1. 如果有专门设置的激活用聊天内容(由 ActionModifier 设置) # 1. 如果有专门设置的激活用聊天内容(由 ActionModifier 设置)
if hasattr(self, '_activation_chat_content'): if hasattr(self, "_activation_chat_content"):
return getattr(self, '_activation_chat_content', "") return getattr(self, "_activation_chat_content", "")
# 2. 尝试从 action_data 中获取 # 2. 尝试从 action_data 中获取
if hasattr(self, 'action_data') and isinstance(self.action_data, dict): if hasattr(self, "action_data") and isinstance(self.action_data, dict):
return self.action_data.get('chat_content', "") return self.action_data.get("chat_content", "")
# 3. 默认返回空字符串 # 3. 默认返回空字符串
return "" return ""
@@ -729,7 +729,7 @@ class BaseAction(ABC):
# 自动获取聊天内容 # 自动获取聊天内容
chat_content = self._get_chat_content() chat_content = self._get_chat_content()
search_text = chat_content search_text = chat_content
if not case_sensitive: if not case_sensitive:
search_text = search_text.lower() search_text = search_text.lower()
@@ -786,7 +786,7 @@ class BaseAction(ABC):
try: try:
# 自动获取聊天内容 # 自动获取聊天内容
chat_content = self._get_chat_content() chat_content = self._get_chat_content()
# 如果没有提供 LLM 模型,创建一个默认的 # 如果没有提供 LLM 模型,创建一个默认的
if llm_judge_model is None: if llm_judge_model is None:
from src.config.config import model_config from src.config.config import model_config

View File

@@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import (
EventHandlerInfo, EventHandlerInfo,
InterestCalculatorInfo, InterestCalculatorInfo,
PlusCommandInfo, PlusCommandInfo,
PromptInfo,
ToolInfo, ToolInfo,
) )
@@ -15,6 +16,7 @@ from .base_action import BaseAction
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .base_interest_calculator import BaseInterestCalculator from .base_interest_calculator import BaseInterestCalculator
from .base_prompt import BasePrompt
from .base_tool import BaseTool from .base_tool import BaseTool
from .plugin_base import PluginBase from .plugin_base import PluginBase
from .plus_command import PlusCommand from .plus_command import PlusCommand
@@ -80,6 +82,13 @@ class BasePlugin(PluginBase):
logger.warning("EventHandler的get_info逻辑尚未实现") logger.warning("EventHandler的get_info逻辑尚未实现")
return None return None
elif component_type == ComponentType.PROMPT:
if hasattr(component_class, "get_prompt_info"):
return component_class.get_prompt_info()
else:
logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法")
return None
else: else:
logger.error(f"不支持的组件类型: {component_type}") logger.error(f"不支持的组件类型: {component_type}")
return None return None
@@ -109,6 +118,7 @@ class BasePlugin(PluginBase):
| tuple[EventHandlerInfo, type[BaseEventHandler]] | tuple[EventHandlerInfo, type[BaseEventHandler]]
| tuple[ToolInfo, type[BaseTool]] | tuple[ToolInfo, type[BaseTool]]
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]] | tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
| tuple[PromptInfo, type[BasePrompt]]
]: ]:
"""获取插件包含的组件列表 """获取插件包含的组件列表

View File

@@ -0,0 +1,95 @@
from abc import ABC, abstractmethod
from typing import Any
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, PromptInfo
logger = get_logger("base_prompt")
class BasePrompt(ABC):
"""Prompt组件基类
Prompt是插件的一种组件类型用于动态地向现有的核心Prompt模板中注入额外的上下文信息。
它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。
子类可以通过类属性定义其行为:
- prompt_name: Prompt组件的唯一名称。
- injection_point: 指定要注入的目标Prompt名称或名称列表
"""
prompt_name: str = ""
"""Prompt组件的名称"""
prompt_description: str = ""
"""Prompt组件的描述"""
# 定义此组件希望注入到哪个或哪些核心Prompt中
# 可以是一个字符串(单个目标)或字符串列表(多个目标)
# 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"]
injection_point: str | list[str] = ""
"""要注入的目标Prompt名称或列表"""
def __init__(self, params: PromptParameters, plugin_config: dict | None = None):
"""初始化Prompt组件
Args:
params: 统一提示词参数,包含所有构建提示词所需的上下文信息。
plugin_config: 插件配置字典。
"""
self.params = params
self.plugin_config = plugin_config or {}
self.log_prefix = "[PromptComponent]"
logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成")
@abstractmethod
async def execute(self) -> str:
"""执行Prompt生成的抽象方法子类必须实现。
此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。
返回的字符串将被拼接到目标Prompt的最前面。
Returns:
str: 生成的文本内容。
"""
pass
def get_config(self, key: str, default: Any = None) -> Any:
"""获取插件配置值,支持嵌套键访问。
Args:
key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"
default: 未找到键时返回的默认值。
Returns:
Any: 配置值或默认值。
"""
if not self.plugin_config:
return default
keys = key.split(".")
current = self.plugin_config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
@classmethod
def get_prompt_info(cls) -> "PromptInfo":
"""从类属性生成PromptInfo用于组件注册和管理。
Returns:
PromptInfo: 生成的Prompt信息对象。
"""
if not cls.prompt_name:
raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。")
return PromptInfo(
name=cls.prompt_name,
component_type=ComponentType.PROMPT,
description=cls.prompt_description,
injection_point=cls.injection_point,
)

View File

@@ -20,6 +20,7 @@ class ComponentType(Enum):
EVENT_HANDLER = "event_handler" # 事件处理组件 EVENT_HANDLER = "event_handler" # 事件处理组件
CHATTER = "chatter" # 聊天处理器组件 CHATTER = "chatter" # 聊天处理器组件
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件 INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
PROMPT = "prompt" # Prompt组件
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@@ -143,7 +144,7 @@ class ActionInfo(ComponentInfo):
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} ) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: list[str] = field(default_factory=list) # 动作需求说明 action_require: list[str] = field(default_factory=list) # 动作需求说明
associated_types: list[str] = field(default_factory=list) # 关联的消息类型 associated_types: list[str] = field(default_factory=list) # 关联的消息类型
# ================================================================================== # ==================================================================================
# 激活类型相关字段(已废弃,建议使用 go_activate() 方法) # 激活类型相关字段(已废弃,建议使用 go_activate() 方法)
# 保留这些字段是为了向后兼容BaseAction.go_activate() 的默认实现会使用这些字段 # 保留这些字段是为了向后兼容BaseAction.go_activate() 的默认实现会使用这些字段
@@ -155,7 +156,7 @@ class ActionInfo(ComponentInfo):
llm_judge_prompt: str = "" # 已废弃,建议在 go_activate() 中使用 _llm_judge_activation() llm_judge_prompt: str = "" # 已废弃,建议在 go_activate() 中使用 _llm_judge_activation()
activation_keywords: list[str] = field(default_factory=list) # 已废弃,建议在 go_activate() 中使用 _keyword_match() activation_keywords: list[str] = field(default_factory=list) # 已废弃,建议在 go_activate() 中使用 _keyword_match()
keyword_case_sensitive: bool = False # 已废弃 keyword_case_sensitive: bool = False # 已废弃
# 模式和并行设置 # 模式和并行设置
mode_enable: ChatMode = ChatMode.ALL mode_enable: ChatMode = ChatMode.ALL
parallel_action: bool = False parallel_action: bool = False
@@ -266,6 +267,18 @@ class EventInfo(ComponentInfo):
self.component_type = ComponentType.EVENT_HANDLER self.component_type = ComponentType.EVENT_HANDLER
@dataclass
class PromptInfo(ComponentInfo):
"""Prompt组件信息"""
injection_point: str | list[str] = ""
"""要注入的目标Prompt名称或列表"""
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.PROMPT
@dataclass @dataclass
class PluginInfo: class PluginInfo:
"""插件信息""" """插件信息"""

View File

@@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
ActionInfo, ActionInfo,
@@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import (
InterestCalculatorInfo, InterestCalculatorInfo,
PluginInfo, PluginInfo,
PlusCommandInfo, PlusCommandInfo,
PromptInfo,
ToolInfo, ToolInfo,
) )
from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.plus_command import PlusCommand
@@ -37,6 +39,7 @@ ComponentClassType = (
| type[PlusCommand] | type[PlusCommand]
| type[BaseChatter] | type[BaseChatter]
| type[BaseInterestCalculator] | type[BaseInterestCalculator]
| type[BasePrompt]
) )
@@ -183,6 +186,10 @@ class ComponentRegistry:
assert isinstance(component_info, InterestCalculatorInfo) assert isinstance(component_info, InterestCalculatorInfo)
assert issubclass(component_class, BaseInterestCalculator) assert issubclass(component_class, BaseInterestCalculator)
ret = self._register_interest_calculator_component(component_info, component_class) ret = self._register_interest_calculator_component(component_info, component_class)
case ComponentType.PROMPT:
assert isinstance(component_info, PromptInfo)
assert issubclass(component_class, BasePrompt)
ret = self._register_prompt_component(component_info, component_class)
case _: case _:
logger.warning(f"未知组件类型: {component_type}") logger.warning(f"未知组件类型: {component_type}")
ret = False ret = False
@@ -346,6 +353,31 @@ class ComponentRegistry:
logger.debug(f"已注册InterestCalculator组件: {calculator_name}") logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
return True return True
def _register_prompt_component(
self, prompt_info: PromptInfo, prompt_class: "ComponentClassType"
) -> bool:
"""注册Prompt组件到Prompt特定注册表"""
prompt_name = prompt_info.name
if not prompt_name:
logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称")
return False
if not hasattr(self, "_prompt_registry"):
self._prompt_registry: dict[str, type[BasePrompt]] = {}
if not hasattr(self, "_enabled_prompt_registry"):
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
_assign_plugin_attrs(
prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {}
)
self._prompt_registry[prompt_name] = prompt_class # type: ignore
if prompt_info.enabled:
self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore
logger.debug(f"已注册Prompt组件: {prompt_name}")
return True
# === 组件移除相关 === # === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
@@ -580,7 +612,17 @@ class ComponentRegistry:
component_name: str, component_name: str,
component_type: ComponentType | None = None, component_type: ComponentType | None = None,
) -> ( ) -> (
type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None type[
BaseCommand
| BaseAction
| BaseEventHandler
| BaseTool
| PlusCommand
| BaseChatter
| BaseInterestCalculator
| BasePrompt
]
| None
): ):
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
@@ -829,6 +871,7 @@ class ComponentRegistry:
events_handlers: int = 0 events_handlers: int = 0
plus_command_components: int = 0 plus_command_components: int = 0
chatter_components: int = 0 chatter_components: int = 0
prompt_components: int = 0
for component in self._components.values(): for component in self._components.values():
if component.component_type == ComponentType.ACTION: if component.component_type == ComponentType.ACTION:
action_components += 1 action_components += 1
@@ -842,6 +885,8 @@ class ComponentRegistry:
plus_command_components += 1 plus_command_components += 1
elif component.component_type == ComponentType.CHATTER: elif component.component_type == ComponentType.CHATTER:
chatter_components += 1 chatter_components += 1
elif component.component_type == ComponentType.PROMPT:
prompt_components += 1
return { return {
"action_components": action_components, "action_components": action_components,
"command_components": command_components, "command_components": command_components,
@@ -849,6 +894,7 @@ class ComponentRegistry:
"event_handlers": events_handlers, "event_handlers": events_handlers,
"plus_command_components": plus_command_components, "plus_command_components": plus_command_components,
"chatter_components": chatter_components, "chatter_components": chatter_components,
"prompt_components": prompt_components,
"total_components": len(self._components), "total_components": len(self._components),
"total_plugins": len(self._plugins), "total_plugins": len(self._plugins),
"components_by_type": { "components_by_type": {

View File

@@ -358,13 +358,14 @@ class PluginManager:
event_handler_count = stats.get("event_handlers", 0) event_handler_count = stats.get("event_handlers", 0)
plus_command_count = stats.get("plus_command_components", 0) plus_command_count = stats.get("plus_command_components", 0)
chatter_count = stats.get("chatter_components", 0) chatter_count = stats.get("chatter_components", 0)
prompt_count = stats.get("prompt_components", 0)
total_components = stats.get("total_components", 0) total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览 # 📋 显示插件加载总览
if total_registered > 0: if total_registered > 0:
logger.info("🎉 插件系统加载完成!") logger.info("🎉 插件系统加载完成!")
logger.info( logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})" f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
) )
# 显示详细的插件列表 # 显示详细的插件列表
@@ -382,6 +383,13 @@ class PluginManager:
# 组件列表 # 组件列表
if plugin_info.components: if plugin_info.components:
def format_component(c):
desc = c.description
if len(desc) > 15:
desc = desc[:15] + "..."
return f"{c.name} ({desc})" if desc else c.name
action_components = [ action_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ACTION c for c in plugin_info.components if c.component_type == ComponentType.ACTION
] ]
@@ -395,29 +403,35 @@ class PluginManager:
plus_command_components = [ plus_command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
] ]
prompt_components = [
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
]
if action_components: if action_components:
action_names = [c.name for c in action_components] action_details = [format_component(c) for c in action_components]
logger.info(f" 🎯 Action组件: {', '.join(action_names)}") logger.info(f" 🎯 Action组件: {', '.join(action_details)}")
if command_components: if command_components:
command_names = [c.name for c in command_components] command_details = [format_component(c) for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}") logger.info(f" ⚡ Command组件: {', '.join(command_details)}")
if tool_components: if tool_components:
tool_names = [c.name for c in tool_components] tool_details = [format_component(c) for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}") logger.info(f" 🛠️ Tool组件: {', '.join(tool_details)}")
if plus_command_components: if plus_command_components:
plus_command_names = [c.name for c in plus_command_components] plus_command_details = [format_component(c) for c in plus_command_components]
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}") logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_details)}")
chatter_components = [ chatter_components = [
c for c in plugin_info.components if c.component_type == ComponentType.CHATTER c for c in plugin_info.components if c.component_type == ComponentType.CHATTER
] ]
if chatter_components: if chatter_components:
chatter_names = [c.name for c in chatter_components] chatter_details = [format_component(c) for c in chatter_components]
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}") logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_details)}")
if event_handler_components: if event_handler_components:
event_handler_names = [c.name for c in event_handler_components] event_handler_details = [format_component(c) for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
if prompt_components:
prompt_details = [format_component(c) for c in prompt_components]
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
# 权限节点信息 # 权限节点信息
if plugin_instance := self.loaded_plugins.get(plugin_name): if plugin_instance := self.loaded_plugins.get(plugin_name):

View File

@@ -155,88 +155,22 @@ class ChatterPlanFilter:
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
schedule_block = "" schedule_block = ""
# 优先检查是否被吵醒 if global_config.planning_system.schedule_enable:
from src.chat.message_manager.message_manager import message_manager
angry_prompt_addition = ""
try:
from src.plugins.built_in.sleep_system.api import get_wakeup_manager
wakeup_mgr = get_wakeup_manager()
except ImportError:
logger.debug("无法导入睡眠系统API将跳过相关检查。")
wakeup_mgr = None
if wakeup_mgr:
# 双重检查确保愤怒状态不会丢失
# 检查1: 直接从 wakeup_manager 获取
if wakeup_mgr.is_in_angry_state():
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
# 检查2: 如果上面没获取到,再从 mood_manager 确认
if not angry_prompt_addition:
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
if chat_mood_for_check.is_angry_from_wakeup:
angry_prompt_addition = global_config.sleep_system.angry_prompt
if angry_prompt_addition:
schedule_block = angry_prompt_addition
elif global_config.planning_system.schedule_enable:
if activity_info := schedule_manager.get_current_activity(): if activity_info := schedule_manager.get_current_activity():
activity = activity_info.get("activity", "未知活动") activity = activity_info.get("activity", "未知活动")
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。" schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
mood_block = "" mood_block = ""
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块 # 需要情绪模块打开才能获得情绪,否则会引发报错
if not angry_prompt_addition and global_config.mood.enable_mood: if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id) chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
mood_block = f"你现在的心情是:{chat_mood.mood_state}" mood_block = f"你现在的心情是:{chat_mood.mood_state}"
if plan.mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context()
chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
truncate=False,
show_actions=False,
)
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),
limit=5,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
prompt = prompt_template.format(
time_block=time_block,
identity_block=identity_block,
schedule_block=schedule_block,
mood_block=mood_block,
long_term_memory_block=long_term_memory_block,
chat_content_block=chat_content_block or "最近没有聊天内容。",
actions_before_now_block=actions_before_now_block,
)
return prompt, message_id_list
# 构建已读/未读历史消息 # 构建已读/未读历史消息
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks( read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
plan plan
) )
# 为了兼容性保留原有的chat_content_block
chat_content_block, _ = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
read_mark=self.last_obs_time_mark,
truncate=True,
show_actions=True,
)
actions_before_now = await get_actions_by_timestamp_with_chat( actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id, chat_id=plan.chat_id,
timestamp_start=time.time() - 3600, timestamp_start=time.time() - 3600,
@@ -286,7 +220,7 @@ class ChatterPlanFilter:
is_group_chat = plan.chat_type == ChatType.GROUP is_group_chat = plan.chat_type == ChatType.GROUP
chat_context_description = "你现在正在一个群聊中" chat_context_description = "你现在正在一个群聊中"
if not is_group_chat and plan.target_info: if not is_group_chat and plan.target_info:
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
chat_context_description = f"你正在和 {chat_target_name} 私聊" chat_context_description = f"你正在和 {chat_target_name} 私聊"
action_options_block = await self._build_action_options(plan.available_actions) action_options_block = await self._build_action_options(plan.available_actions)

View File

@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan, TargetPersonInfo from src.common.data_models.info_data_model import Plan, TargetPersonInfo
from src.config.config import global_config from src.config.config import global_config
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
@@ -55,6 +55,11 @@ class ChatterPlanGenerator:
try: try:
# 获取聊天类型和目标信息 # 获取聊天类型和目标信息
chat_type, target_info = await get_chat_type_and_target_info(self.chat_id) chat_type, target_info = await get_chat_type_and_target_info(self.chat_id)
if chat_type:
chat_type = ChatType.GROUP
else:
#遇到未知类型也当私聊处理
chat_type = ChatType.PRIVATE
# 获取可用动作列表 # 获取可用动作列表
available_actions = await self._get_available_actions(chat_type, mode) available_actions = await self._get_available_actions(chat_type, mode)
@@ -62,12 +67,16 @@ class ChatterPlanGenerator:
# 获取聊天历史记录 # 获取聊天历史记录
recent_messages = await self._get_recent_messages() recent_messages = await self._get_recent_messages()
# 构建计划对象
# 使用 target_info 字典创建 TargetPersonInfo 实例
target_person_info = TargetPersonInfo(**target_info) if target_info else TargetPersonInfo()
# 构建计划对象 # 构建计划对象
plan = Plan( plan = Plan(
chat_id=self.chat_id, chat_id=self.chat_id,
chat_type=chat_type, chat_type=chat_type,
mode=mode, mode=mode,
target_info=target_info, target_info=target_person_info,
available_actions=available_actions, available_actions=available_actions,
chat_history=recent_messages, chat_history=recent_messages,
) )
@@ -77,6 +86,7 @@ class ChatterPlanGenerator:
except Exception: except Exception:
# 如果生成失败,返回一个基本的空计划 # 如果生成失败,返回一个基本的空计划
return Plan( return Plan(
chat_type = ChatType.PRIVATE,#空计划默认当成私聊
chat_id=self.chat_id, chat_id=self.chat_id,
mode=mode, mode=mode,
target_info=TargetPersonInfo(), target_info=TargetPersonInfo(),
@@ -124,7 +134,7 @@ class ChatterPlanGenerator:
try: try:
# 获取最近的消息记录 # 获取最近的消息记录
raw_messages = await get_raw_msg_before_timestamp_with_chat( raw_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length chat_id=self.chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size
) )
# 转换为 DatabaseMessages 对象 # 转换为 DatabaseMessages 对象

View File

@@ -70,6 +70,7 @@ class ChatterActionPlanner:
"replies_generated": 0, "replies_generated": 0,
"other_actions_executed": 0, "other_actions_executed": 0,
} }
self._background_tasks: set[asyncio.Task] = set()
async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]: async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]:
""" """
@@ -157,7 +158,9 @@ class ChatterActionPlanner:
) )
if interest_updates: if interest_updates:
asyncio.create_task(self._commit_interest_updates(interest_updates)) task = asyncio.create_task(self._commit_interest_updates(interest_updates))
self._background_tasks.add(task)
task.add_done_callback(self._handle_task_result)
# 检查兴趣度是否达到非回复动作阈值 # 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
@@ -266,6 +269,17 @@ class ChatterActionPlanner:
return final_actions_dict, final_target_message_dict return final_actions_dict, final_target_message_dict
def _handle_task_result(self, task: asyncio.Task) -> None:
"""处理后台任务的结果,记录异常。"""
try:
task.result()
except asyncio.CancelledError:
pass # 任务被取消是正常现象
except Exception as e:
logger.error(f"后台任务执行失败: {e}", exc_info=True)
finally:
self._background_tasks.discard(task)
def get_planner_stats(self) -> dict[str, Any]: def get_planner_stats(self) -> dict[str, Any]:
"""获取规划器统计""" """获取规划器统计"""
return self.planner_stats.copy() return self.planner_stats.copy()

View File

@@ -15,7 +15,7 @@ logger = get_logger(__name__)
@register_plugin @register_plugin
class ProactiveThinkerPlugin(BasePlugin): class ProactiveThinkerPlugin(BasePlugin):
"""一个主动思考的插件,但现在还只是个空壳子""" """一个主动思考的插件"""
plugin_name: str = "proactive_thinker" plugin_name: str = "proactive_thinker"
enable_plugin: bool = True enable_plugin: bool = True

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -13,7 +14,6 @@ from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system import BaseEventHandler, EventType from src.plugin_system import BaseEventHandler, EventType
from src.plugin_system.apis import chat_api, message_api, person_api from src.plugin_system.apis import chat_api, message_api, person_api
from src.plugin_system.base.base_event import HandlerResult from src.plugin_system.base.base_event import HandlerResult
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
from .proactive_thinker_executor import ProactiveThinkerExecutor from .proactive_thinker_executor import ProactiveThinkerExecutor

View File

@@ -3,7 +3,7 @@ Base search engine interface
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, Optional
class BaseSearchEngine(ABC): class BaseSearchEngine(ABC):
@@ -24,6 +24,12 @@ class BaseSearchEngine(ABC):
""" """
pass pass
async def read_url(self, url: str) -> Optional[str]:
"""
读取URL内容如果引擎不支持则返回None
"""
return None
@abstractmethod @abstractmethod
def is_available(self) -> bool: def is_available(self) -> bool:
""" """

View File

@@ -0,0 +1,107 @@
"""
Metaso Search Engine (Chat Completions Mode)
"""
import json
from typing import Any, List
import httpx
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
from ..utils.api_key_manager import create_api_key_manager_from_config
from .base import BaseSearchEngine
logger = get_logger(__name__)
class MetasoClient:
"""A client to interact with the Metaso API."""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://metaso.cn/api/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
}
async def search(self, query: str, **kwargs) -> List[dict[str, Any]]:
"""Perform a search using the Metaso Chat Completions API."""
payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]}
search_url = f"{self.base_url}/chat/completions"
full_response_content = ""
async with httpx.AsyncClient(timeout=90.0) as client:
try:
async with client.stream("POST", search_url, headers=self.headers, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:"):
data_str = line[len("data:") :].strip()
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response_content += content_chunk
except json.JSONDecodeError:
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
continue
if not full_response_content:
logger.warning("Metaso search returned an empty stream.")
return []
return [
{
"title": query,
"url": "https://metaso.cn/",
"snippet": full_response_content,
"provider": "Metaso (Chat)",
}
]
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error occurred while searching with Metaso Chat: {e.response.text}")
return []
except Exception as e:
logger.error(f"An error occurred while searching with Metaso Chat: {e}", exc_info=True)
return []
class MetasoSearchEngine(BaseSearchEngine):
"""Metaso Search Engine implementation."""
def __init__(self):
self._initialize_clients()
def _initialize_clients(self):
"""Initialize Metaso clients."""
metaso_api_keys = config_api.get_global_config("web_search.metaso_api_keys", None)
self.api_manager = create_api_key_manager_from_config(
metaso_api_keys, lambda key: MetasoClient(api_key=key), "Metaso"
)
def is_available(self) -> bool:
"""Check if the Metaso search engine is available."""
return self.api_manager.is_available()
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""Execute a Metaso search."""
if not self.is_available():
return []
query = args["query"]
try:
metaso_client = self.api_manager.get_next_client()
if not metaso_client:
logger.error("Could not get Metaso client.")
return []
return await metaso_client.search(query)
except Exception as e:
logger.error(f"Metaso search failed: {e}", exc_info=True)
return []

View File

@@ -22,6 +22,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
提供网络搜索和URL解析功能支持多种搜索引擎 提供网络搜索和URL解析功能支持多种搜索引擎
- Exa (需要API密钥) - Exa (需要API密钥)
- Tavily (需要API密钥) - Tavily (需要API密钥)
- Metaso (需要API密钥)
- DuckDuckGo (免费) - DuckDuckGo (免费)
- Bing (免费) - Bing (免费)
""" """
@@ -43,6 +44,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.exa_engine import ExaSearchEngine from .engines.exa_engine import ExaSearchEngine
from .engines.searxng_engine import SearXNGSearchEngine from .engines.searxng_engine import SearXNGSearchEngine
from .engines.tavily_engine import TavilySearchEngine from .engines.tavily_engine import TavilySearchEngine
from .engines.metaso_engine import MetasoSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化 # 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine() exa_engine = ExaSearchEngine()
@@ -50,14 +52,16 @@ class WEBSEARCHPLUGIN(BasePlugin):
ddg_engine = DDGSearchEngine() ddg_engine = DDGSearchEngine()
bing_engine = BingSearchEngine() bing_engine = BingSearchEngine()
searxng_engine = SearXNGSearchEngine() searxng_engine = SearXNGSearchEngine()
metaso_engine = MetasoSearchEngine()
# 报告每个引擎的状态
# 报告每个引擎的状态
engines_status = { engines_status = {
"Exa": exa_engine.is_available(), "Exa": exa_engine.is_available(),
"Tavily": tavily_engine.is_available(), "Tavily": tavily_engine.is_available(),
"DuckDuckGo": ddg_engine.is_available(), "DuckDuckGo": ddg_engine.is_available(),
"Bing": bing_engine.is_available(), "Bing": bing_engine.is_available(),
"SearXNG": searxng_engine.is_available(), "SearXNG": searxng_engine.is_available(),
"Metaso": metaso_engine.is_available(),
} }
available_engines = [name for name, available in engines_status.items() if available] available_engines = [name for name, available in engines_status.items() if available]

View File

@@ -15,6 +15,7 @@ from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine from ..engines.exa_engine import ExaSearchEngine
from ..engines.searxng_engine import SearXNGSearchEngine from ..engines.searxng_engine import SearXNGSearchEngine
from ..engines.tavily_engine import TavilySearchEngine from ..engines.tavily_engine import TavilySearchEngine
from ..engines.metaso_engine import MetasoSearchEngine
from ..utils.formatters import deduplicate_results, format_search_results from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool") logger = get_logger("web_search_tool")
@@ -51,6 +52,7 @@ class WebSurfingTool(BaseTool):
"ddg": DDGSearchEngine(), "ddg": DDGSearchEngine(),
"bing": BingSearchEngine(), "bing": BingSearchEngine(),
"searxng": SearXNGSearchEngine(), "searxng": SearXNGSearchEngine(),
"metaso": MetasoSearchEngine(),
} }
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "7.3.2" version = "7.3.3"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读---- #----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值 #如果你想要修改配置文件请递增version的值
@@ -326,6 +326,7 @@ info_extraction_workers = 3 # 实体提取同时执行线程数非Pro模型
qa_relation_search_top_k = 10 # 关系搜索TopK qa_relation_search_top_k = 10 # 关系搜索TopK
qa_relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系) qa_relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系)
qa_paragraph_search_top_k = 1000 # 段落搜索TopK不能过小可能影响搜索结果 qa_paragraph_search_top_k = 1000 # 段落搜索TopK不能过小可能影响搜索结果
qa_paragraph_threshold = 0.4 # 段落阈值(相似度高于此阈值的段落才会被认为是相关的)
qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重当搜索仅使用DPR时此参数不起作用 qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重当搜索仅使用DPR时此参数不起作用
qa_ent_filter_top_k = 10 # 实体过滤TopK qa_ent_filter_top_k = 10 # 实体过滤TopK
qa_ppr_damping = 0.8 # PPR阻尼系数 qa_ppr_damping = 0.8 # PPR阻尼系数
@@ -473,11 +474,12 @@ enable_web_search_tool = true # 是否启用联网搜索tool
enable_url_tool = true # 是否启用URL解析tool enable_url_tool = true # 是否启用URL解析tool
tavily_api_keys = ["None"]# Tavily API密钥列表支持轮询机制 tavily_api_keys = ["None"]# Tavily API密钥列表支持轮询机制
exa_api_keys = ["None"]# EXA API密钥列表支持轮询机制 exa_api_keys = ["None"]# EXA API密钥列表支持轮询机制
metaso_api_keys = ["None"]# Metaso API密钥列表支持轮询机制
searxng_instances = [] # SearXNG 实例 URL 列表 searxng_instances = [] # SearXNG 实例 URL 列表
searxng_api_keys = []# SearXNG 实例 API 密钥列表 searxng_api_keys = []# SearXNG 实例 API 密钥列表
# 搜索引擎配置 # 搜索引擎配置
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing" enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing", "metaso"
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个) search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
[sleep_system] [sleep_system]

View File

@@ -107,7 +107,7 @@ class UILogHandler(logging.Handler):
# if not success: # if not success:
# print(f"[UI日志适配器] 日志发送失败: {ui_level} - {formatted_msg[:50]}...") # print(f"[UI日志适配器] 日志发送失败: {ui_level} - {formatted_msg[:50]}...")
except Exception as e: except Exception:
# 静默失败,不影响主程序 # 静默失败,不影响主程序
pass pass

2006
uv.lock generated

File diff suppressed because it is too large Load Diff