Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
156
.github/workflows/docker-image.yml
vendored
156
.github/workflows/docker-image.yml
vendored
@@ -1,70 +1,51 @@
|
||||
name: Docker CI
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
# - develop
|
||||
# tags:
|
||||
# - "v*.*.*"
|
||||
# - "v*"
|
||||
# - "*.*.*"
|
||||
# - "*.*.*-*"
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: 构建 AMD64 镜像
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: 检出 Git 仓库
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: 克隆 maim_message
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: MaiM-with-u/maim_message
|
||||
path: maim_message
|
||||
|
||||
- name: 克隆 MaiMBot-LPMM
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: MaiM-with-u/MaiMBot-LPMM
|
||||
path: MaiMBot-LPMM
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
- name: 登录到 Docker Hub
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Docker 元数据
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
- name: 动态生成镜像标签
|
||||
id: tag
|
||||
run: |
|
||||
if [ "$GITHUB_REF" == "refs/heads/master" ]; then
|
||||
echo "tag=latest" >> $GITHUB_ENV
|
||||
elif [ "$GITHUB_REF" == "refs/heads/develop" ]; then
|
||||
echo "tag=dev" >> $GITHUB_ENV
|
||||
else
|
||||
echo "tag=${{ github.ref_name }}" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: 构建并推送 AMD64 镜像
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@@ -72,10 +53,97 @@ jobs:
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/maibot:amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/maibot:${{ env.tag }},name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
BRANCH_NAME=${{ github.ref_name }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
20
Dockerfile
20
Dockerfile
@@ -2,31 +2,19 @@ FROM python:3.13.5-slim-bookworm
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# 工作目录
|
||||
WORKDIR /mmc
|
||||
WORKDIR /app
|
||||
|
||||
# 复制依赖列表
|
||||
COPY requirements.txt .
|
||||
# 同级目录下需要有 maim_message MaiMBot-LPMM
|
||||
#COPY maim_message /maim_message
|
||||
COPY MaiMBot-LPMM /MaiMBot-LPMM
|
||||
COPY pyproject.toml .
|
||||
|
||||
# 编译器
|
||||
RUN apt-get update && apt-get install -y build-essential
|
||||
|
||||
# lpmm编译安装
|
||||
RUN cd /MaiMBot-LPMM && uv pip install --system -r requirements.txt
|
||||
RUN uv pip install --system Cython py-cpuinfo setuptools
|
||||
RUN cd /MaiMBot-LPMM/lib/quick_algo && python build_lib.py --cleanup --cythonize --install
|
||||
|
||||
|
||||
# 安装依赖
|
||||
RUN uv pip install --system --upgrade pip
|
||||
#RUN uv pip install --system -e /maim_message
|
||||
RUN uv pip install --system -r requirements.txt
|
||||
|
||||
# 复制项目代码
|
||||
RUN uv sync
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT [ "python","bot.py" ]
|
||||
ENTRYPOINT [ "uv","run","bot.py" ]
|
||||
@@ -38,12 +38,12 @@
|
||||
|
||||
**MoFox_Bot** 是一个基于 [MaiCore](https://github.com/MaiM-with-u/MaiBot) `0.10.0 snapshot.5` 的增强型 fork 项目。我们保留了原项目几乎所有核心功能,并在此基础上进行了深度优化与功能扩展,致力于打造一个**更稳定、更智能、更具趣味性**的 AI 智能体。
|
||||
|
||||
> [!IMPORTANT]
|
||||
> [IMPORTANT]
|
||||
> **第三方项目声明**
|
||||
>
|
||||
> 本项目由 **MoFox Studio** 独立维护,为 **MaiBot 的第三方分支**,并非官方版本。所有更新与支持均由我们团队负责,与 MaiBot 官方无直接关系。
|
||||
|
||||
> [!WARNING]
|
||||
> [WARNING]
|
||||
> **迁移风险提示**
|
||||
>
|
||||
> 由于我们对数据库结构进行了重构与优化,从官方 MaiBot 直接迁移至 MoFox_Bot **可能导致数据不兼容**。请在迁移前**务必备份原始数据**,以避免信息丢失。
|
||||
@@ -63,8 +63,6 @@
|
||||
<td width="50%">
|
||||
|
||||
### 🔧 原版功能(全部保留)
|
||||
|
||||
- 🧠 **智能对话系统** - 基于 LLM 的自然语言交互,支持 normal 和 focus 统一化处理
|
||||
- 🔌 **强大插件系统** - 全面重构的插件架构,支持完整的管理 API 和权限控制
|
||||
- 💭 **实时思维系统** - 模拟人类思考过程
|
||||
- 📚 **表达学习功能** - 学习群友的说话风格和表达方式
|
||||
@@ -78,6 +76,7 @@
|
||||
|
||||
### 🚀 拓展功能
|
||||
|
||||
- 🧠 **AFC 智能对话** - 基于亲和力流,实现兴趣感知和动态关系构建
|
||||
- 🔄 **数据库切换** - 支持 SQLite 与 MySQL 自由切换,采用 SQLAlchemy 2.0 重新构建
|
||||
- 🛡️ **反注入集成** - 内置一整套回复前注入过滤系统,为人格保驾护航
|
||||
- 🎥 **视频分析** - 支持多种视频识别模式,拓展原版视觉
|
||||
|
||||
@@ -1,47 +1,22 @@
|
||||
services:
|
||||
adapters:
|
||||
container_name: maim-bot-adapters
|
||||
#### prod ####
|
||||
image: unclas/maimbot-adapter:latest
|
||||
# image: infinitycat/maimbot-adapter:latest
|
||||
#### dev ####
|
||||
# image: unclas/maimbot-adapter:dev
|
||||
# image: infinitycat/maimbot-adapter:dev
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
# ports:
|
||||
# - "8095:8095"
|
||||
volumes:
|
||||
- ./docker-config/adapters/config.toml:/adapters/config.toml # 持久化adapters配置文件
|
||||
- ./data/adapters:/adapters/data # adapters 数据持久化
|
||||
restart: always
|
||||
networks:
|
||||
- maim_bot
|
||||
core:
|
||||
container_name: maim-bot-core
|
||||
container_name: MoFox-Bot
|
||||
#### prod ####
|
||||
image: sengokucola/maibot:latest
|
||||
# image: infinitycat/maibot:latest
|
||||
image: hunuon/mofox:latest
|
||||
#### dev ####
|
||||
# image: sengokucola/maibot:dev
|
||||
# image: infinitycat/maibot:dev
|
||||
# image: hunuon/mofox:dev
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
# - EULA_AGREE=99f08e0cab0190de853cb6af7d64d4de # 同意EULA
|
||||
# - PRIVACY_AGREE=9943b855e72199d0f5016ea39052f1b6 # 同意EULA
|
||||
# ports:
|
||||
# - "8000:8000"
|
||||
volumes:
|
||||
- ./docker-config/mmc/.env:/MaiMBot/.env # 持久化env配置文件
|
||||
- ./docker-config/mmc:/MaiMBot/config # 持久化bot配置文件
|
||||
- ./data/MaiMBot/maibot_statistics.html:/MaiMBot/maibot_statistics.html #统计数据输出
|
||||
- ./data/MaiMBot:/MaiMBot/data # 共享目录
|
||||
- ./data/MaiMBot/plugins:/MaiMBot/plugins # 插件目录
|
||||
- ./data/MaiMBot/logs:/MaiMBot/logs # 日志目录
|
||||
- site-packages:/usr/local/lib/python3.13/site-packages # 持久化Python包
|
||||
- ./docker-config/core/.env:/app/.env # 持久化env配置文件
|
||||
- ./docker-config/core:/app/config # 持久化bot配置文件
|
||||
- ./data/core/maibot_statistics.html:/app/maibot_statistics.html #统计数据输出
|
||||
- ./data/app:/app/data # 共享目录
|
||||
- ./data/core/plugins:/app/plugins # 插件目录
|
||||
- ./data/core/logs:/app/logs # 日志目录
|
||||
restart: always
|
||||
networks:
|
||||
- maim_bot
|
||||
- mofox
|
||||
napcat:
|
||||
environment:
|
||||
- NAPCAT_UID=1000
|
||||
@@ -52,25 +27,12 @@ services:
|
||||
volumes:
|
||||
- ./docker-config/napcat:/app/napcat/config # 持久化napcat配置文件
|
||||
- ./data/qq:/app/.config/QQ # 持久化QQ本体
|
||||
- ./data/MaiMBot:/MaiMBot/data # 共享目录
|
||||
container_name: maim-bot-napcat
|
||||
- ./data/app:/app/data # 共享目录
|
||||
container_name: mofox-napcat
|
||||
restart: always
|
||||
image: mlikiowa/napcat-docker:latest
|
||||
networks:
|
||||
- maim_bot
|
||||
sqlite-web:
|
||||
# 注意:coleifer/sqlite-web 镜像不支持arm64
|
||||
image: coleifer/sqlite-web
|
||||
container_name: sqlite-web
|
||||
restart: always
|
||||
ports:
|
||||
- "8120:8080"
|
||||
volumes:
|
||||
- ./data/MaiMBot:/data/MaiMBot
|
||||
environment:
|
||||
- SQLITE_DATABASE=MaiMBot/MaiBot.db # 你的数据库文件
|
||||
networks:
|
||||
- maim_bot
|
||||
- mofox
|
||||
|
||||
# chat2db占用相对较高但是功能强大
|
||||
# 内存占用约600m,内存充足推荐选此
|
||||
@@ -81,11 +43,11 @@ services:
|
||||
# ports:
|
||||
# - "10824:10824"
|
||||
# volumes:
|
||||
# - ./data/MaiMBot:/data/MaiMBot
|
||||
# - ./data/chat2db:/data/app
|
||||
# networks:
|
||||
# - maim_bot
|
||||
# - mofox
|
||||
volumes:
|
||||
site-packages:
|
||||
networks:
|
||||
maim_bot:
|
||||
mofox:
|
||||
driver: bridge
|
||||
|
||||
@@ -3,10 +3,11 @@ import random
|
||||
from typing import Any
|
||||
|
||||
from src.plugin_system import (
|
||||
ActionActivationType,
|
||||
BaseAction,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BasePrompt,
|
||||
ToolParamType,
|
||||
BaseTool,
|
||||
ChatType,
|
||||
CommandArgs,
|
||||
@@ -37,7 +38,17 @@ class GetSystemInfoTool(BaseTool):
|
||||
name = "get_system_info"
|
||||
description = "获取当前系统的模拟版本和状态信息。"
|
||||
available_for_llm = True
|
||||
parameters = []
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
(
|
||||
"time_range",
|
||||
ToolParamType.STRING,
|
||||
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
|
||||
@@ -100,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction):
|
||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||
"""LLM 判断激活:判断用户是否情绪低落"""
|
||||
return await self._llm_judge_activation(
|
||||
chat_content=chat_content,
|
||||
judge_prompt="""
|
||||
判断用户是否表达了以下情绪或需求:
|
||||
1. 感到难过、沮丧或失落
|
||||
@@ -133,11 +143,11 @@ class CombinedActivationExampleAction(BaseAction):
|
||||
# 先尝试随机激活
|
||||
if await self._random_activation(0.2):
|
||||
return True
|
||||
|
||||
|
||||
# 如果随机未激活,尝试关键词匹配
|
||||
if await self._keyword_match(chat_content, ["表情", "emoji", "😊"], case_sensitive=False):
|
||||
return True
|
||||
|
||||
|
||||
# 都不满足则不激活
|
||||
return False
|
||||
|
||||
@@ -170,6 +180,19 @@ class RandomEmojiAction(BaseAction):
|
||||
return True, "成功发送了一个随机表情"
|
||||
|
||||
|
||||
class WeatherPrompt(BasePrompt):
|
||||
"""一个简单的Prompt组件,用于向Planner注入天气信息。"""
|
||||
|
||||
prompt_name = "weather_info_prompt"
|
||||
prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。"
|
||||
injection_point = "planner_prompt"
|
||||
|
||||
async def execute(self) -> str:
|
||||
# 在实际应用中,这里可以调用天气API
|
||||
# 为了演示,我们返回一个固定的天气信息
|
||||
return "当前天气:晴朗,温度25°C。"
|
||||
|
||||
|
||||
@register_plugin
|
||||
class HelloWorldPlugin(BasePlugin):
|
||||
"""一个包含四大核心组件和高级配置功能的入门示例插件。"""
|
||||
@@ -179,7 +202,6 @@ class HelloWorldPlugin(BasePlugin):
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
config_file_name = "config.toml"
|
||||
enable_plugin = False
|
||||
|
||||
config_schema = {
|
||||
"meta": {
|
||||
@@ -209,4 +231,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
if self.get_config("components.random_emoji_action_enabled", True):
|
||||
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
||||
|
||||
# 注册新的Prompt组件
|
||||
components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt))
|
||||
|
||||
return components
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
name = "MoFox-Bot"
|
||||
version = "0.8.1"
|
||||
description = "MoFox-Bot 是一个基于大语言模型的可交互智能体"
|
||||
requires-python = ">=3.11"
|
||||
requires-python = ">=3.11,<=3.13"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.14",
|
||||
"aiohttp-cors>=0.8.1",
|
||||
"apscheduler>=3.11.0",
|
||||
"asyncddgs>=0.1.0a1",
|
||||
"asyncio>=4.0.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"chromadb>=0.5.0",
|
||||
"chromadb>=1.2.0",
|
||||
"colorama>=0.4.6",
|
||||
"cryptography>=45.0.5",
|
||||
"cryptography>=46.0.3",
|
||||
"customtkinter>=5.2.2",
|
||||
"dotenv>=0.9.9",
|
||||
"exa-py>=1.14.20",
|
||||
@@ -21,11 +20,10 @@ dependencies = [
|
||||
"google>=3.0.0",
|
||||
"google-genai>=1.29.0",
|
||||
"httpx>=0.28.1",
|
||||
"jieba>=0.1.13",
|
||||
"json-repair>=0.47.6",
|
||||
"json5>=0.12.1",
|
||||
"jsonlines>=4.0.0",
|
||||
"langfuse==2.46.2",
|
||||
"langfuse==3.7.0",
|
||||
"lunar-python>=1.4.4",
|
||||
"lxml>=6.0.0",
|
||||
"maim-message>=0.3.8",
|
||||
@@ -33,16 +31,16 @@ dependencies = [
|
||||
"networkx>=3.4.2",
|
||||
"orjson>=3.10",
|
||||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"openai>=2.5.0",
|
||||
"opencv-python>=4.11.0.86",
|
||||
"packaging>=23.2",
|
||||
"packaging>=25.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"pillow>=12.0.0",
|
||||
"pip-check-reqs>=2.5.5",
|
||||
"psutil>=7.0.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pyarrow>=21.0.0",
|
||||
"pydantic>=2.12.3",
|
||||
"pygments>=2.19.2",
|
||||
"pymongo>=4.13.2",
|
||||
"pymysql>=1.1.1",
|
||||
@@ -76,8 +74,8 @@ dependencies = [
|
||||
"aiosqlite>=0.21.0",
|
||||
"inkfox>=0.1.1",
|
||||
"rjieba>=0.1.13",
|
||||
"mcp>=0.9.0",
|
||||
"sse-starlette>=2.2.1",
|
||||
"mcp>=1.18.0",
|
||||
"sse-starlette>=3.0.2",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
||||
106
scripts/convert_manifest.py
Normal file
106
scripts/convert_manifest.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
# 将脚本所在的目录添加到系统路径中,以便导入项目模块
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("convert_manifest")
|
||||
|
||||
def convert_and_copy_plugin(plugin_dir: Path, output_dir: Path):
|
||||
"""
|
||||
转换插件的 _manifest.json 文件,并将其整个目录复制到输出位置。
|
||||
"""
|
||||
manifest_path = plugin_dir / "_manifest.json"
|
||||
if not manifest_path.is_file():
|
||||
logger.warning(f"在目录 '{plugin_dir.name}' 中未找到 '_manifest.json',已跳过。")
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. 复制整个插件目录
|
||||
target_plugin_dir = output_dir / plugin_dir.name
|
||||
if target_plugin_dir.exists():
|
||||
shutil.rmtree(target_plugin_dir) # 如果目标已存在,先删除
|
||||
shutil.copytree(plugin_dir, target_plugin_dir)
|
||||
logger.info(f"已将插件 '{plugin_dir.name}' 完整复制到 '{target_plugin_dir}'")
|
||||
|
||||
# 2. 读取 manifest 并生成 __init__.py 内容
|
||||
with open(manifest_path, "rb") as f:
|
||||
manifest = orjson.loads(f.read())
|
||||
|
||||
plugin_name = manifest.get("name", "Unknown Plugin")
|
||||
description = manifest.get("description", "No description provided.")
|
||||
version = manifest.get("version", "1.0.0")
|
||||
author = manifest.get("author", {}).get("name", "Unknown Author")
|
||||
license_type = manifest.get("license")
|
||||
repository_url = manifest.get("repository_url")
|
||||
keywords = manifest.get("keywords", [])
|
||||
categories = manifest.get("categories", [])
|
||||
plugin_type = manifest.get("plugin_info", {}).get("plugin_type")
|
||||
|
||||
|
||||
meta_template = f"""from src.plugin_system.base.plugin_metadata import PluginMetadata
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="{plugin_name}",
|
||||
description="{description}",
|
||||
usage="暂无说明",
|
||||
type={repr(plugin_type)},
|
||||
version="{version}",
|
||||
author="{author}",
|
||||
license={repr(license_type)},
|
||||
repository_url={repr(repository_url)},
|
||||
keywords={keywords},
|
||||
categories={categories},
|
||||
)
|
||||
"""
|
||||
# 3. 在复制后的目录中创建或覆盖 __init__.py
|
||||
output_init_path = target_plugin_dir / "__init__.py"
|
||||
with open(output_init_path, "w", encoding="utf-8") as f:
|
||||
f.write(meta_template)
|
||||
|
||||
# 4. 删除复制后的 _manifest.json
|
||||
copied_manifest_path = target_plugin_dir / "_manifest.json"
|
||||
if copied_manifest_path.is_file():
|
||||
copied_manifest_path.unlink()
|
||||
|
||||
logger.info(f"成功为 '{plugin_dir.name}' 创建元数据文件并清理清单。")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"错误: 在 '{manifest_path}' 未找到清单文件")
|
||||
except orjson.JSONDecodeError:
|
||||
logger.error(f"错误: 无法解析 '{manifest_path}' 的 JSON 内容")
|
||||
except Exception as e:
|
||||
logger.error(f"处理 '{plugin_dir.name}' 时发生意外错误: {e}")
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,扫描 "plugins" 目录,并将合格的插件转换并复制到 "completed_plugins" 目录。
|
||||
"""
|
||||
# 使用相对于脚本位置的固定路径
|
||||
script_dir = Path(__file__).parent
|
||||
input_path = script_dir / "pending_plugins"
|
||||
output_path = script_dir / "completed_plugins"
|
||||
|
||||
if not input_path.is_dir():
|
||||
logger.error(f"错误: 输入目录 '{input_path}' 不存在。")
|
||||
input_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("请在新建的文件夹里面投入插件文件夹并重新启动脚本")
|
||||
return
|
||||
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"正在扫描 '{input_path}' 中的插件...")
|
||||
for item in input_path.iterdir():
|
||||
if item.is_dir():
|
||||
logger.info(f"发现插件目录: '{item.name}',开始处理...")
|
||||
convert_and_copy_plugin(item, output_path)
|
||||
|
||||
logger.info("所有插件处理完成。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,218 +0,0 @@
|
||||
"""批量将经典 SQLAlchemy 模型字段写法
|
||||
|
||||
field = Column(Integer, nullable=False, default=0)
|
||||
|
||||
转换为 2.0 推荐的带类型注解写法:
|
||||
|
||||
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
脚本特点:
|
||||
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
|
||||
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
|
||||
3. 已经是 Mapped 写法的行会跳过。
|
||||
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
|
||||
5. nullable=True 时自动添加 "| None"。
|
||||
6. 保留 Column(...) 内的原始参数顺序与内容。
|
||||
7. 生成 .bak 备份文件,确保可回滚。
|
||||
8. 支持 --dry-run 查看差异,不写回文件。
|
||||
|
||||
局限/注意:
|
||||
- 简单基于正则/括号计数,不解析完整 AST;非常规写法(例如变量中构造 Column 再赋值)不会处理。
|
||||
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
|
||||
- 不自动添加 from __future__ import annotations;如需 Python 3.10 以下更先进类型表达式,请自行处理。
|
||||
|
||||
使用方式: (在项目根目录执行)
|
||||
|
||||
python scripts/convert_sqlalchemy_models.py \
|
||||
--file src/common/database/sqlalchemy_models.py --dry-run
|
||||
|
||||
确认无误后去掉 --dry-run 真实写入。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
TYPE_MAP = {
|
||||
"Integer": "int",
|
||||
"Float": "float",
|
||||
"Boolean": "bool",
|
||||
"Text": "str",
|
||||
"String": "str",
|
||||
"DateTime": "datetime.datetime",
|
||||
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
|
||||
"get_string_field": "str",
|
||||
}
|
||||
|
||||
|
||||
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
|
||||
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
|
||||
|
||||
|
||||
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
|
||||
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
|
||||
|
||||
使用括号计数法处理多行。
|
||||
返回 (start, end) 行号 (包含 end)。"""
|
||||
line = lines[start_index]
|
||||
if "Column(" not in line:
|
||||
return None
|
||||
open_parens = line.count("(") - line.count(")")
|
||||
i = start_index
|
||||
while open_parens > 0 and i + 1 < len(lines):
|
||||
i += 1
|
||||
l2 = lines[i]
|
||||
open_parens += l2.count("(") - l2.count(")")
|
||||
return (start_index, i)
|
||||
|
||||
|
||||
def extract_column_body(block_lines: list[str]) -> str:
|
||||
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
|
||||
joined = "\n".join(block_lines)
|
||||
# 找到第一次出现 Column(
|
||||
start_pos = joined.find("Column(")
|
||||
if start_pos == -1:
|
||||
return ""
|
||||
inner = joined[start_pos + len("Column(") :]
|
||||
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
|
||||
last_paren = inner.rfind(")")
|
||||
if last_paren != -1:
|
||||
inner = inner[:last_paren]
|
||||
return inner.strip()
|
||||
|
||||
|
||||
def guess_python_type(column_body: str) -> str:
|
||||
# 简单取第一个类型标识符 (去掉前导装饰/空格)
|
||||
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
|
||||
# 利用正则抓取第一个标识符
|
||||
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
|
||||
if not m:
|
||||
return "Any"
|
||||
type_token = m.group(1)
|
||||
py_type = TYPE_MAP.get(type_token, "Any")
|
||||
# nullable 检测
|
||||
if "nullable=True" in column_body or "nullable = True" in column_body:
|
||||
# 避免重复 Optional
|
||||
if py_type != "Any" and not py_type.endswith(" | None"):
|
||||
py_type = f"{py_type} | None"
|
||||
elif py_type == "Any":
|
||||
py_type = "Any | None"
|
||||
return py_type
|
||||
|
||||
|
||||
def convert_block(block_lines: list[str]) -> list[str]:
|
||||
first_line = block_lines[0]
|
||||
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
|
||||
if not m_name:
|
||||
return block_lines
|
||||
indent = m_name.group("indent")
|
||||
name = m_name.group("name")
|
||||
body = extract_column_body(block_lines)
|
||||
py_type = guess_python_type(body)
|
||||
# 构造新的多行 mapped_column 写法
|
||||
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
|
||||
inner_lines = body.split("\n")
|
||||
if len(inner_lines) == 1:
|
||||
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
|
||||
return [new_line]
|
||||
else:
|
||||
# 多行情况
|
||||
ind2 = indent + " "
|
||||
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
|
||||
for il in inner_lines:
|
||||
if il.strip():
|
||||
rebuilt.append(f"{ind2}{il.rstrip()}")
|
||||
rebuilt.append(f"{indent})\n")
|
||||
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
|
||||
|
||||
|
||||
def ensure_imports(content: str) -> str:
|
||||
if "Mapped," in content or "Mapped[" in content:
|
||||
# 已经可能存在导入
|
||||
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
|
||||
# 简单插到第一个 import sqlalchemy 之后
|
||||
lines = content.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
|
||||
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
|
||||
return "\n".join(lines)
|
||||
return content
|
||||
|
||||
|
||||
def process_file(path: Path) -> tuple[str, str]:
|
||||
original = path.read_text(encoding="utf-8")
|
||||
lines = original.splitlines(keepends=True)
|
||||
i = 0
|
||||
out: list[str] = []
|
||||
changed = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
# 跳过已是 Mapped 风格
|
||||
if ALREADY_MAPPED_RE.match(line):
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
|
||||
start, end = detect_column_block(lines, i) or (i, i)
|
||||
block = lines[start : end + 1]
|
||||
converted = convert_block(block)
|
||||
out.extend(converted)
|
||||
i = end + 1
|
||||
# 如果转换结果与原始不同,计数
|
||||
if "".join(converted) != "".join(block):
|
||||
changed += 1
|
||||
else:
|
||||
out.append(line)
|
||||
i += 1
|
||||
new_content = "".join(out)
|
||||
new_content = ensure_imports(new_content)
|
||||
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
|
||||
return original, new_content if changed else original
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
|
||||
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
|
||||
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
|
||||
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
|
||||
args = parser.parse_args()
|
||||
|
||||
target = Path(args.file)
|
||||
if not target.exists():
|
||||
raise SystemExit(f"文件不存在: {target}")
|
||||
|
||||
original, new_content = process_file(target)
|
||||
|
||||
if original == new_content:
|
||||
print("[INFO] 没有需要转换的内容或转换后无差异。")
|
||||
return
|
||||
|
||||
# 简单差异输出 (行对比)
|
||||
if args.dry_run or not args.write:
|
||||
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
|
||||
)
|
||||
count = 0
|
||||
for d in diff:
|
||||
print(d)
|
||||
count += 1
|
||||
if count == 0:
|
||||
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
|
||||
if not args.write:
|
||||
print("\n未写回。若确认无误,添加 --write 执行替换。")
|
||||
return
|
||||
|
||||
backup = target.with_suffix(target.suffix + ".bak")
|
||||
shutil.copyfile(target, backup)
|
||||
target.write_text(new_content, encoding="utf-8")
|
||||
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
main()
|
||||
@@ -5,6 +5,7 @@ import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
@@ -191,43 +192,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
return None, pg_hash
|
||||
|
||||
|
||||
async def extract_information(paragraphs_dict, model_set):
|
||||
def extract_info_sync(pg_hash, paragraph, model_set):
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||
|
||||
|
||||
def extract_information(paragraphs_dict, model_set):
|
||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
tasks = [
|
||||
extract_info_async(p_hash, p, llm_api)
|
||||
for p_hash, p in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks))
|
||||
for future in asyncio.as_completed(tasks):
|
||||
doc_item, failed_hash = await future
|
||||
if failed_hash:
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
progress.update(prog_task, advance=1)
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
f_to_hash = {
|
||||
executor.submit(extract_info_sync, p_hash, p, model_set): p_hash
|
||||
for p_hash, p in paragraphs_dict.items()
|
||||
}
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||
for future in as_completed(f_to_hash):
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
if open_ie_docs:
|
||||
all_entities = [
|
||||
e for doc in open_ie_docs for e in doc["extracted_entities"]
|
||||
]
|
||||
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
|
||||
num_entities = len(all_entities)
|
||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
@@ -312,7 +315,7 @@ async def import_data(openie_obj: OpenIE | None = None):
|
||||
logger.info("--- 数据导入完成 ---")
|
||||
|
||||
|
||||
async def import_from_specific_file():
|
||||
def import_from_specific_file():
|
||||
"""从用户指定的 openie.json 文件导入数据"""
|
||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||
|
||||
@@ -327,7 +330,7 @@ async def import_from_specific_file():
|
||||
try:
|
||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||
openie_obj = OpenIE.load()
|
||||
await import_data(openie_obj=openie_obj)
|
||||
asyncio.run(import_data(openie_obj=openie_obj))
|
||||
except Exception as e:
|
||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||
|
||||
@@ -335,20 +338,14 @@ async def import_from_specific_file():
|
||||
# --- 主函数 ---
|
||||
|
||||
|
||||
async def async_main():
|
||||
def main():
|
||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||
raw_data_relpath = os.path.relpath(
|
||||
RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")
|
||||
)
|
||||
openie_output_relpath = os.path.relpath(
|
||||
OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")
|
||||
)
|
||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||
|
||||
print("=== LPMM 知识库学习工具 ===")
|
||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||
print(
|
||||
f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)"
|
||||
)
|
||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||
@@ -362,20 +359,16 @@ async def async_main():
|
||||
elif choice == "2":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
await extract_information(
|
||||
paragraphs, model_config.model_task_config.lpmm_qa
|
||||
)
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
elif choice == "3":
|
||||
await import_data()
|
||||
asyncio.run(import_data())
|
||||
elif choice == "4":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
await extract_information(
|
||||
paragraphs, model_config.model_task_config.lpmm_qa
|
||||
)
|
||||
await import_data()
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
asyncio.run(import_data())
|
||||
elif choice == "5":
|
||||
await import_from_specific_file()
|
||||
import_from_specific_file()
|
||||
elif choice == "6":
|
||||
clear_cache()
|
||||
elif choice == "0":
|
||||
@@ -385,4 +378,4 @@ async def async_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(async_main())
|
||||
main()
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
更新Prompt类导入脚本
|
||||
将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# 需要更新的文件列表
|
||||
files_to_update = [
|
||||
"src/person_info/relationship_fetcher.py",
|
||||
"src/mood/mood_manager.py",
|
||||
"src/mais4u/mais4u_chat/body_emotion_action_manager.py",
|
||||
"src/chat/express/expression_learner.py",
|
||||
"src/chat/planner_actions/planner.py",
|
||||
"src/mais4u/mais4u_chat/s4u_prompt.py",
|
||||
"src/chat/message_receive/bot.py",
|
||||
"src/chat/replyer/default_generator.py",
|
||||
"src/chat/express/expression_selector.py",
|
||||
"src/mais4u/mai_think.py",
|
||||
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
||||
"src/plugin_system/core/tool_use.py",
|
||||
"src/chat/memory_system/memory_activator.py",
|
||||
"src/chat/utils/smart_prompt.py",
|
||||
]
|
||||
|
||||
|
||||
def update_prompt_imports(file_path):
|
||||
"""更新文件中的Prompt导入"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 替换导入语句
|
||||
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
|
||||
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
|
||||
|
||||
if old_import in content:
|
||||
new_content = content.replace(old_import, new_import)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
print(f"已更新: {file_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"无需更新: {file_path}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
updated_count = 0
|
||||
for file_path in files_to_update:
|
||||
if update_prompt_imports(file_path):
|
||||
updated_count += 1
|
||||
|
||||
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,8 +30,8 @@ from .utils.hash import get_sha256
|
||||
install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
DEFAULT_MAX_WORKERS = 3 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
@@ -124,60 +124,124 @@ class EmbeddingStore:
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: list[str],
|
||||
main_loop: asyncio.AbstractEventLoop,
|
||||
chunk_size: int = 10,
|
||||
max_workers: int = 10,
|
||||
progress_callback=None,
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务"""
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
# 导入必要的模块
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 在主线程(即主事件循环所在的线程)中创建LLMRequest实例
|
||||
# 这样可以确保它绑定到正确的事件循环
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 分块
|
||||
chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)]
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""在工作线程中运行的函数"""
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
embedding = []
|
||||
try:
|
||||
# 将异步的 get_embedding 调用提交到主事件循环
|
||||
future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop)
|
||||
# 同步等待结果,延长超时时间
|
||||
embedding_result, _ = future.result(timeout=60)
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if embedding_result and len(embedding_result) > 0:
|
||||
embedding = embedding_result
|
||||
else:
|
||||
logger.error(f"获取嵌入失败(返回为空): {s}")
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}")
|
||||
finally:
|
||||
chunk_results.append((start_idx + i, s, embedding))
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
for i, s in enumerate(chunk_strs):
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
return chunk_results
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
chunk_results = future.result()
|
||||
@@ -185,14 +249,22 @@ class EmbeddingStore:
|
||||
results[idx] = (s, embedding)
|
||||
except Exception as e:
|
||||
chunk = future_to_chunk[future]
|
||||
logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}")
|
||||
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
|
||||
# 为失败的块添加空结果
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s_item in enumerate(chunk_strs):
|
||||
if (start_idx + i) not in results:
|
||||
results[start_idx + i] = (s_item, [])
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
# 按原始顺序返回结果
|
||||
return [results.get(i, (strs[i], [])) for i in range(len(strs))]
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
if i in results:
|
||||
ordered_results.append(results[i])
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
return ordered_results
|
||||
|
||||
@staticmethod
|
||||
def get_test_file_path():
|
||||
@@ -202,17 +274,9 @@ class EmbeddingStore:
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
# 获取当前正在运行的事件循环
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
return
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
main_loop,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
@@ -224,6 +288,8 @@ class EmbeddingStore:
|
||||
test_vectors[str(idx)] = embedding
|
||||
else:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
@@ -255,17 +321,9 @@ class EmbeddingStore:
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
# 获取当前正在运行的事件循环
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
return False
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
main_loop,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
@@ -325,20 +383,11 @@ class EmbeddingStore:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
if new_strs:
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
# 更新进度条以反映未处理的项目
|
||||
progress.update(task, advance=len(new_strs))
|
||||
return
|
||||
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size,
|
||||
len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size,
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
@@ -355,13 +404,12 @@ class EmbeddingStore:
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
main_loop,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
# 存入结果
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if embedding: # 只有成功获取到嵌入才存入
|
||||
|
||||
@@ -88,6 +88,8 @@ class QAManager:
|
||||
else:
|
||||
logger.info("未找到相关关系,将使用文段检索结果")
|
||||
result = paragraph_search_res
|
||||
if result and result[0][1] < global_config.lpmm_knowledge.qa_paragraph_threshold:
|
||||
result = []
|
||||
ppr_node_weights = None
|
||||
|
||||
# 过滤阈值
|
||||
|
||||
@@ -45,8 +45,8 @@ class MessageManager:
|
||||
self.chatter_manager = ChatterManager(self.action_manager)
|
||||
|
||||
# 消息缓存系统 - 直接集成到消息管理器
|
||||
self.message_caches: Dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
|
||||
self.stream_processing_status: Dict[str, bool] = defaultdict(bool) # 流的处理状态
|
||||
self.message_caches: dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
|
||||
self.stream_processing_status: dict[str, bool] = defaultdict(bool) # 流的处理状态
|
||||
self.cache_stats = {
|
||||
"total_cached_messages": 0,
|
||||
"total_flushed_messages": 0,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
from .state_manager import SleepState, sleep_state_manager
|
||||
|
||||
logger = get_logger("sleep_logic")
|
||||
@@ -77,7 +77,7 @@ class SleepLogic:
|
||||
logger.info(f"当前时间 {now.strftime('%H:%M')} 已到达或超过预定起床时间 {wake_up_time.strftime('%H:%M')}。")
|
||||
sleep_state_manager.set_state(SleepState.AWAKE)
|
||||
|
||||
def _should_be_sleeping(self, now: datetime) -> Tuple[bool, Optional[datetime]]:
|
||||
def _should_be_sleeping(self, now: datetime) -> tuple[bool, datetime | None]:
|
||||
"""
|
||||
判断在当前时刻,是否应该处于睡眠时间。
|
||||
|
||||
@@ -108,10 +108,10 @@ class SleepLogic:
|
||||
return True, wake_up_time
|
||||
# 如果当前时间大于入睡时间,说明已经进入睡眠窗口
|
||||
return True, wake_up_time
|
||||
|
||||
|
||||
return False, None
|
||||
|
||||
def _get_fixed_sleep_times(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def _get_fixed_sleep_times(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“固定时间”模式时,从此方法计算睡眠和起床时间。
|
||||
会加入配置中的随机偏移量,让作息更自然。
|
||||
@@ -129,7 +129,7 @@ class SleepLogic:
|
||||
wake_up_t = datetime.strptime(sleep_config.fixed_wake_up_time, "%H:%M").time()
|
||||
|
||||
sleep_time = datetime.combine(now.date(), sleep_t) + timedelta(minutes=sleep_offset)
|
||||
|
||||
|
||||
# 如果起床时间比睡觉时间早,说明是第二天
|
||||
wake_up_day = now.date() + timedelta(days=1) if wake_up_t < sleep_t else now.date()
|
||||
wake_up_time = datetime.combine(wake_up_day, wake_up_t) + timedelta(minutes=wake_up_offset)
|
||||
@@ -139,7 +139,7 @@ class SleepLogic:
|
||||
logger.error(f"解析固定睡眠时间失败: {e}")
|
||||
return None, None
|
||||
|
||||
def _get_sleep_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def _get_sleep_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“日程表”模式时,从此方法获取睡眠时间。
|
||||
实现了核心逻辑:
|
||||
@@ -164,8 +164,8 @@ class SleepLogic:
|
||||
wake_up_time = None
|
||||
|
||||
return sleep_time, wake_up_time
|
||||
|
||||
def _get_wakeup_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
|
||||
def _get_wakeup_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“日程表”模式时,从此方法获取睡眠时间。
|
||||
实现了核心逻辑:
|
||||
@@ -192,4 +192,4 @@ class SleepLogic:
|
||||
|
||||
|
||||
# 全局单例
|
||||
sleep_logic = SleepLogic()
|
||||
sleep_logic = SleepLogic()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -43,7 +43,7 @@ class SleepStateManager:
|
||||
"""
|
||||
初始化状态管理器,定义状态数据结构并从本地加载历史状态。
|
||||
"""
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
self._default_state()
|
||||
self.load_state()
|
||||
|
||||
@@ -115,9 +115,9 @@ class SleepStateManager:
|
||||
def set_state(
|
||||
self,
|
||||
new_state: SleepState,
|
||||
duration_seconds: Optional[float] = None,
|
||||
sleep_start: Optional[datetime] = None,
|
||||
wake_up: Optional[datetime] = None,
|
||||
duration_seconds: float | None = None,
|
||||
sleep_start: datetime | None = None,
|
||||
wake_up: datetime | None = None,
|
||||
):
|
||||
"""
|
||||
核心函数:切换到新的睡眠状态,并更新相关的状态数据。
|
||||
@@ -132,7 +132,7 @@ class SleepStateManager:
|
||||
if new_state == SleepState.AWAKE:
|
||||
self._default_state() # 醒来时重置所有状态
|
||||
self.state["state"] = SleepState.AWAKE # 确保状态正确
|
||||
|
||||
|
||||
elif new_state == SleepState.SLEEPING:
|
||||
self.state["sleep_start_time"] = (sleep_start or datetime.now()).isoformat()
|
||||
self.state["wake_up_time"] = wake_up.isoformat() if wake_up else None
|
||||
@@ -153,7 +153,7 @@ class SleepStateManager:
|
||||
self.state["last_checked"] = datetime.now().isoformat()
|
||||
self.save_state()
|
||||
|
||||
def get_wake_up_time(self) -> Optional[datetime]:
|
||||
def get_wake_up_time(self) -> datetime | None:
|
||||
"""获取预定的起床时间,如果已设置的话。"""
|
||||
wake_up_str = self.state.get("wake_up_time")
|
||||
if wake_up_str:
|
||||
@@ -163,7 +163,7 @@ class SleepStateManager:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_sleep_start_time(self) -> Optional[datetime]:
|
||||
def get_sleep_start_time(self) -> datetime | None:
|
||||
"""获取本次睡眠的开始时间,如果已设置的话。"""
|
||||
sleep_start_str = self.state.get("sleep_start_time")
|
||||
if sleep_start_str:
|
||||
@@ -187,4 +187,4 @@ class SleepStateManager:
|
||||
|
||||
|
||||
# 全局单例
|
||||
sleep_state_manager = SleepStateManager()
|
||||
sleep_state_manager = SleepStateManager()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
|
||||
from .sleep_logic import sleep_logic
|
||||
|
||||
logger = get_logger("sleep_tasks")
|
||||
|
||||
@@ -402,19 +402,31 @@ class ChatBot:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
platform = message_data["message_info"].get("platform")
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
if not isinstance(message_data, dict):
|
||||
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
|
||||
return
|
||||
message_info = message_data.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(message_data.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
platform = message_info.get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
message_data["message_info"]["group_info"]["group_id"]
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str(
|
||||
message_info["group_info"]["group_id"]
|
||||
)
|
||||
if message_data["message_info"].get("user_info") is not None:
|
||||
message_data["message_info"]["user_info"]["user_id"] = str(
|
||||
message_data["message_info"]["user_info"]["user_id"]
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str(
|
||||
message_info["user_info"]["user_id"]
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -207,18 +207,18 @@ class ActionModifier:
|
||||
List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表
|
||||
"""
|
||||
deactivated_actions = []
|
||||
|
||||
|
||||
# 获取 Action 类注册表
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
actions_to_check = list(actions_with_info.items())
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
|
||||
# 创建并行任务列表
|
||||
activation_tasks = []
|
||||
task_action_names = []
|
||||
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
# 获取 Action 类
|
||||
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
@@ -226,7 +226,7 @@ class ActionModifier:
|
||||
logger.warning(f"{self.log_prefix}未找到 Action 类: {action_name},默认不激活")
|
||||
deactivated_actions.append((action_name, "未找到 Action 类"))
|
||||
continue
|
||||
|
||||
|
||||
# 创建一个临时实例来调用 go_activate 方法
|
||||
# 注意:这里只是为了调用 go_activate,不需要完整的初始化
|
||||
try:
|
||||
@@ -237,24 +237,24 @@ class ActionModifier:
|
||||
action_instance.log_prefix = self.log_prefix
|
||||
# 设置聊天内容,用于激活判断
|
||||
action_instance._activation_chat_content = chat_content
|
||||
|
||||
|
||||
# 调用 go_activate 方法(不再需要传入 chat_content)
|
||||
task = action_instance.go_activate(
|
||||
llm_judge_model=self.llm_judge,
|
||||
)
|
||||
activation_tasks.append(task)
|
||||
task_action_names.append(action_name)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}创建 Action 实例 {action_name} 失败: {e}")
|
||||
deactivated_actions.append((action_name, f"创建实例失败: {e}"))
|
||||
|
||||
|
||||
# 并行执行所有激活判断
|
||||
if activation_tasks:
|
||||
logger.debug(f"{self.log_prefix}并行执行激活判断,任务数: {len(activation_tasks)}")
|
||||
try:
|
||||
task_results = await asyncio.gather(*activation_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理结果
|
||||
for action_name, result in zip(task_action_names, task_results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
@@ -267,7 +267,7 @@ class ActionModifier:
|
||||
else:
|
||||
# go_activate 返回 True,激活
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行激活判断失败: {e}")
|
||||
# 如果并行执行失败,为所有任务默认不激活
|
||||
|
||||
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
@@ -1312,7 +1313,7 @@ class DefaultReplyer:
|
||||
}
|
||||
|
||||
# 设置超时
|
||||
timeout = 15.0 # 秒
|
||||
timeout = 45.0 # 秒
|
||||
|
||||
async def get_task_result(task_name, task):
|
||||
try:
|
||||
|
||||
@@ -8,13 +8,14 @@ import contextvars
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -23,81 +24,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
notice_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
@@ -132,7 +58,7 @@ class PromptContext:
|
||||
context_id = None
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
@@ -185,16 +111,42 @@ class PromptManager:
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
|
||||
"""
|
||||
异步获取提示模板,并动态注入插件内容
|
||||
"""
|
||||
original_prompt = None
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
if name not in self._prompts:
|
||||
original_prompt = context_prompt
|
||||
elif name in self._prompts:
|
||||
original_prompt = self._prompts[name]
|
||||
else:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
# 动态注入插件内容
|
||||
if original_prompt.name:
|
||||
# 确保我们有有效的parameters实例
|
||||
params_for_injection = parameters or original_prompt.parameters
|
||||
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=original_prompt.name, params=params_for_injection
|
||||
)
|
||||
logger.info(components_prefix)
|
||||
if components_prefix:
|
||||
logger.info(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
# 创建一个新的临时Prompt实例,不进行注册
|
||||
new_template = f"{components_prefix}\n\n{original_prompt.template}"
|
||||
temp_prompt = Prompt(
|
||||
template=new_template,
|
||||
name=original_prompt.name,
|
||||
parameters=original_prompt.parameters,
|
||||
should_register=False, # 确保不重新注册
|
||||
)
|
||||
return temp_prompt
|
||||
|
||||
return original_prompt
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
@@ -216,7 +168,9 @@ class PromptManager:
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 提取parameters用于注入
|
||||
parameters = kwargs.get("parameters")
|
||||
prompt = await self.get_prompt_async(name, parameters=parameters)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
@@ -304,11 +258,14 @@ class Prompt:
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
# 1. 构建核心上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
# 2. 格式化主模板
|
||||
main_formatted_prompt = await self._format_with_context(context_data)
|
||||
|
||||
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
|
||||
result = main_formatted_prompt
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
@@ -470,9 +427,13 @@ class Prompt:
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
target_user_id = ""
|
||||
if self.parameters.target_user_info:
|
||||
target_user_id = self.parameters.target_user_info.get("user_id") or ""
|
||||
|
||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
target_user_id,
|
||||
self.parameters.sender,
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
@@ -498,11 +459,14 @@ class Prompt:
|
||||
|
||||
# 创建临时生成器实例来使用其方法
|
||||
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
if temp_generator:
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
@@ -589,10 +553,10 @@ class Prompt:
|
||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
if isinstance(running_memories, BaseException):
|
||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||
running_memories = []
|
||||
if isinstance(instant_memory, Exception):
|
||||
if isinstance(instant_memory, BaseException):
|
||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||
instant_memory = None
|
||||
|
||||
@@ -763,20 +727,15 @@ class Prompt:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import QAManager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
# 获取问题文本(当前消息)
|
||||
question = self.parameters.target or ""
|
||||
if not question:
|
||||
if not question or not qa_manager:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
# 创建QA管理器
|
||||
qa_manager = QAManager()
|
||||
|
||||
# 搜索相关知识
|
||||
knowledge_results = await qa_manager.get_knowledge(
|
||||
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
|
||||
)
|
||||
knowledge_results = await qa_manager.get_knowledge(question=question)
|
||||
|
||||
# 构建知识块
|
||||
if knowledge_results and knowledge_results.get("knowledge_items"):
|
||||
@@ -786,12 +745,17 @@ class Prompt:
|
||||
content = item.get("content", "")
|
||||
source = item.get("source", "")
|
||||
relevance = item.get("relevance", 0.0)
|
||||
|
||||
if content:
|
||||
try:
|
||||
relevance_float = float(relevance)
|
||||
relevance_str = f"{relevance_float:.2f}"
|
||||
except (ValueError, TypeError):
|
||||
relevance_str = str(relevance)
|
||||
|
||||
if source:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||
knowledge_parts.append(f"- [{relevance_str}] {content} (来源: {source})")
|
||||
else:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||
knowledge_parts.append(f"- [{relevance_str}] {content}")
|
||||
|
||||
if knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||
@@ -1108,8 +1072,24 @@ def create_prompt(
|
||||
async def create_prompt_async(
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
"""异步创建Prompt实例,并动态注入插件内容"""
|
||||
# 确保有可用的parameters实例
|
||||
final_params = parameters or PromptParameters(**kwargs)
|
||||
|
||||
# 动态注入插件内容
|
||||
if name:
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=name, params=final_params
|
||||
)
|
||||
if components_prefix:
|
||||
logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
template = f"{components_prefix}\n\n{template}"
|
||||
|
||||
# 使用可能已修改的模板创建实例
|
||||
prompt = create_prompt(template, name, final_params)
|
||||
|
||||
# 如果在特定上下文中,则异步注册
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
109
src/chat/utils/prompt_component_manager.py
Normal file
109
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import asyncio
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("prompt_component_manager")
|
||||
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
|
||||
该管理器负责:
|
||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||
"""
|
||||
获取指定注入点的所有已注册组件类。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
|
||||
Returns:
|
||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||
"""
|
||||
# 从组件注册中心获取所有启用的Prompt组件
|
||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||
|
||||
matching_components: list[Type[BasePrompt]] = []
|
||||
|
||||
for prompt_name, prompt_info in enabled_prompts.items():
|
||||
# 确保 prompt_info 是 PromptInfo 类型
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
continue
|
||||
|
||||
# 获取注入点信息
|
||||
injection_points = prompt_info.injection_point
|
||||
if isinstance(injection_points, str):
|
||||
injection_points = [injection_points]
|
||||
|
||||
# 检查当前注入点是否匹配
|
||||
if injection_point in injection_points:
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
if component_class and issubclass(component_class, BasePrompt):
|
||||
matching_components.append(component_class)
|
||||
|
||||
return matching_components
|
||||
|
||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||
"""
|
||||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
params: 用于初始化组件的 PromptParameters 对象。
|
||||
|
||||
Returns:
|
||||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||
"""
|
||||
component_classes = self.get_components_for(injection_point)
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
tasks = []
|
||||
for component_class in component_classes:
|
||||
try:
|
||||
# 从注册中心获取组件信息
|
||||
prompt_info = component_registry.get_component_info(
|
||||
component_class.prompt_name, ComponentType.PROMPT
|
||||
)
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||
plugin_config = {}
|
||||
else:
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
tasks.append(instance.execute())
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
|
||||
if not tasks:
|
||||
return ""
|
||||
|
||||
# 并行执行所有组件
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 过滤掉执行失败的结果和空字符串
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||
elif result and isinstance(result, str) and result.strip():
|
||||
valid_results.append(result.strip())
|
||||
|
||||
# 使用换行符拼接所有有效结果
|
||||
return "\n".join(valid_results)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
@@ -298,14 +298,14 @@ def random_remove_punctuation(text: str) -> str:
|
||||
def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
"""识别并保护数学公式和代码块,返回处理后的文本和映射"""
|
||||
placeholder_map = {}
|
||||
|
||||
|
||||
# 第一层防护:优先保护标准Markdown格式
|
||||
# 使用 re.S 来让 . 匹配换行符
|
||||
markdown_patterns = {
|
||||
'code': r"```.*?```",
|
||||
'math': r"\$\$.*?\$\$",
|
||||
"code": r"```.*?```",
|
||||
"math": r"\$\$.*?\$\$",
|
||||
}
|
||||
|
||||
|
||||
placeholder_idx = 0
|
||||
for block_type, pattern in markdown_patterns.items():
|
||||
matches = re.findall(pattern, text, re.S)
|
||||
@@ -318,7 +318,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
# 第二层防护:保护非标准的、可能是公式或代码的片段
|
||||
# 这个正则表达式寻找连续5个以上的、主要由非中文字符组成的片段
|
||||
general_pattern = r"(?:[a-zA-Z0-9\s.,;:(){}\[\]_+\-*/=<>^|&%?!'\"√²³ⁿ∑∫≠≥≤]){5,}"
|
||||
|
||||
|
||||
# 为了避免与已保护的占位符冲突,我们在剩余的文本上进行查找
|
||||
# 这是一个简化的处理,更稳妥的方式是分段查找,但目前这样足以应对多数情况
|
||||
try:
|
||||
@@ -327,7 +327,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
# 避免将包含占位符的片段再次保护
|
||||
if "__SPECIAL_" in match:
|
||||
continue
|
||||
|
||||
|
||||
placeholder = f"__SPECIAL_GENERAL_{placeholder_idx}__"
|
||||
text = text.replace(match, placeholder, 1)
|
||||
placeholder_map[placeholder] = match
|
||||
@@ -352,23 +352,23 @@ def protect_quoted_content(text: str) -> tuple[str, dict[str, str]]:
|
||||
placeholder_map = {}
|
||||
# 匹配中英文单双引号,使用非贪婪模式
|
||||
quote_pattern = re.compile(r'(".*?")|(\'.*?\')|(“.*?”)|(‘.*?’)')
|
||||
|
||||
|
||||
matches = quote_pattern.finditer(text)
|
||||
|
||||
|
||||
# 为了避免替换时索引错乱,我们从后往前替换
|
||||
# finditer 找到的是 match 对象,我们需要转换为 list 来反转
|
||||
match_list = list(matches)
|
||||
|
||||
|
||||
for idx, match in enumerate(reversed(match_list)):
|
||||
original_quoted_text = match.group(0)
|
||||
placeholder = f"__QUOTE_{len(match_list) - 1 - idx}__"
|
||||
|
||||
|
||||
# 直接在原始文本上操作,替换 match 对象的 span
|
||||
start, end = match.span()
|
||||
text = text[:start] + placeholder + text[end:]
|
||||
|
||||
|
||||
placeholder_map[placeholder] = original_quoted_text
|
||||
|
||||
|
||||
return text, placeholder_map
|
||||
|
||||
|
||||
@@ -389,13 +389,13 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
# --- 三层防护系统 ---
|
||||
# 第一层:保护颜文字
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})
|
||||
|
||||
|
||||
# 第二层:保护引号内容
|
||||
protected_text, quote_mapping = protect_quoted_content(protected_text)
|
||||
|
||||
# 第三层:保护数学公式和代码块
|
||||
protected_text, special_blocks_mapping = protect_special_blocks(protected_text)
|
||||
|
||||
|
||||
# 提取被 () 或 [] 或 ()包裹且包含中文的内容
|
||||
pattern = re.compile(r"[(\[(](?=.*[一-鿿]).*?[)\])]")
|
||||
_extracted_contents = pattern.findall(protected_text)
|
||||
@@ -412,7 +412,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||
|
||||
|
||||
# --- 移除总长度检查 ---
|
||||
# 原有的总长度检查会导致长回复被直接丢弃,现已移除,由后续的智能合并逻辑处理。
|
||||
# max_length = global_config.response_splitter.max_length * 2
|
||||
@@ -472,7 +472,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
break
|
||||
|
||||
# 寻找最短的相邻句子对
|
||||
min_len = float('inf')
|
||||
min_len = float("inf")
|
||||
merge_idx = -1
|
||||
for i in range(len(sentences) - 1):
|
||||
combined_len = len(sentences[i]) + len(sentences[i+1])
|
||||
@@ -488,7 +488,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
sentences[merge_idx] = merged_sentence
|
||||
# 删除后一个句子
|
||||
del sentences[merge_idx + 1]
|
||||
|
||||
|
||||
logger.info(f"智能合并完成,最终消息数量: {len(sentences)}")
|
||||
|
||||
# if extracted_contents:
|
||||
|
||||
@@ -79,7 +79,7 @@ class Server:
|
||||
logger.warning(f"端口 {self.port} 已被占用,正在尝试下一个端口...")
|
||||
self.port += 1
|
||||
|
||||
logger.info(f"将在 http://{self.host}:{self.port} 上启动服务器")
|
||||
logger.info(f"将在 {self.host}:{self.port} 上启动服务器")
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = Config(app=self.app, host=self.host, port=self.port, log_config=None, access_log=False)
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.config.config_base import ValidatedConfigBase
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 重要的配置类继承自ValidatedConfigBase进行Pydantic验证
|
||||
2. 所有配置类必须继承自ValidatedConfigBase进行Pydantic验证
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
@@ -492,6 +492,7 @@ class LPMMKnowledgeConfig(ValidatedConfigBase):
|
||||
info_extraction_workers: int = Field(default=3, description="信息提取工作线程数")
|
||||
qa_relation_search_top_k: int = Field(default=10, description="QA关系搜索Top K")
|
||||
qa_relation_threshold: float = Field(default=0.75, description="QA关系阈值")
|
||||
qa_paragraph_threshold: float = Field(default=0.3, description="QA段落阈值")
|
||||
qa_paragraph_search_top_k: int = Field(default=1000, description="QA段落搜索Top K")
|
||||
qa_paragraph_node_weight: float = Field(default=0.05, description="QA段落节点权重")
|
||||
qa_ent_filter_top_k: int = Field(default=10, description="QA实体过滤Top K")
|
||||
|
||||
@@ -13,6 +13,7 @@ from rich.traceback import install
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.memory_system.memory_manager import memory_manager
|
||||
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
@@ -29,7 +30,6 @@ from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -26,6 +26,7 @@ from .base import (
|
||||
ActionInfo,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BasePrompt,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BaseTool,
|
||||
@@ -64,6 +65,7 @@ __all__ = [
|
||||
"BaseEventHandler",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -8,6 +8,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .command_args import CommandArgs
|
||||
from .component_types import (
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"BaseCommand",
|
||||
"BaseEventHandler",
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -615,15 +615,15 @@ class BaseAction(ABC):
|
||||
"""
|
||||
# 尝试从不同的实例属性中获取聊天内容
|
||||
# 优先级:_activation_chat_content > action_data['chat_content'] > ""
|
||||
|
||||
|
||||
# 1. 如果有专门设置的激活用聊天内容(由 ActionModifier 设置)
|
||||
if hasattr(self, '_activation_chat_content'):
|
||||
return getattr(self, '_activation_chat_content', "")
|
||||
|
||||
if hasattr(self, "_activation_chat_content"):
|
||||
return getattr(self, "_activation_chat_content", "")
|
||||
|
||||
# 2. 尝试从 action_data 中获取
|
||||
if hasattr(self, 'action_data') and isinstance(self.action_data, dict):
|
||||
return self.action_data.get('chat_content', "")
|
||||
|
||||
if hasattr(self, "action_data") and isinstance(self.action_data, dict):
|
||||
return self.action_data.get("chat_content", "")
|
||||
|
||||
# 3. 默认返回空字符串
|
||||
return ""
|
||||
|
||||
@@ -729,7 +729,7 @@ class BaseAction(ABC):
|
||||
|
||||
# 自动获取聊天内容
|
||||
chat_content = self._get_chat_content()
|
||||
|
||||
|
||||
search_text = chat_content
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
@@ -786,7 +786,7 @@ class BaseAction(ABC):
|
||||
try:
|
||||
# 自动获取聊天内容
|
||||
chat_content = self._get_chat_content()
|
||||
|
||||
|
||||
# 如果没有提供 LLM 模型,创建一个默认的
|
||||
if llm_judge_model is None:
|
||||
from src.config.config import model_config
|
||||
|
||||
@@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import (
|
||||
EventHandlerInfo,
|
||||
InterestCalculatorInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
@@ -15,6 +16,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_interest_calculator import BaseInterestCalculator
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .plugin_base import PluginBase
|
||||
from .plus_command import PlusCommand
|
||||
@@ -80,6 +82,13 @@ class BasePlugin(PluginBase):
|
||||
logger.warning("EventHandler的get_info逻辑尚未实现")
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.PROMPT:
|
||||
if hasattr(component_class, "get_prompt_info"):
|
||||
return component_class.get_prompt_info()
|
||||
else:
|
||||
logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法")
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的组件类型: {component_type}")
|
||||
return None
|
||||
@@ -109,6 +118,7 @@ class BasePlugin(PluginBase):
|
||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||
| tuple[ToolInfo, type[BaseTool]]
|
||||
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
||||
| tuple[PromptInfo, type[BasePrompt]]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
|
||||
95
src/plugin_system/base/base_prompt.py
Normal file
95
src/plugin_system/base/base_prompt.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
|
||||
logger = get_logger("base_prompt")
|
||||
|
||||
|
||||
class BasePrompt(ABC):
|
||||
"""Prompt组件基类
|
||||
|
||||
Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。
|
||||
它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。
|
||||
|
||||
子类可以通过类属性定义其行为:
|
||||
- prompt_name: Prompt组件的唯一名称。
|
||||
- injection_point: 指定要注入的目标Prompt名称(或名称列表)。
|
||||
"""
|
||||
|
||||
prompt_name: str = ""
|
||||
"""Prompt组件的名称"""
|
||||
prompt_description: str = ""
|
||||
"""Prompt组件的描述"""
|
||||
|
||||
# 定义此组件希望注入到哪个或哪些核心Prompt中
|
||||
# 可以是一个字符串(单个目标)或字符串列表(多个目标)
|
||||
# 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"]
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __init__(self, params: PromptParameters, plugin_config: dict | None = None):
|
||||
"""初始化Prompt组件
|
||||
|
||||
Args:
|
||||
params: 统一提示词参数,包含所有构建提示词所需的上下文信息。
|
||||
plugin_config: 插件配置字典。
|
||||
"""
|
||||
self.params = params
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PromptComponent]"
|
||||
|
||||
logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> str:
|
||||
"""执行Prompt生成的抽象方法,子类必须实现。
|
||||
|
||||
此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。
|
||||
返回的字符串将被拼接到目标Prompt的最前面。
|
||||
|
||||
Returns:
|
||||
str: 生成的文本内容。
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问。
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。
|
||||
default: 未找到键时返回的默认值。
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值。
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_prompt_info(cls) -> "PromptInfo":
|
||||
"""从类属性生成PromptInfo,用于组件注册和管理。
|
||||
|
||||
Returns:
|
||||
PromptInfo: 生成的Prompt信息对象。
|
||||
"""
|
||||
if not cls.prompt_name:
|
||||
raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。")
|
||||
|
||||
return PromptInfo(
|
||||
name=cls.prompt_name,
|
||||
component_type=ComponentType.PROMPT,
|
||||
description=cls.prompt_description,
|
||||
injection_point=cls.injection_point,
|
||||
)
|
||||
@@ -20,6 +20,7 @@ class ComponentType(Enum):
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||
CHATTER = "chatter" # 聊天处理器组件
|
||||
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
||||
PROMPT = "prompt" # Prompt组件
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
@@ -143,7 +144,7 @@ class ActionInfo(ComponentInfo):
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: list[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: list[str] = field(default_factory=list) # 关联的消息类型
|
||||
|
||||
|
||||
# ==================================================================================
|
||||
# 激活类型相关字段(已废弃,建议使用 go_activate() 方法)
|
||||
# 保留这些字段是为了向后兼容,BaseAction.go_activate() 的默认实现会使用这些字段
|
||||
@@ -155,7 +156,7 @@ class ActionInfo(ComponentInfo):
|
||||
llm_judge_prompt: str = "" # 已废弃,建议在 go_activate() 中使用 _llm_judge_activation()
|
||||
activation_keywords: list[str] = field(default_factory=list) # 已废弃,建议在 go_activate() 中使用 _keyword_match()
|
||||
keyword_case_sensitive: bool = False # 已废弃
|
||||
|
||||
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
parallel_action: bool = False
|
||||
@@ -266,6 +267,18 @@ class EventInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptInfo(ComponentInfo):
|
||||
"""Prompt组件信息"""
|
||||
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.PROMPT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
@@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import (
|
||||
InterestCalculatorInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
@@ -37,6 +39,7 @@ ComponentClassType = (
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| type[BasePrompt]
|
||||
)
|
||||
|
||||
|
||||
@@ -183,6 +186,10 @@ class ComponentRegistry:
|
||||
assert isinstance(component_info, InterestCalculatorInfo)
|
||||
assert issubclass(component_class, BaseInterestCalculator)
|
||||
ret = self._register_interest_calculator_component(component_info, component_class)
|
||||
case ComponentType.PROMPT:
|
||||
assert isinstance(component_info, PromptInfo)
|
||||
assert issubclass(component_class, BasePrompt)
|
||||
ret = self._register_prompt_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
ret = False
|
||||
@@ -346,6 +353,31 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
|
||||
return True
|
||||
|
||||
def _register_prompt_component(
|
||||
self, prompt_info: PromptInfo, prompt_class: "ComponentClassType"
|
||||
) -> bool:
|
||||
"""注册Prompt组件到Prompt特定注册表"""
|
||||
prompt_name = prompt_info.name
|
||||
if not prompt_name:
|
||||
logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称")
|
||||
return False
|
||||
|
||||
if not hasattr(self, "_prompt_registry"):
|
||||
self._prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
if not hasattr(self, "_enabled_prompt_registry"):
|
||||
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
|
||||
_assign_plugin_attrs(
|
||||
prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {}
|
||||
)
|
||||
self._prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
if prompt_info.enabled:
|
||||
self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
logger.debug(f"已注册Prompt组件: {prompt_name}")
|
||||
return True
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
@@ -580,7 +612,17 @@ class ComponentRegistry:
|
||||
component_name: str,
|
||||
component_type: ComponentType | None = None,
|
||||
) -> (
|
||||
type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None
|
||||
type[
|
||||
BaseCommand
|
||||
| BaseAction
|
||||
| BaseEventHandler
|
||||
| BaseTool
|
||||
| PlusCommand
|
||||
| BaseChatter
|
||||
| BaseInterestCalculator
|
||||
| BasePrompt
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
@@ -829,6 +871,7 @@ class ComponentRegistry:
|
||||
events_handlers: int = 0
|
||||
plus_command_components: int = 0
|
||||
chatter_components: int = 0
|
||||
prompt_components: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
@@ -842,6 +885,8 @@ class ComponentRegistry:
|
||||
plus_command_components += 1
|
||||
elif component.component_type == ComponentType.CHATTER:
|
||||
chatter_components += 1
|
||||
elif component.component_type == ComponentType.PROMPT:
|
||||
prompt_components += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
@@ -849,6 +894,7 @@ class ComponentRegistry:
|
||||
"event_handlers": events_handlers,
|
||||
"plus_command_components": plus_command_components,
|
||||
"chatter_components": chatter_components,
|
||||
"prompt_components": prompt_components,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
|
||||
@@ -358,13 +358,14 @@ class PluginManager:
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
plus_command_count = stats.get("plus_command_components", 0)
|
||||
chatter_count = stats.get("chatter_components", 0)
|
||||
prompt_count = stats.get("prompt_components", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -382,6 +383,13 @@ class PluginManager:
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
|
||||
def format_component(c):
|
||||
desc = c.description
|
||||
if len(desc) > 15:
|
||||
desc = desc[:15] + "..."
|
||||
return f"{c.name} ({desc})" if desc else c.name
|
||||
|
||||
action_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||
]
|
||||
@@ -395,29 +403,35 @@ class PluginManager:
|
||||
plus_command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
|
||||
]
|
||||
prompt_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||
action_details = [format_component(c) for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_details)}")
|
||||
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
command_details = [format_component(c) for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_details)}")
|
||||
if tool_components:
|
||||
tool_names = [c.name for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
||||
tool_details = [format_component(c) for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_details)}")
|
||||
if plus_command_components:
|
||||
plus_command_names = [c.name for c in plus_command_components]
|
||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}")
|
||||
plus_command_details = [format_component(c) for c in plus_command_components]
|
||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_details)}")
|
||||
chatter_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.CHATTER
|
||||
]
|
||||
if chatter_components:
|
||||
chatter_names = [c.name for c in chatter_components]
|
||||
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}")
|
||||
chatter_details = [format_component(c) for c in chatter_components]
|
||||
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_details)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
event_handler_details = [format_component(c) for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
|
||||
if prompt_components:
|
||||
prompt_details = [format_component(c) for c in prompt_components]
|
||||
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
|
||||
|
||||
# 权限节点信息
|
||||
if plugin_instance := self.loaded_plugins.get(plugin_name):
|
||||
|
||||
@@ -155,88 +155,22 @@ class ChatterPlanFilter:
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
angry_prompt_addition = ""
|
||||
try:
|
||||
from src.plugins.built_in.sleep_system.api import get_wakeup_manager
|
||||
wakeup_mgr = get_wakeup_manager()
|
||||
except ImportError:
|
||||
logger.debug("无法导入睡眠系统API,将跳过相关检查。")
|
||||
wakeup_mgr = None
|
||||
|
||||
if wakeup_mgr:
|
||||
|
||||
# 双重检查确保愤怒状态不会丢失
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
if chat_mood_for_check.is_angry_from_wakeup:
|
||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||
|
||||
if angry_prompt_addition:
|
||||
schedule_block = angry_prompt_addition
|
||||
elif global_config.planning_system.schedule_enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if activity_info := schedule_manager.get_current_activity():
|
||||
activity = activity_info.get("activity", "未知活动")
|
||||
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||
# 需要情绪模块打开才能获得情绪,否则会引发报错
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
# 构建已读/未读历史消息
|
||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||
plan
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
@@ -286,7 +220,7 @@ class ChatterPlanFilter:
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取聊天类型和目标信息
|
||||
chat_type, target_info = await get_chat_type_and_target_info(self.chat_id)
|
||||
if chat_type:
|
||||
chat_type = ChatType.GROUP
|
||||
else:
|
||||
#遇到未知类型也当私聊处理
|
||||
chat_type = ChatType.PRIVATE
|
||||
|
||||
# 获取可用动作列表
|
||||
available_actions = await self._get_available_actions(chat_type, mode)
|
||||
@@ -62,12 +67,16 @@ class ChatterPlanGenerator:
|
||||
# 获取聊天历史记录
|
||||
recent_messages = await self._get_recent_messages()
|
||||
|
||||
# 构建计划对象
|
||||
# 使用 target_info 字典创建 TargetPersonInfo 实例
|
||||
target_person_info = TargetPersonInfo(**target_info) if target_info else TargetPersonInfo()
|
||||
|
||||
# 构建计划对象
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
chat_type=chat_type,
|
||||
mode=mode,
|
||||
target_info=target_info,
|
||||
target_info=target_person_info,
|
||||
available_actions=available_actions,
|
||||
chat_history=recent_messages,
|
||||
)
|
||||
@@ -77,6 +86,7 @@ class ChatterPlanGenerator:
|
||||
except Exception:
|
||||
# 如果生成失败,返回一个基本的空计划
|
||||
return Plan(
|
||||
chat_type = ChatType.PRIVATE,#空计划默认当成私聊
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
target_info=TargetPersonInfo(),
|
||||
@@ -124,7 +134,7 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取最近的消息记录
|
||||
raw_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size
|
||||
)
|
||||
|
||||
# 转换为 DatabaseMessages 对象
|
||||
|
||||
@@ -70,6 +70,7 @@ class ChatterActionPlanner:
|
||||
"replies_generated": 0,
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""
|
||||
@@ -157,7 +158,9 @@ class ChatterActionPlanner:
|
||||
)
|
||||
|
||||
if interest_updates:
|
||||
asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
task = asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._handle_task_result)
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
@@ -266,6 +269,17 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def _handle_task_result(self, task: asyncio.Task) -> None:
|
||||
"""处理后台任务的结果,记录异常。"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
pass # 任务被取消是正常现象
|
||||
except Exception as e:
|
||||
logger.error(f"后台任务执行失败: {e}", exc_info=True)
|
||||
finally:
|
||||
self._background_tasks.discard(task)
|
||||
|
||||
def get_planner_stats(self) -> dict[str, Any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
@@ -15,7 +15,7 @@ logger = get_logger(__name__)
|
||||
|
||||
@register_plugin
|
||||
class ProactiveThinkerPlugin(BasePlugin):
|
||||
"""一个主动思考的插件,但现在还只是个空壳子"""
|
||||
"""一个主动思考的插件"""
|
||||
|
||||
plugin_name: str = "proactive_thinker"
|
||||
enable_plugin: bool = True
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -13,7 +14,6 @@ from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system import BaseEventHandler, EventType
|
||||
from src.plugin_system.apis import chat_api, message_api, person_api
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
|
||||
|
||||
from .proactive_thinker_executor import ProactiveThinkerExecutor
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseSearchEngine(ABC):
|
||||
@@ -24,6 +24,12 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def read_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
读取URL内容,如果引擎不支持则返回None
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
|
||||
107
src/plugins/built_in/web_search_tool/engines/metaso_engine.py
Normal file
107
src/plugins/built_in/web_search_tool/engines/metaso_engine.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Metaso Search Engine (Chat Completions Mode)
|
||||
"""
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
import httpx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MetasoClient:
|
||||
"""A client to interact with the Metaso API."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = "https://metaso.cn/api/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def search(self, query: str, **kwargs) -> List[dict[str, Any]]:
|
||||
"""Perform a search using the Metaso Chat Completions API."""
|
||||
payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]}
|
||||
search_url = f"{self.base_url}/chat/completions"
|
||||
full_response_content = ""
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
try:
|
||||
async with client.stream("POST", search_url, headers=self.headers, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data_str = line[len("data:") :].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content_chunk = delta.get("content")
|
||||
if content_chunk:
|
||||
full_response_content += content_chunk
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
|
||||
continue
|
||||
|
||||
if not full_response_content:
|
||||
logger.warning("Metaso search returned an empty stream.")
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"title": query,
|
||||
"url": "https://metaso.cn/",
|
||||
"snippet": full_response_content,
|
||||
"provider": "Metaso (Chat)",
|
||||
}
|
||||
]
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error occurred while searching with Metaso Chat: {e.response.text}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while searching with Metaso Chat: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
class MetasoSearchEngine(BaseSearchEngine):
|
||||
"""Metaso Search Engine implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""Initialize Metaso clients."""
|
||||
metaso_api_keys = config_api.get_global_config("web_search.metaso_api_keys", None)
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
metaso_api_keys, lambda key: MetasoClient(api_key=key), "Metaso"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if the Metaso search engine is available."""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Execute a Metaso search."""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
try:
|
||||
metaso_client = self.api_manager.get_next_client()
|
||||
if not metaso_client:
|
||||
logger.error("Could not get Metaso client.")
|
||||
return []
|
||||
|
||||
return await metaso_client.search(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Metaso search failed: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -22,6 +22,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
提供网络搜索和URL解析功能,支持多种搜索引擎:
|
||||
- Exa (需要API密钥)
|
||||
- Tavily (需要API密钥)
|
||||
- Metaso (需要API密钥)
|
||||
- DuckDuckGo (免费)
|
||||
- Bing (免费)
|
||||
"""
|
||||
@@ -43,6 +44,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
from .engines.exa_engine import ExaSearchEngine
|
||||
from .engines.searxng_engine import SearXNGSearchEngine
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.metaso_engine import MetasoSearchEngine
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
@@ -50,14 +52,16 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
ddg_engine = DDGSearchEngine()
|
||||
bing_engine = BingSearchEngine()
|
||||
searxng_engine = SearXNGSearchEngine()
|
||||
|
||||
# 报告每个引擎的状态
|
||||
metaso_engine = MetasoSearchEngine()
|
||||
|
||||
# 报告每个引擎的状态
|
||||
engines_status = {
|
||||
"Exa": exa_engine.is_available(),
|
||||
"Tavily": tavily_engine.is_available(),
|
||||
"DuckDuckGo": ddg_engine.is_available(),
|
||||
"Bing": bing_engine.is_available(),
|
||||
"SearXNG": searxng_engine.is_available(),
|
||||
"Metaso": metaso_engine.is_available(),
|
||||
}
|
||||
|
||||
available_engines = [name for name, available in engines_status.items() if available]
|
||||
|
||||
@@ -15,6 +15,7 @@ from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.exa_engine import ExaSearchEngine
|
||||
from ..engines.searxng_engine import SearXNGSearchEngine
|
||||
from ..engines.tavily_engine import TavilySearchEngine
|
||||
from ..engines.metaso_engine import MetasoSearchEngine
|
||||
from ..utils.formatters import deduplicate_results, format_search_results
|
||||
|
||||
logger = get_logger("web_search_tool")
|
||||
@@ -51,6 +52,7 @@ class WebSurfingTool(BaseTool):
|
||||
"ddg": DDGSearchEngine(),
|
||||
"bing": BingSearchEngine(),
|
||||
"searxng": SearXNGSearchEngine(),
|
||||
"metaso": MetasoSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.3.2"
|
||||
version = "7.3.3"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -326,6 +326,7 @@ info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型
|
||||
qa_relation_search_top_k = 10 # 关系搜索TopK
|
||||
qa_relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系)
|
||||
qa_paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果)
|
||||
qa_paragraph_threshold = 0.4 # 段落阈值(相似度高于此阈值的段落才会被认为是相关的)
|
||||
qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)
|
||||
qa_ent_filter_top_k = 10 # 实体过滤TopK
|
||||
qa_ppr_damping = 0.8 # PPR阻尼系数
|
||||
@@ -473,11 +474,12 @@ enable_web_search_tool = true # 是否启用联网搜索tool
|
||||
enable_url_tool = true # 是否启用URL解析tool
|
||||
tavily_api_keys = ["None"]# Tavily API密钥列表,支持轮询机制
|
||||
exa_api_keys = ["None"]# EXA API密钥列表,支持轮询机制
|
||||
metaso_api_keys = ["None"]# Metaso API密钥列表,支持轮询机制
|
||||
searxng_instances = [] # SearXNG 实例 URL 列表
|
||||
searxng_api_keys = []# SearXNG 实例 API 密钥列表
|
||||
|
||||
# 搜索引擎配置
|
||||
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing"
|
||||
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing", "metaso"
|
||||
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
|
||||
|
||||
[sleep_system]
|
||||
|
||||
@@ -107,7 +107,7 @@ class UILogHandler(logging.Handler):
|
||||
# if not success:
|
||||
# print(f"[UI日志适配器] 日志发送失败: {ui_level} - {formatted_msg[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# 静默失败,不影响主程序
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user