Compare commits
7 Commits
767aad407a
...
0050bfff09
| Author | SHA1 | Date | |
|---|---|---|---|
|
0050bfff09
|
|||
|
e4192bb47c
|
|||
|
c9479bd7f4
|
|||
|
6432f339b4
|
|||
|
|
43dbfb2a1e | ||
|
|
9f666b580e | ||
|
|
fbc37bbcaf |
32
.gitea/workflows/build.yaml
Normal file
32
.gitea/workflows/build.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
- gitea
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: docker.gardel.top
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
tags: docker.gardel.top/gardel/mofox:dev
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
149
.github/workflows/docker-image.yml
vendored
149
.github/workflows/docker-image.yml
vendored
@@ -1,149 +0,0 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
@@ -4,6 +4,7 @@ import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -1023,6 +1024,15 @@ class EmojiManager:
|
||||
- 必须是表情包,非普通截图。
|
||||
- 图中文字不超过5个。
|
||||
请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。
|
||||
输出格式:
|
||||
```json
|
||||
{{
|
||||
"detailed_description": "",
|
||||
"keywords": [],
|
||||
"refined_sentence": "",
|
||||
"is_compliant": true
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
|
||||
@@ -1042,16 +1052,14 @@ class EmojiManager:
|
||||
if not vlm_response_str:
|
||||
continue
|
||||
|
||||
match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL)
|
||||
if match:
|
||||
vlm_response_json = json.loads(match.group(0))
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
vlm_response_json = self._parse_json_response(vlm_response_str)
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。")
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
|
||||
@@ -1122,7 +1130,7 @@ class EmojiManager:
|
||||
if emoji_base64 is None: # 再次检查读取
|
||||
logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}")
|
||||
return False
|
||||
|
||||
|
||||
# 等待描述生成完成
|
||||
description, emotions = await self.build_emoji_description(emoji_base64)
|
||||
|
||||
@@ -1135,7 +1143,7 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
new_emoji.description = description
|
||||
new_emoji.emotion = emotions
|
||||
except Exception as build_desc_error:
|
||||
@@ -1196,6 +1204,29 @@ class EmojiManager:
|
||||
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _parse_json_response(cls, response: str) -> dict[str, Any] | None:
|
||||
"""解析 LLM 的 JSON 响应"""
|
||||
try:
|
||||
# 尝试提取 JSON 代码块
|
||||
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 尝试直接解析
|
||||
json_str = response.strip()
|
||||
|
||||
# 移除可能的注释
|
||||
json_str = re.sub(r"//.*", "", json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
|
||||
data = json_repair.loads(json_str)
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
emoji_manager = None
|
||||
|
||||
|
||||
@@ -129,16 +129,6 @@ class ChatStream:
|
||||
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||
self.context.set_current_message(message)
|
||||
|
||||
# 调试日志
|
||||
logger.debug(
|
||||
f"消息上下文已设置 - message_id: {message.message_id}, "
|
||||
f"chat_id: {message.chat_id}, "
|
||||
f"is_mentioned: {message.is_mentioned}, "
|
||||
f"is_emoji: {message.is_emoji}, "
|
||||
f"is_picid: {message.is_picid}, "
|
||||
f"interest_value: {message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
import json
|
||||
|
||||
@@ -70,8 +70,6 @@ def init_prompt():
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要复读你前面发过的内容,意思相近也不行。
|
||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包),只输出一条回复就好。
|
||||
⛔ 绝对禁止输出任何艾特:不要输出@、@xxx等格式。你看到的聊天记录中的艾特是系统显示格式,你无法通过模仿来实现真正的艾特。想称呼某人直接写名字。
|
||||
|
||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
||||
|
||||
@@ -140,11 +138,15 @@ def init_prompt():
|
||||
{time_block}
|
||||
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
||||
⛔ 绝对禁止输出任何形式的艾特:不要输出@、@xxx等。你看到的聊天记录中的艾特格式是系统显示用的,你无法通过模仿它来实现真正的艾特功能,只会输出一串无意义的假文本。想称呼某人直接写名字即可。
|
||||
不要模仿任何系统消息的格式,你的回复应该是自然的对话内容,例如:
|
||||
- 当你想要打招呼时,直接输出“你好!”而不是“[回复<xxx>]: 用户你好!”
|
||||
- 当你想要提及某人时,直接叫对方名字,而不是“@xxx”
|
||||
|
||||
你只能输出文字,不能输出任何表情包、图片、文件等内容!如果用户要求你发送非文字内容,请输出"PASS",而不是[表情包:xxx]
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
||||
*你叫{bot_name},也有人叫你{bot_nickname},请你清楚你的身份,分清对方到底有没有叫你*
|
||||
|
||||
现在,你说:
|
||||
""",
|
||||
@@ -211,8 +213,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
||||
*{chat_scene}*
|
||||
|
||||
### 核心任务
|
||||
- 你需要对以上未读历史消息进行统一回应。这些消息可能来自不同的参与者,你需要理解整体对话动态,生成一段自然、连贯的回复。
|
||||
- 你的回复应该能够推动对话继续,可以回应其中一个或多个话题,也可以提出新的观点。
|
||||
- 你需要对以上未读历史消息用一句简单的话统一回应。这些消息可能来自不同的参与者,你需要理解整体对话动态,生成一段自然、连贯的回复。
|
||||
|
||||
## 规则
|
||||
{safety_guidelines_block}
|
||||
@@ -224,11 +225,15 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
||||
{time_block}
|
||||
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
||||
⛔ 绝对禁止输出任何形式的艾特:不要输出@、@xxx等。你看到的聊天记录中的艾特格式是系统显示用的,你无法通过模仿它来实现真正的艾特功能,只会输出一串无意义的假文本。想称呼某人直接写名字即可。
|
||||
不要模仿任何系统消息的格式,你的回复应该是自然的对话内容,例如:
|
||||
- 当你想要打招呼时,直接输出“你好!”而不是“[回复<xxx>]: 用户你好!”
|
||||
- 当你想要提及某人时,直接叫对方名字,而不是“@xxx”
|
||||
|
||||
你只能输出文字,不能输出任何表情包、图片、文件等内容!如果用户要求你发送非文字内容,请输出"PASS",而不是[表情包:xxx]
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
||||
*你叫{bot_name},也有人叫你{bot_nickname},请你清楚你的身份,分清对方到底有没有叫你*
|
||||
|
||||
现在,你说:
|
||||
""",
|
||||
|
||||
@@ -405,6 +405,12 @@ def recover_quoted_content(sentences: list[str], placeholder_map: dict[str, str]
|
||||
|
||||
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
||||
assert global_config is not None
|
||||
|
||||
normalized_text = text.strip() if isinstance(text, str) else ""
|
||||
if normalized_text.upper() == "PASS":
|
||||
logger.info("[回复内容过滤器] 检测到PASS信号,跳过发送。")
|
||||
return []
|
||||
|
||||
if not global_config.response_post_process.enable_response_post_process:
|
||||
return [text]
|
||||
|
||||
|
||||
@@ -616,20 +616,20 @@ class StreamContext(BaseDataModel):
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
logger.debug(f"[check_types] 检查消息是否支持类型: {types}")
|
||||
# logger.debug(f"[check_types] 检查消息是否支持类型: {types}") # 简化日志,避免冗余
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
import orjson
|
||||
try:
|
||||
logger.debug(f"[check_types] additional_config 类型: {type(self.current_message.additional_config)}")
|
||||
# logger.debug(f"[check_types] additional_config 类型: {type(self.current_message.additional_config)}") # 简化日志
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
logger.debug(f"[check_types] 解析后的 config 键: {config.keys() if isinstance(config, dict) else 'N/A'}")
|
||||
# logger.debug(f"[check_types] 解析后的 config 键: {config.keys() if isinstance(config, dict) else 'N/A'}") # 简化日志
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
logger.debug(f"[check_types] 找到 format_info: {format_info}")
|
||||
# logger.debug(f"[check_types] 找到 format_info: {format_info}") # 简化日志
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
@@ -646,9 +646,9 @@ class StreamContext(BaseDataModel):
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
# logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}") # 简化日志
|
||||
return False
|
||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
||||
# logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") # 简化日志
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
@@ -665,9 +665,9 @@ class StreamContext(BaseDataModel):
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
# logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") # 简化日志
|
||||
return False
|
||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
||||
# logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") # 简化日志
|
||||
return True
|
||||
else:
|
||||
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
||||
@@ -679,16 +679,16 @@ class StreamContext(BaseDataModel):
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
logger.debug("[check_types] 使用备用方案:默认支持类型检查")
|
||||
# logger.debug("[check_types] 使用备用方案:默认支持类型检查") # 简化日志
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"[check_types] 使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# logger.debug(f"[check_types] 使用默认类型检查,消息可能不支持类型 '{requested_type}'") # 简化日志
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
logger.warning(f"[check_types] ❌ 备用方案拒绝类型 '{requested_type}'")
|
||||
return False
|
||||
logger.debug("[check_types] ✅ 备用方案通过所有类型检查")
|
||||
# logger.debug("[check_types] ✅ 备用方案通过所有类型检查") # 简化日志
|
||||
return True
|
||||
|
||||
# ==================== 消息缓存系统方法 ====================
|
||||
@@ -736,7 +736,7 @@ class StreamContext(BaseDataModel):
|
||||
list[DatabaseMessages]: 刷新的消息列表
|
||||
"""
|
||||
if not self.message_cache:
|
||||
logger.debug(f"StreamContext {self.stream_id} 缓存为空,无需刷新")
|
||||
# 缓存为空是正常情况,不需要记录日志
|
||||
return []
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
"""
|
||||
透明连接复用管理器
|
||||
在不改变原有API的情况下,实现数据库连接的智能复用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("connection_pool_manager")
|
||||
|
||||
|
||||
class ConnectionInfo:
|
||||
"""连接信息包装器"""
|
||||
|
||||
def __init__(self, session: AsyncSession, created_at: float):
|
||||
self.session = session
|
||||
self.created_at = created_at
|
||||
self.last_used = created_at
|
||||
self.in_use = False
|
||||
self.ref_count = 0
|
||||
|
||||
def mark_used(self):
|
||||
"""标记连接被使用"""
|
||||
self.last_used = time.time()
|
||||
self.in_use = True
|
||||
self.ref_count += 1
|
||||
|
||||
def mark_released(self):
|
||||
"""标记连接被释放"""
|
||||
self.in_use = False
|
||||
self.ref_count = max(0, self.ref_count - 1)
|
||||
|
||||
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
||||
"""检查连接是否过期"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查总生命周期
|
||||
if current_time - self.created_at > max_lifetime:
|
||||
return True
|
||||
|
||||
# 检查空闲时间
|
||||
if not self.in_use and current_time - self.last_used > max_idle:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""关闭连接"""
|
||||
try:
|
||||
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
||||
# 通过 `cast` 明确告知类型检查器 `shield` 的返回类型,避免类型错误
|
||||
from typing import cast
|
||||
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
||||
logger.debug("连接已关闭")
|
||||
except asyncio.CancelledError:
|
||||
# 这是一个预期的行为,例如在流式聊天中断时
|
||||
logger.debug("关闭连接时任务被取消")
|
||||
# 重新抛出异常以确保任务状态正确
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭连接时出错: {e}")
|
||||
|
||||
|
||||
class ConnectionPoolManager:
|
||||
"""透明的连接池管理器"""
|
||||
|
||||
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
||||
self.max_pool_size = max_pool_size
|
||||
self.max_lifetime = max_lifetime
|
||||
self.max_idle = max_idle
|
||||
|
||||
# 连接池
|
||||
self._connections: set[ConnectionInfo] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"total_created": 0,
|
||||
"total_reused": 0,
|
||||
"total_expired": 0,
|
||||
"active_connections": 0,
|
||||
"pool_hits": 0,
|
||||
"pool_misses": 0,
|
||||
}
|
||||
|
||||
# 后台清理任务
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._should_cleanup = False
|
||||
|
||||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||||
|
||||
async def start(self):
|
||||
"""启动连接池管理器"""
|
||||
if self._cleanup_task is None:
|
||||
self._should_cleanup = True
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("连接池管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止连接池管理器"""
|
||||
self._should_cleanup = False
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
# 关闭所有连接
|
||||
await self._close_all_connections()
|
||||
logger.info("连接池管理器已停止")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
||||
"""
|
||||
获取数据库会话的透明包装器
|
||||
如果有可用连接则复用,否则创建新连接
|
||||
"""
|
||||
connection_info = None
|
||||
|
||||
try:
|
||||
# 尝试获取现有连接
|
||||
connection_info = await self._get_reusable_connection(session_factory)
|
||||
|
||||
if connection_info:
|
||||
# 复用现有连接
|
||||
connection_info.mark_used()
|
||||
self._stats["total_reused"] += 1
|
||||
self._stats["pool_hits"] += 1
|
||||
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
|
||||
else:
|
||||
# 创建新连接
|
||||
session = session_factory()
|
||||
connection_info = ConnectionInfo(session, time.time())
|
||||
|
||||
async with self._lock:
|
||||
self._connections.add(connection_info)
|
||||
|
||||
connection_info.mark_used()
|
||||
self._stats["total_created"] += 1
|
||||
self._stats["pool_misses"] += 1
|
||||
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
|
||||
|
||||
yield connection_info.session
|
||||
|
||||
except Exception:
|
||||
# 发生错误时回滚连接
|
||||
if connection_info and connection_info.session:
|
||||
try:
|
||||
await connection_info.session.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回池中
|
||||
if connection_info:
|
||||
connection_info.mark_released()
|
||||
|
||||
async def _get_reusable_connection(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> ConnectionInfo | None:
|
||||
"""获取可复用的连接"""
|
||||
# 导入方言适配器获取 ping 查询
|
||||
from src.common.database.core.dialect_adapter import DialectAdapter
|
||||
|
||||
ping_query = DialectAdapter.get_ping_query()
|
||||
|
||||
async with self._lock:
|
||||
# 清理过期连接
|
||||
await self._cleanup_expired_connections_locked()
|
||||
|
||||
# 查找可复用的连接
|
||||
for connection_info in list(self._connections):
|
||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||
# 验证连接是否仍然有效
|
||||
try:
|
||||
# 执行 ping 查询来验证连接
|
||||
await connection_info.session.execute(text(ping_query))
|
||||
return connection_info
|
||||
except Exception as e:
|
||||
logger.debug(f"连接验证失败,将移除: {e}")
|
||||
await connection_info.close()
|
||||
self._connections.remove(connection_info)
|
||||
self._stats["total_expired"] += 1
|
||||
|
||||
# 检查是否可以创建新连接
|
||||
if len(self._connections) >= self.max_pool_size:
|
||||
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _cleanup_expired_connections_locked(self):
|
||||
"""清理过期连接(需要在锁内调用)"""
|
||||
time.time()
|
||||
expired_connections = [
|
||||
connection_info for connection_info in list(self._connections)
|
||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
||||
]
|
||||
|
||||
for connection_info in expired_connections:
|
||||
await connection_info.close()
|
||||
self._connections.remove(connection_info)
|
||||
self._stats["total_expired"] += 1
|
||||
|
||||
if expired_connections:
|
||||
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理循环"""
|
||||
while self._should_cleanup:
|
||||
try:
|
||||
await asyncio.sleep(30.0) # 每30秒清理一次
|
||||
|
||||
async with self._lock:
|
||||
await self._cleanup_expired_connections_locked()
|
||||
|
||||
# 更新统计信息
|
||||
self._stats["active_connections"] = len(self._connections)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"连接池清理循环出错: {e}")
|
||||
await asyncio.sleep(10.0)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""关闭所有连接"""
|
||||
async with self._lock:
|
||||
for connection_info in list(self._connections):
|
||||
await connection_info.close()
|
||||
|
||||
self._connections.clear()
|
||||
logger.info("所有连接已关闭")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取连接池统计信息"""
|
||||
return {
|
||||
**self._stats,
|
||||
"active_connections": len(self._connections),
|
||||
"max_pool_size": self.max_pool_size,
|
||||
"pool_efficiency": (
|
||||
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
||||
)
|
||||
* 100,
|
||||
}
|
||||
|
||||
|
||||
# 全局连接池管理器实例
|
||||
_connection_pool_manager: ConnectionPoolManager | None = None
|
||||
|
||||
|
||||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||||
"""获取全局连接池管理器实例"""
|
||||
global _connection_pool_manager
|
||||
if _connection_pool_manager is None:
|
||||
_connection_pool_manager = ConnectionPoolManager()
|
||||
return _connection_pool_manager
|
||||
|
||||
|
||||
async def start_connection_pool():
|
||||
"""启动连接池"""
|
||||
manager = get_connection_pool_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def stop_connection_pool():
|
||||
"""停止连接池"""
|
||||
global _connection_pool_manager
|
||||
if _connection_pool_manager:
|
||||
await _connection_pool_manager.stop()
|
||||
_connection_pool_manager = None
|
||||
@@ -87,7 +87,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话上下文管理器
|
||||
|
||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||
这是数据库操作的主要入口点,直接从会话工厂获取独立会话。
|
||||
|
||||
支持的数据库:
|
||||
- SQLite: 自动设置 busy_timeout 和外键约束
|
||||
@@ -101,20 +101,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
Yields:
|
||||
AsyncSession: SQLAlchemy异步会话对象
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from ..optimization.connection_pool import get_connection_pool_manager
|
||||
|
||||
session_factory = await get_session_factory()
|
||||
pool_manager = get_connection_pool_manager()
|
||||
|
||||
# 使用连接池管理器(透明复用连接)
|
||||
async with pool_manager.get_session(session_factory) as session:
|
||||
# 获取数据库类型并应用特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
async with get_db_session_direct() as session:
|
||||
yield session
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""数据库优化层
|
||||
|
||||
职责:
|
||||
- 连接池管理
|
||||
- 批量调度
|
||||
- 多级缓存
|
||||
- 数据预加载
|
||||
@@ -23,12 +22,6 @@ from .cache_manager import (
|
||||
close_cache,
|
||||
get_cache,
|
||||
)
|
||||
from .connection_pool import (
|
||||
ConnectionPoolManager,
|
||||
get_connection_pool_manager,
|
||||
start_connection_pool,
|
||||
stop_connection_pool,
|
||||
)
|
||||
from .preloader import (
|
||||
AccessPattern,
|
||||
CommonDataPreloader,
|
||||
@@ -46,8 +39,6 @@ __all__ = [
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"CommonDataPreloader",
|
||||
# Connection Pool
|
||||
"ConnectionPoolManager",
|
||||
# Preloader
|
||||
"DataPreloader",
|
||||
"LRUCache",
|
||||
@@ -59,8 +50,5 @@ __all__ = [
|
||||
"close_preloader",
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_connection_pool_manager",
|
||||
"get_preloader",
|
||||
"start_connection_pool",
|
||||
"stop_connection_pool",
|
||||
]
|
||||
|
||||
@@ -304,13 +304,11 @@ class MultiLevelCache:
|
||||
# 1. 尝试从L1获取
|
||||
value = await self.l1_cache.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"L1缓存命中: {key}")
|
||||
return value
|
||||
|
||||
# 2. 尝试从L2获取
|
||||
value = await self.l2_cache.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"L2缓存命中: {key}")
|
||||
# 提升到L1
|
||||
await self.l1_cache.set(key, value)
|
||||
return value
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
"""
|
||||
透明连接复用管理器
|
||||
|
||||
在不改变原有API的情况下,实现数据库连接的智能复用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.connection_pool")
|
||||
|
||||
|
||||
class ConnectionInfo:
|
||||
"""连接信息包装器"""
|
||||
|
||||
def __init__(self, session: AsyncSession, created_at: float):
|
||||
self.session = session
|
||||
self.created_at = created_at
|
||||
self.last_used = created_at
|
||||
self.in_use = False
|
||||
self.ref_count = 0
|
||||
|
||||
def mark_used(self):
|
||||
"""标记连接被使用"""
|
||||
self.last_used = time.time()
|
||||
self.in_use = True
|
||||
self.ref_count += 1
|
||||
|
||||
def mark_released(self):
|
||||
"""标记连接被释放"""
|
||||
self.in_use = False
|
||||
self.ref_count = max(0, self.ref_count - 1)
|
||||
|
||||
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
||||
"""检查连接是否过期"""
|
||||
current_time = time.time()
|
||||
|
||||
# 检查总生命周期
|
||||
if current_time - self.created_at > max_lifetime:
|
||||
return True
|
||||
|
||||
# 检查空闲时间
|
||||
if not self.in_use and current_time - self.last_used > max_idle:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""关闭连接"""
|
||||
try:
|
||||
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
||||
from typing import cast
|
||||
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
||||
logger.debug("连接已关闭")
|
||||
except asyncio.CancelledError:
|
||||
# 这是一个预期的行为,例如在流式聊天中断时
|
||||
logger.debug("关闭连接时任务被取消")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭连接时出错: {e}")
|
||||
|
||||
|
||||
class ConnectionPoolManager:
|
||||
"""透明的连接池管理器"""
|
||||
|
||||
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
||||
self.max_pool_size = max_pool_size
|
||||
self.max_lifetime = max_lifetime
|
||||
self.max_idle = max_idle
|
||||
|
||||
# 连接池
|
||||
self._connections: set[ConnectionInfo] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"total_created": 0,
|
||||
"total_reused": 0,
|
||||
"total_expired": 0,
|
||||
"active_connections": 0,
|
||||
"pool_hits": 0,
|
||||
"pool_misses": 0,
|
||||
}
|
||||
|
||||
# 后台清理任务
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._should_cleanup = False
|
||||
|
||||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||||
|
||||
async def start(self):
|
||||
"""启动连接池管理器"""
|
||||
if self._cleanup_task is None:
|
||||
self._should_cleanup = True
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("✅ 连接池管理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止连接池管理器"""
|
||||
self._should_cleanup = False
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
# 关闭所有连接
|
||||
await self._close_all_connections()
|
||||
logger.info("✅ 连接池管理器已停止")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
||||
"""
|
||||
获取数据库会话的透明包装器
|
||||
如果有可用连接则复用,否则创建新连接
|
||||
|
||||
事务管理说明:
|
||||
- 正常退出时自动提交事务
|
||||
- 发生异常时自动回滚事务
|
||||
- 如果用户代码已手动调用 commit/rollback,再次调用是安全的(空操作)
|
||||
- 支持所有数据库类型:SQLite、PostgreSQL
|
||||
"""
|
||||
connection_info = None
|
||||
|
||||
try:
|
||||
# 尝试获取现有连接
|
||||
connection_info = await self._get_reusable_connection(session_factory)
|
||||
|
||||
if connection_info:
|
||||
# 复用现有连接
|
||||
connection_info.mark_used()
|
||||
self._stats["total_reused"] += 1
|
||||
self._stats["pool_hits"] += 1
|
||||
logger.debug(f"♻️ 复用连接 (池大小: {len(self._connections)})")
|
||||
else:
|
||||
# 创建新连接
|
||||
session = session_factory()
|
||||
connection_info = ConnectionInfo(session, time.time())
|
||||
|
||||
async with self._lock:
|
||||
self._connections.add(connection_info)
|
||||
|
||||
connection_info.mark_used()
|
||||
self._stats["total_created"] += 1
|
||||
self._stats["pool_misses"] += 1
|
||||
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
|
||||
|
||||
yield connection_info.session
|
||||
|
||||
# 🔧 正常退出时提交事务
|
||||
# 这对所有数据库(SQLite、PostgreSQL)都很重要
|
||||
# 因为 SQLAlchemy 默认使用事务模式,不会自动提交
|
||||
# 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作
|
||||
if connection_info and connection_info.session:
|
||||
try:
|
||||
# 检查事务是否处于活动状态,避免在已回滚的事务上提交
|
||||
if connection_info.session.is_active:
|
||||
await connection_info.session.commit()
|
||||
except Exception as commit_error:
|
||||
logger.warning(f"提交事务时出错: {commit_error}")
|
||||
try:
|
||||
await connection_info.session.rollback()
|
||||
except Exception:
|
||||
pass # 忽略回滚错误,因为事务可能已经结束
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
# 发生错误时回滚连接
|
||||
if connection_info and connection_info.session:
|
||||
try:
|
||||
# 检查是否需要回滚(事务是否活动)
|
||||
if connection_info.session.is_active:
|
||||
await connection_info.session.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||||
raise
|
||||
finally:
|
||||
# 释放连接回池中
|
||||
if connection_info:
|
||||
connection_info.mark_released()
|
||||
|
||||
async def _get_reusable_connection(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> ConnectionInfo | None:
|
||||
"""获取可复用的连接"""
|
||||
async with self._lock:
|
||||
# 清理过期连接
|
||||
await self._cleanup_expired_connections_locked()
|
||||
|
||||
# 查找可复用的连接
|
||||
for connection_info in list(self._connections):
|
||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||
# 验证连接是否仍然有效
|
||||
try:
|
||||
# 执行一个简单的查询来验证连接
|
||||
await connection_info.session.execute(text("SELECT 1"))
|
||||
return connection_info
|
||||
except Exception as e:
|
||||
logger.debug(f"连接验证失败,将移除: {e}")
|
||||
await connection_info.close()
|
||||
self._connections.remove(connection_info)
|
||||
self._stats["total_expired"] += 1
|
||||
|
||||
# 检查是否可以创建新连接
|
||||
if len(self._connections) >= self.max_pool_size:
|
||||
logger.warning(f"⚠️ 连接池已满 ({len(self._connections)}/{self.max_pool_size})")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _cleanup_expired_connections_locked(self):
|
||||
"""清理过期连接(需要在锁内调用)"""
|
||||
expired_connections = [
|
||||
connection_info for connection_info in list(self._connections)
|
||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
||||
]
|
||||
|
||||
for connection_info in expired_connections:
|
||||
await connection_info.close()
|
||||
self._connections.remove(connection_info)
|
||||
self._stats["total_expired"] += 1
|
||||
|
||||
if expired_connections:
|
||||
logger.debug(f"🧹 清理了 {len(expired_connections)} 个过期连接")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理循环"""
|
||||
while self._should_cleanup:
|
||||
try:
|
||||
await asyncio.sleep(30.0) # 每30秒清理一次
|
||||
|
||||
async with self._lock:
|
||||
await self._cleanup_expired_connections_locked()
|
||||
|
||||
# 更新统计信息
|
||||
self._stats["active_connections"] = len(self._connections)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"连接池清理循环出错: {e}")
|
||||
await asyncio.sleep(10.0)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""关闭所有连接"""
|
||||
async with self._lock:
|
||||
for connection_info in list(self._connections):
|
||||
await connection_info.close()
|
||||
|
||||
self._connections.clear()
|
||||
logger.info("所有连接已关闭")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取连接池统计信息"""
|
||||
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
|
||||
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"active_connections": len(self._connections),
|
||||
"max_pool_size": self.max_pool_size,
|
||||
"pool_efficiency": f"{pool_efficiency:.2f}%",
|
||||
}
|
||||
|
||||
|
||||
# 全局连接池管理器实例
|
||||
_connection_pool_manager: ConnectionPoolManager | None = None
|
||||
|
||||
|
||||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||||
"""获取全局连接池管理器实例"""
|
||||
global _connection_pool_manager
|
||||
if _connection_pool_manager is None:
|
||||
_connection_pool_manager = ConnectionPoolManager()
|
||||
return _connection_pool_manager
|
||||
|
||||
|
||||
async def start_connection_pool():
|
||||
"""启动连接池"""
|
||||
manager = get_connection_pool_manager()
|
||||
await manager.start()
|
||||
|
||||
|
||||
async def stop_connection_pool():
|
||||
"""停止连接池"""
|
||||
global _connection_pool_manager
|
||||
if _connection_pool_manager:
|
||||
await _connection_pool_manager.stop()
|
||||
_connection_pool_manager = None
|
||||
@@ -923,6 +923,41 @@ class KokoroFlowChatterProactiveConfig(ValidatedConfigBase):
|
||||
)
|
||||
|
||||
|
||||
class KokoroFlowChatterWaitingConfig(ValidatedConfigBase):
|
||||
"""Kokoro Flow Chatter 等待策略配置"""
|
||||
|
||||
default_max_wait_seconds: int = Field(
|
||||
default=300,
|
||||
ge=0,
|
||||
le=3600,
|
||||
description="默认最大等待秒数(当LLM未给出等待时间时使用)",
|
||||
)
|
||||
min_wait_seconds: int = Field(
|
||||
default=30,
|
||||
ge=0,
|
||||
le=1800,
|
||||
description="允许的最小等待秒数,防止等待时间过短导致频繁打扰",
|
||||
)
|
||||
max_wait_seconds: int = Field(
|
||||
default=1800,
|
||||
ge=60,
|
||||
le=7200,
|
||||
description="允许的最大等待秒数,避免等待时间过长",
|
||||
)
|
||||
wait_duration_multiplier: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=10.0,
|
||||
description="等待时长倍率,用于整体放大或缩短LLM给出的等待时间",
|
||||
)
|
||||
max_consecutive_timeouts: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="允许的连续等待超时次数上限,达到后不再等待用户回复 (0 表示不限制)",
|
||||
)
|
||||
|
||||
|
||||
class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
"""
|
||||
Kokoro Flow Chatter 配置类 - 私聊专用心流对话系统
|
||||
@@ -947,6 +982,11 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
description="是否在等待期间启用心理活动更新"
|
||||
)
|
||||
|
||||
waiting: KokoroFlowChatterWaitingConfig = Field(
|
||||
default_factory=KokoroFlowChatterWaitingConfig,
|
||||
description="等待策略配置(默认等待时间、倍率等)",
|
||||
)
|
||||
|
||||
# --- 私聊专属主动思考配置 ---
|
||||
proactive_thinking: KokoroFlowChatterProactiveConfig = Field(
|
||||
default_factory=KokoroFlowChatterProactiveConfig,
|
||||
|
||||
@@ -233,6 +233,7 @@ class Memory:
|
||||
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
||||
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime | None = None # 最近一次结构或元数据更新
|
||||
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
||||
access_count: int = 0 # 访问次数
|
||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||
@@ -245,6 +246,8 @@ class Memory:
|
||||
# 确保重要性和激活度在有效范围内
|
||||
self.importance = max(0.0, min(1.0, self.importance))
|
||||
self.activation = max(0.0, min(1.0, self.activation))
|
||||
if not self.updated_at:
|
||||
self.updated_at = self.created_at
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
@@ -258,6 +261,7 @@ class Memory:
|
||||
"activation": self.activation,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"last_accessed": self.last_accessed.isoformat(),
|
||||
"access_count": self.access_count,
|
||||
"decay_factor": self.decay_factor,
|
||||
@@ -278,6 +282,13 @@ class Memory:
|
||||
# 备选:使用直接的 activation 字段
|
||||
activation_level = data.get("activation", 0.0)
|
||||
|
||||
updated_at_raw = data.get("updated_at")
|
||||
if updated_at_raw:
|
||||
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||
else:
|
||||
# 旧数据没有 updated_at,退化为最后访问时间或创建时间
|
||||
updated_at = datetime.fromisoformat(data.get("last_accessed", data["created_at"]))
|
||||
|
||||
return cls(
|
||||
id=data["id"],
|
||||
subject_id=data["subject_id"],
|
||||
@@ -288,6 +299,7 @@ class Memory:
|
||||
activation=activation_level, # 使用统一的激活度值
|
||||
status=MemoryStatus(data.get("status", "staged")),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=updated_at,
|
||||
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
|
||||
access_count=data.get("access_count", 0),
|
||||
decay_factor=data.get("decay_factor", 1.0),
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -187,8 +187,8 @@ class ShortTermMemoryManager:
|
||||
"importance": 0.7,
|
||||
"attributes": {{
|
||||
"time": "时间信息",
|
||||
"attribute1": "其他属性1"
|
||||
"attribute2": "其他属性2"
|
||||
"attribute1": "其他属性1",
|
||||
"attribute2": "其他属性2",
|
||||
...
|
||||
}}
|
||||
}}
|
||||
@@ -327,7 +327,7 @@ class ShortTermMemoryManager:
|
||||
# 创建决策对象
|
||||
# 将 LLM 返回的大写操作名转换为小写(适配枚举定义)
|
||||
operation_str = data.get("operation", "CREATE_NEW").lower()
|
||||
|
||||
|
||||
decision = ShortTermDecision(
|
||||
operation=ShortTermOperation(operation_str),
|
||||
target_memory_id=data.get("target_memory_id"),
|
||||
@@ -531,7 +531,7 @@ class ShortTermMemoryManager:
|
||||
json_str = re.sub(r"//.*", "", json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
|
||||
data = json.loads(json_str)
|
||||
data = json_repair.loads(json_str)
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -597,35 +597,35 @@ class ShortTermMemoryManager:
|
||||
# 1. 正常筛选:重要性达标的记忆
|
||||
candidates = [mem for mem in self.memories if mem.importance >= self.transfer_importance_threshold]
|
||||
candidate_ids = {mem.id for mem in candidates}
|
||||
|
||||
|
||||
# 2. 检查低重要性记忆是否积压
|
||||
# 剩余的都是低重要性记忆
|
||||
low_importance_memories = [mem for mem in self.memories if mem.id not in candidate_ids]
|
||||
|
||||
|
||||
# 如果低重要性记忆数量超过了上限(说明积压严重)
|
||||
# 我们需要清理掉一部分,而不是转移它们
|
||||
if len(low_importance_memories) > self.max_memories:
|
||||
# 目标保留数量(降至上限的 90%)
|
||||
target_keep_count = int(self.max_memories * 0.9)
|
||||
num_to_remove = len(low_importance_memories) - target_keep_count
|
||||
|
||||
|
||||
if num_to_remove > 0:
|
||||
# 按创建时间排序,删除最早的
|
||||
low_importance_memories.sort(key=lambda x: x.created_at)
|
||||
to_remove = low_importance_memories[:num_to_remove]
|
||||
|
||||
|
||||
for mem in to_remove:
|
||||
if mem in self.memories:
|
||||
self.memories.remove(mem)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"短期记忆清理: 移除了 {len(to_remove)} 条低重要性记忆 "
|
||||
f"(保留 {len(self.memories)} 条)"
|
||||
)
|
||||
|
||||
|
||||
# 触发保存
|
||||
asyncio.create_task(self._save_to_disk())
|
||||
|
||||
|
||||
return candidates
|
||||
|
||||
async def clear_transferred_memories(self, memory_ids: list[str]) -> None:
|
||||
|
||||
@@ -26,7 +26,7 @@ from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from .config import KFCMode, get_config
|
||||
from .config import KFCMode, apply_wait_duration_rules, get_config
|
||||
from .models import SessionStatus
|
||||
from .session import get_session_manager
|
||||
|
||||
@@ -179,6 +179,30 @@ class KokoroFlowChatter(BaseChatter):
|
||||
)
|
||||
|
||||
# 10. 执行动作
|
||||
raw_wait = plan_response.max_wait_seconds
|
||||
adjusted_wait = apply_wait_duration_rules(
|
||||
raw_wait,
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
timeout_limit = max(0, self._config.waiting.max_consecutive_timeouts)
|
||||
if (
|
||||
timeout_limit
|
||||
and session.consecutive_timeout_count >= timeout_limit
|
||||
and raw_wait > 0
|
||||
and adjusted_wait == 0
|
||||
):
|
||||
logger.info(
|
||||
"[KFC] 连续等待 %s 次未收到回复,暂停继续等待",
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
elif adjusted_wait != raw_wait:
|
||||
logger.debug(
|
||||
"[KFC] 调整等待时长: raw=%ss adjusted=%ss",
|
||||
raw_wait,
|
||||
adjusted_wait,
|
||||
)
|
||||
plan_response.max_wait_seconds = adjusted_wait
|
||||
|
||||
exec_results = []
|
||||
has_reply = False
|
||||
|
||||
|
||||
@@ -48,6 +48,12 @@ class WaitingDefaults:
|
||||
# 最大等待时间
|
||||
max_wait_seconds: int = 1800
|
||||
|
||||
# 等待时长倍率(>1 放大等待时间,<1 缩短)
|
||||
wait_duration_multiplier: float = 1.0
|
||||
|
||||
# 连续等待超时上限(达到后不再继续等待,0 表示不限制)
|
||||
max_consecutive_timeouts: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProactiveConfig:
|
||||
@@ -202,6 +208,8 @@ def load_config() -> KokoroFlowChatterConfig:
|
||||
default_max_wait_seconds=getattr(wait_cfg, 'default_max_wait_seconds', 300),
|
||||
min_wait_seconds=getattr(wait_cfg, 'min_wait_seconds', 30),
|
||||
max_wait_seconds=getattr(wait_cfg, 'max_wait_seconds', 1800),
|
||||
wait_duration_multiplier=getattr(wait_cfg, 'wait_duration_multiplier', 1.0),
|
||||
max_consecutive_timeouts=getattr(wait_cfg, 'max_consecutive_timeouts', 3),
|
||||
)
|
||||
|
||||
# 主动思考配置 - 支持 proactive 和 proactive_thinking 两种写法
|
||||
@@ -262,3 +270,35 @@ def reload_config() -> KokoroFlowChatterConfig:
|
||||
global _config
|
||||
_config = load_config()
|
||||
return _config
|
||||
|
||||
|
||||
def apply_wait_duration_rules(raw_wait_seconds: int, consecutive_timeouts: int = 0) -> int:
|
||||
"""根据配置计算最终的等待时间"""
|
||||
if raw_wait_seconds <= 0:
|
||||
return 0
|
||||
|
||||
waiting_cfg = get_config().waiting
|
||||
multiplier = max(waiting_cfg.wait_duration_multiplier, 0.0)
|
||||
if multiplier == 0:
|
||||
return 0
|
||||
|
||||
adjusted = int(round(raw_wait_seconds * multiplier))
|
||||
|
||||
min_wait = max(0, waiting_cfg.min_wait_seconds)
|
||||
max_wait = max(waiting_cfg.max_wait_seconds, 0)
|
||||
|
||||
if max_wait > 0 and min_wait > 0 and max_wait < min_wait:
|
||||
max_wait = min_wait
|
||||
|
||||
if max_wait > 0:
|
||||
adjusted = min(adjusted, max_wait)
|
||||
if min_wait > 0:
|
||||
adjusted = max(adjusted, min_wait)
|
||||
|
||||
adjusted = max(adjusted, 0)
|
||||
|
||||
timeout_limit = max(0, waiting_cfg.max_consecutive_timeouts)
|
||||
if timeout_limit and consecutive_timeouts >= timeout_limit:
|
||||
return 0
|
||||
|
||||
return adjusted
|
||||
|
||||
@@ -24,7 +24,7 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.unified_scheduler import TriggerType, unified_scheduler
|
||||
|
||||
from .config import KFCMode, get_config
|
||||
from .config import KFCMode, apply_wait_duration_rules, get_config
|
||||
from .models import EventType, SessionStatus
|
||||
from .session import KokoroSession, get_session_manager
|
||||
|
||||
@@ -83,6 +83,7 @@ class ProactiveThinker:
|
||||
"""加载配置 - 使用统一的配置系统"""
|
||||
config = get_config()
|
||||
proactive_cfg = config.proactive
|
||||
self._waiting_cfg = config.waiting
|
||||
|
||||
# 工作模式
|
||||
self._mode = config.mode
|
||||
@@ -460,6 +461,30 @@ class ProactiveThinker:
|
||||
action.params["thought"] = plan_response.thought
|
||||
action.params["situation_type"] = "timeout"
|
||||
action.params["extra_context"] = extra_context
|
||||
|
||||
raw_wait = plan_response.max_wait_seconds
|
||||
adjusted_wait = apply_wait_duration_rules(
|
||||
raw_wait,
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
timeout_limit = max(0, getattr(self._waiting_cfg, "max_consecutive_timeouts", 0))
|
||||
if (
|
||||
timeout_limit
|
||||
and session.consecutive_timeout_count >= timeout_limit
|
||||
and raw_wait > 0
|
||||
and adjusted_wait == 0
|
||||
):
|
||||
logger.info(
|
||||
"[ProactiveThinker] 连续等待 %s 次未获回复,停止继续等待",
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
elif adjusted_wait != raw_wait:
|
||||
logger.debug(
|
||||
"[ProactiveThinker] 调整超时等待: raw=%ss adjusted=%ss",
|
||||
raw_wait,
|
||||
adjusted_wait,
|
||||
)
|
||||
plan_response.max_wait_seconds = adjusted_wait
|
||||
|
||||
# ★ 在执行动作前最后一次检查状态,防止与 Chatter 并发
|
||||
if session.status != SessionStatus.WAITING:
|
||||
@@ -683,6 +708,30 @@ class ProactiveThinker:
|
||||
action.params["thought"] = plan_response.thought
|
||||
action.params["situation_type"] = "proactive"
|
||||
action.params["extra_context"] = extra_context
|
||||
|
||||
raw_wait = plan_response.max_wait_seconds
|
||||
adjusted_wait = apply_wait_duration_rules(
|
||||
raw_wait,
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
timeout_limit = max(0, getattr(self._waiting_cfg, "max_consecutive_timeouts", 0))
|
||||
if (
|
||||
timeout_limit
|
||||
and session.consecutive_timeout_count >= timeout_limit
|
||||
and raw_wait > 0
|
||||
and adjusted_wait == 0
|
||||
):
|
||||
logger.info(
|
||||
"[ProactiveThinker] 连续等待 %s 次未获回复,主动无需再等",
|
||||
session.consecutive_timeout_count,
|
||||
)
|
||||
elif adjusted_wait != raw_wait:
|
||||
logger.debug(
|
||||
"[ProactiveThinker] 调整主动等待: raw=%ss adjusted=%ss",
|
||||
raw_wait,
|
||||
adjusted_wait,
|
||||
)
|
||||
plan_response.max_wait_seconds = adjusted_wait
|
||||
|
||||
# 执行动作(回复生成在 Action.execute() 中完成)
|
||||
for action in plan_response.actions:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.9.6"
|
||||
version = "7.9.8"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -622,6 +622,14 @@ mode = "split"
|
||||
max_wait_seconds_default = 300 # 默认的最大等待秒数(AI发送消息后愿意等待用户回复的时间)
|
||||
enable_continuous_thinking = true # 是否在等待期间启用心理活动更新
|
||||
|
||||
# --- 等待策略 ---
|
||||
[kokoro_flow_chatter.waiting]
|
||||
default_max_wait_seconds = 300 # LLM 未给出等待时间时的默认值
|
||||
min_wait_seconds = 30 # 允许的最短等待时间,防止太快打扰用户
|
||||
max_wait_seconds = 1800 # 允许的最长等待时间(秒)
|
||||
wait_duration_multiplier = 1.0 # 对 LLM 给出的等待时间应用的倍率(>1 放大,<1 缩短)
|
||||
max_consecutive_timeouts = 3 # 连续等待超时达到该次数后,强制不再继续等待(0 表示不限制)
|
||||
|
||||
# --- 私聊专属主动思考配置 ---
|
||||
# 注意:这是KFC专属的主动思考配置,只有当KFC启用时才生效。
|
||||
# 它旨在模拟更真实、情感驱动的互动,而非简单的定时任务。
|
||||
|
||||
Reference in New Issue
Block a user