Compare commits
7 Commits
767aad407a
...
0050bfff09
| Author | SHA1 | Date | |
|---|---|---|---|
|
0050bfff09
|
|||
|
e4192bb47c
|
|||
|
c9479bd7f4
|
|||
|
6432f339b4
|
|||
|
|
43dbfb2a1e | ||
|
|
9f666b580e | ||
|
|
fbc37bbcaf |
32
.gitea/workflows/build.yaml
Normal file
32
.gitea/workflows/build.yaml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
name: Build and Push Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
- gitea
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to Docker Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: docker.gardel.top
|
||||||
|
username: ${{ secrets.DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||||
|
- name: Build and Push Docker Image
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: docker.gardel.top/gardel/mofox:dev
|
||||||
|
build-args: |
|
||||||
|
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
VCS_REF=${{ github.sha }}
|
||||||
149
.github/workflows/docker-image.yml
vendored
149
.github/workflows/docker-image.yml
vendored
@@ -1,149 +0,0 @@
|
|||||||
name: Docker Build and Push
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
- dev
|
|
||||||
tags:
|
|
||||||
- "v*.*.*"
|
|
||||||
- "v*"
|
|
||||||
- "*.*.*"
|
|
||||||
- "*.*.*-*"
|
|
||||||
workflow_dispatch: # 允许手动触发工作流
|
|
||||||
|
|
||||||
# Workflow's jobs
|
|
||||||
jobs:
|
|
||||||
build-amd64:
|
|
||||||
name: Build AMD64 Image
|
|
||||||
runs-on: ubuntu-24.04
|
|
||||||
outputs:
|
|
||||||
digest: ${{ steps.build.outputs.digest }}
|
|
||||||
steps:
|
|
||||||
- name: Check out git repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
buildkitd-flags: --debug
|
|
||||||
|
|
||||||
# Log in docker hub
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
# Generate metadata for Docker images
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
|
||||||
|
|
||||||
# Build and push AMD64 image by digest
|
|
||||||
- name: Build and push AMD64
|
|
||||||
id: build
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: linux/amd64
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
file: ./Dockerfile
|
|
||||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
|
|
||||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
|
|
||||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
VCS_REF=${{ github.sha }}
|
|
||||||
|
|
||||||
build-arm64:
|
|
||||||
name: Build ARM64 Image
|
|
||||||
runs-on: ubuntu-24.04-arm
|
|
||||||
outputs:
|
|
||||||
digest: ${{ steps.build.outputs.digest }}
|
|
||||||
steps:
|
|
||||||
- name: Check out git repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
buildkitd-flags: --debug
|
|
||||||
|
|
||||||
# Log in docker hub
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
# Generate metadata for Docker images
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
|
||||||
|
|
||||||
# Build and push ARM64 image by digest
|
|
||||||
- name: Build and push ARM64
|
|
||||||
id: build
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
platforms: linux/arm64/v8
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
file: ./Dockerfile
|
|
||||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
|
|
||||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
|
|
||||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
|
||||||
VCS_REF=${{ github.sha }}
|
|
||||||
|
|
||||||
create-manifest:
|
|
||||||
name: Create Multi-Arch Manifest
|
|
||||||
runs-on: ubuntu-24.04
|
|
||||||
needs:
|
|
||||||
- build-amd64
|
|
||||||
- build-arm64
|
|
||||||
steps:
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
# Log in docker hub
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
# Generate metadata for Docker images
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
|
||||||
tags: |
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=tag
|
|
||||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
type=semver,pattern={{major}}.{{minor}}
|
|
||||||
type=semver,pattern={{major}}
|
|
||||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
|
||||||
|
|
||||||
- name: Create and Push Manifest
|
|
||||||
run: |
|
|
||||||
# 为每个标签创建多架构镜像
|
|
||||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
|
||||||
echo "Creating manifest for $tag"
|
|
||||||
docker buildx imagetools create -t $tag \
|
|
||||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
|
|
||||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
|
|
||||||
done
|
|
||||||
@@ -4,6 +4,7 @@ import binascii
|
|||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import json_repair
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
@@ -1023,6 +1024,15 @@ class EmojiManager:
|
|||||||
- 必须是表情包,非普通截图。
|
- 必须是表情包,非普通截图。
|
||||||
- 图中文字不超过5个。
|
- 图中文字不超过5个。
|
||||||
请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。
|
请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。
|
||||||
|
输出格式:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"detailed_description": "",
|
||||||
|
"keywords": [],
|
||||||
|
"refined_sentence": "",
|
||||||
|
"is_compliant": true
|
||||||
|
}}
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
|
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
|
||||||
@@ -1042,16 +1052,14 @@ class EmojiManager:
|
|||||||
if not vlm_response_str:
|
if not vlm_response_str:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL)
|
vlm_response_json = self._parse_json_response(vlm_response_str)
|
||||||
if match:
|
description = vlm_response_json.get("detailed_description", "")
|
||||||
vlm_response_json = json.loads(match.group(0))
|
emotions = vlm_response_json.get("keywords", [])
|
||||||
description = vlm_response_json.get("detailed_description", "")
|
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||||
emotions = vlm_response_json.get("keywords", [])
|
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
if description and emotions and refined_description:
|
||||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||||
if description and emotions and refined_description:
|
break
|
||||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
|
||||||
break
|
|
||||||
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。")
|
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。")
|
||||||
except (json.JSONDecodeError, AttributeError) as e:
|
except (json.JSONDecodeError, AttributeError) as e:
|
||||||
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
|
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
|
||||||
@@ -1196,6 +1204,29 @@ class EmojiManager:
|
|||||||
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
|
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_json_response(cls, response: str) -> dict[str, Any] | None:
|
||||||
|
"""解析 LLM 的 JSON 响应"""
|
||||||
|
try:
|
||||||
|
# 尝试提取 JSON 代码块
|
||||||
|
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
json_str = json_match.group(1)
|
||||||
|
else:
|
||||||
|
# 尝试直接解析
|
||||||
|
json_str = response.strip()
|
||||||
|
|
||||||
|
# 移除可能的注释
|
||||||
|
json_str = re.sub(r"//.*", "", json_str)
|
||||||
|
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||||
|
|
||||||
|
data = json_repair.loads(json_str)
|
||||||
|
return data
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
emoji_manager = None
|
emoji_manager = None
|
||||||
|
|
||||||
|
|||||||
@@ -129,16 +129,6 @@ class ChatStream:
|
|||||||
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||||
self.context.set_current_message(message)
|
self.context.set_current_message(message)
|
||||||
|
|
||||||
# 调试日志
|
|
||||||
logger.debug(
|
|
||||||
f"消息上下文已设置 - message_id: {message.message_id}, "
|
|
||||||
f"chat_id: {message.chat_id}, "
|
|
||||||
f"is_mentioned: {message.is_mentioned}, "
|
|
||||||
f"is_emoji: {message.is_emoji}, "
|
|
||||||
f"is_picid: {message.is_picid}, "
|
|
||||||
f"interest_value: {message.interest_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
||||||
"""安全获取消息的actions字段"""
|
"""安全获取消息的actions字段"""
|
||||||
import json
|
import json
|
||||||
|
|||||||
@@ -70,8 +70,6 @@ def init_prompt():
|
|||||||
{keywords_reaction_prompt}
|
{keywords_reaction_prompt}
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
不要复读你前面发过的内容,意思相近也不行。
|
不要复读你前面发过的内容,意思相近也不行。
|
||||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包),只输出一条回复就好。
|
|
||||||
⛔ 绝对禁止输出任何艾特:不要输出@、@xxx等格式。你看到的聊天记录中的艾特是系统显示格式,你无法通过模仿来实现真正的艾特。想称呼某人直接写名字。
|
|
||||||
|
|
||||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
*你叫{bot_name},也有人叫你{bot_nickname}*
|
||||||
|
|
||||||
@@ -140,11 +138,15 @@ def init_prompt():
|
|||||||
{time_block}
|
{time_block}
|
||||||
|
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
||||||
⛔ 绝对禁止输出任何形式的艾特:不要输出@、@xxx等。你看到的聊天记录中的艾特格式是系统显示用的,你无法通过模仿它来实现真正的艾特功能,只会输出一串无意义的假文本。想称呼某人直接写名字即可。
|
不要模仿任何系统消息的格式,你的回复应该是自然的对话内容,例如:
|
||||||
|
- 当你想要打招呼时,直接输出“你好!”而不是“[回复<xxx>]: 用户你好!”
|
||||||
|
- 当你想要提及某人时,直接叫对方名字,而不是“@xxx”
|
||||||
|
|
||||||
|
你只能输出文字,不能输出任何表情包、图片、文件等内容!如果用户要求你发送非文字内容,请输出"PASS",而不是[表情包:xxx]
|
||||||
|
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
|
|
||||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
*你叫{bot_name},也有人叫你{bot_nickname},请你清楚你的身份,分清对方到底有没有叫你*
|
||||||
|
|
||||||
现在,你说:
|
现在,你说:
|
||||||
""",
|
""",
|
||||||
@@ -211,8 +213,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||||||
*{chat_scene}*
|
*{chat_scene}*
|
||||||
|
|
||||||
### 核心任务
|
### 核心任务
|
||||||
- 你需要对以上未读历史消息进行统一回应。这些消息可能来自不同的参与者,你需要理解整体对话动态,生成一段自然、连贯的回复。
|
- 你需要对以上未读历史消息用一句简单的话统一回应。这些消息可能来自不同的参与者,你需要理解整体对话动态,生成一段自然、连贯的回复。
|
||||||
- 你的回复应该能够推动对话继续,可以回应其中一个或多个话题,也可以提出新的观点。
|
|
||||||
|
|
||||||
## 规则
|
## 规则
|
||||||
{safety_guidelines_block}
|
{safety_guidelines_block}
|
||||||
@@ -224,11 +225,15 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||||||
{time_block}
|
{time_block}
|
||||||
|
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,系统格式化文字)。只输出回复内容。
|
||||||
⛔ 绝对禁止输出任何形式的艾特:不要输出@、@xxx等。你看到的聊天记录中的艾特格式是系统显示用的,你无法通过模仿它来实现真正的艾特功能,只会输出一串无意义的假文本。想称呼某人直接写名字即可。
|
不要模仿任何系统消息的格式,你的回复应该是自然的对话内容,例如:
|
||||||
|
- 当你想要打招呼时,直接输出“你好!”而不是“[回复<xxx>]: 用户你好!”
|
||||||
|
- 当你想要提及某人时,直接叫对方名字,而不是“@xxx”
|
||||||
|
|
||||||
|
你只能输出文字,不能输出任何表情包、图片、文件等内容!如果用户要求你发送非文字内容,请输出"PASS",而不是[表情包:xxx]
|
||||||
|
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
|
|
||||||
*你叫{bot_name},也有人叫你{bot_nickname}*
|
*你叫{bot_name},也有人叫你{bot_nickname},请你清楚你的身份,分清对方到底有没有叫你*
|
||||||
|
|
||||||
现在,你说:
|
现在,你说:
|
||||||
""",
|
""",
|
||||||
|
|||||||
@@ -405,6 +405,12 @@ def recover_quoted_content(sentences: list[str], placeholder_map: dict[str, str]
|
|||||||
|
|
||||||
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
||||||
assert global_config is not None
|
assert global_config is not None
|
||||||
|
|
||||||
|
normalized_text = text.strip() if isinstance(text, str) else ""
|
||||||
|
if normalized_text.upper() == "PASS":
|
||||||
|
logger.info("[回复内容过滤器] 检测到PASS信号,跳过发送。")
|
||||||
|
return []
|
||||||
|
|
||||||
if not global_config.response_post_process.enable_response_post_process:
|
if not global_config.response_post_process.enable_response_post_process:
|
||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
|
|||||||
@@ -616,20 +616,20 @@ class StreamContext(BaseDataModel):
|
|||||||
# 如果没有指定类型要求,默认为支持
|
# 如果没有指定类型要求,默认为支持
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.debug(f"[check_types] 检查消息是否支持类型: {types}")
|
# logger.debug(f"[check_types] 检查消息是否支持类型: {types}") # 简化日志,避免冗余
|
||||||
|
|
||||||
# 优先从additional_config中获取format_info
|
# 优先从additional_config中获取format_info
|
||||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||||
import orjson
|
import orjson
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[check_types] additional_config 类型: {type(self.current_message.additional_config)}")
|
# logger.debug(f"[check_types] additional_config 类型: {type(self.current_message.additional_config)}") # 简化日志
|
||||||
config = orjson.loads(self.current_message.additional_config)
|
config = orjson.loads(self.current_message.additional_config)
|
||||||
logger.debug(f"[check_types] 解析后的 config 键: {config.keys() if isinstance(config, dict) else 'N/A'}")
|
# logger.debug(f"[check_types] 解析后的 config 键: {config.keys() if isinstance(config, dict) else 'N/A'}") # 简化日志
|
||||||
|
|
||||||
# 检查format_info结构
|
# 检查format_info结构
|
||||||
if "format_info" in config:
|
if "format_info" in config:
|
||||||
format_info = config["format_info"]
|
format_info = config["format_info"]
|
||||||
logger.debug(f"[check_types] 找到 format_info: {format_info}")
|
# logger.debug(f"[check_types] 找到 format_info: {format_info}") # 简化日志
|
||||||
|
|
||||||
# 方法1: 直接检查accept_format字段
|
# 方法1: 直接检查accept_format字段
|
||||||
if "accept_format" in format_info:
|
if "accept_format" in format_info:
|
||||||
@@ -646,9 +646,9 @@ class StreamContext(BaseDataModel):
|
|||||||
# 检查所有请求的类型是否都被支持
|
# 检查所有请求的类型是否都被支持
|
||||||
for requested_type in types:
|
for requested_type in types:
|
||||||
if requested_type not in accept_format:
|
if requested_type not in accept_format:
|
||||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
# logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}") # 简化日志
|
||||||
return False
|
return False
|
||||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
# logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)") # 简化日志
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 方法2: 检查content_format字段(向后兼容)
|
# 方法2: 检查content_format字段(向后兼容)
|
||||||
@@ -665,9 +665,9 @@ class StreamContext(BaseDataModel):
|
|||||||
# 检查所有请求的类型是否都被支持
|
# 检查所有请求的类型是否都被支持
|
||||||
for requested_type in types:
|
for requested_type in types:
|
||||||
if requested_type not in content_format:
|
if requested_type not in content_format:
|
||||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
# logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}") # 简化日志
|
||||||
return False
|
return False
|
||||||
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
# logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)") # 简化日志
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
||||||
@@ -679,16 +679,16 @@ class StreamContext(BaseDataModel):
|
|||||||
|
|
||||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||||
# 大多数消息至少支持text类型
|
# 大多数消息至少支持text类型
|
||||||
logger.debug("[check_types] 使用备用方案:默认支持类型检查")
|
# logger.debug("[check_types] 使用备用方案:默认支持类型检查") # 简化日志
|
||||||
default_supported_types = ["text", "emoji"]
|
default_supported_types = ["text", "emoji"]
|
||||||
for requested_type in types:
|
for requested_type in types:
|
||||||
if requested_type not in default_supported_types:
|
if requested_type not in default_supported_types:
|
||||||
logger.debug(f"[check_types] 使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
# logger.debug(f"[check_types] 使用默认类型检查,消息可能不支持类型 '{requested_type}'") # 简化日志
|
||||||
# 对于非基础类型,返回False以避免错误
|
# 对于非基础类型,返回False以避免错误
|
||||||
if requested_type not in ["text", "emoji", "reply"]:
|
if requested_type not in ["text", "emoji", "reply"]:
|
||||||
logger.warning(f"[check_types] ❌ 备用方案拒绝类型 '{requested_type}'")
|
logger.warning(f"[check_types] ❌ 备用方案拒绝类型 '{requested_type}'")
|
||||||
return False
|
return False
|
||||||
logger.debug("[check_types] ✅ 备用方案通过所有类型检查")
|
# logger.debug("[check_types] ✅ 备用方案通过所有类型检查") # 简化日志
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# ==================== 消息缓存系统方法 ====================
|
# ==================== 消息缓存系统方法 ====================
|
||||||
@@ -736,7 +736,7 @@ class StreamContext(BaseDataModel):
|
|||||||
list[DatabaseMessages]: 刷新的消息列表
|
list[DatabaseMessages]: 刷新的消息列表
|
||||||
"""
|
"""
|
||||||
if not self.message_cache:
|
if not self.message_cache:
|
||||||
logger.debug(f"StreamContext {self.stream_id} 缓存为空,无需刷新")
|
# 缓存为空是正常情况,不需要记录日志
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,281 +0,0 @@
|
|||||||
"""
|
|
||||||
透明连接复用管理器
|
|
||||||
在不改变原有API的情况下,实现数据库连接的智能复用
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("connection_pool_manager")
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionInfo:
|
|
||||||
"""连接信息包装器"""
|
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession, created_at: float):
|
|
||||||
self.session = session
|
|
||||||
self.created_at = created_at
|
|
||||||
self.last_used = created_at
|
|
||||||
self.in_use = False
|
|
||||||
self.ref_count = 0
|
|
||||||
|
|
||||||
def mark_used(self):
|
|
||||||
"""标记连接被使用"""
|
|
||||||
self.last_used = time.time()
|
|
||||||
self.in_use = True
|
|
||||||
self.ref_count += 1
|
|
||||||
|
|
||||||
def mark_released(self):
|
|
||||||
"""标记连接被释放"""
|
|
||||||
self.in_use = False
|
|
||||||
self.ref_count = max(0, self.ref_count - 1)
|
|
||||||
|
|
||||||
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
|
||||||
"""检查连接是否过期"""
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 检查总生命周期
|
|
||||||
if current_time - self.created_at > max_lifetime:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查空闲时间
|
|
||||||
if not self.in_use and current_time - self.last_used > max_idle:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
|
||||||
# 通过 `cast` 明确告知类型检查器 `shield` 的返回类型,避免类型错误
|
|
||||||
from typing import cast
|
|
||||||
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
|
||||||
logger.debug("连接已关闭")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# 这是一个预期的行为,例如在流式聊天中断时
|
|
||||||
logger.debug("关闭连接时任务被取消")
|
|
||||||
# 重新抛出异常以确保任务状态正确
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"关闭连接时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPoolManager:
|
|
||||||
"""透明的连接池管理器"""
|
|
||||||
|
|
||||||
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
|
||||||
self.max_pool_size = max_pool_size
|
|
||||||
self.max_lifetime = max_lifetime
|
|
||||||
self.max_idle = max_idle
|
|
||||||
|
|
||||||
# 连接池
|
|
||||||
self._connections: set[ConnectionInfo] = set()
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self._stats = {
|
|
||||||
"total_created": 0,
|
|
||||||
"total_reused": 0,
|
|
||||||
"total_expired": 0,
|
|
||||||
"active_connections": 0,
|
|
||||||
"pool_hits": 0,
|
|
||||||
"pool_misses": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 后台清理任务
|
|
||||||
self._cleanup_task: asyncio.Task | None = None
|
|
||||||
self._should_cleanup = False
|
|
||||||
|
|
||||||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动连接池管理器"""
|
|
||||||
if self._cleanup_task is None:
|
|
||||||
self._should_cleanup = True
|
|
||||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
||||||
logger.info("连接池管理器已启动")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止连接池管理器"""
|
|
||||||
self._should_cleanup = False
|
|
||||||
|
|
||||||
if self._cleanup_task:
|
|
||||||
self._cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._cleanup_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
|
||||||
|
|
||||||
# 关闭所有连接
|
|
||||||
await self._close_all_connections()
|
|
||||||
logger.info("连接池管理器已停止")
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
|
||||||
"""
|
|
||||||
获取数据库会话的透明包装器
|
|
||||||
如果有可用连接则复用,否则创建新连接
|
|
||||||
"""
|
|
||||||
connection_info = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试获取现有连接
|
|
||||||
connection_info = await self._get_reusable_connection(session_factory)
|
|
||||||
|
|
||||||
if connection_info:
|
|
||||||
# 复用现有连接
|
|
||||||
connection_info.mark_used()
|
|
||||||
self._stats["total_reused"] += 1
|
|
||||||
self._stats["pool_hits"] += 1
|
|
||||||
logger.debug(f"复用现有连接 (活跃连接数: {len(self._connections)})")
|
|
||||||
else:
|
|
||||||
# 创建新连接
|
|
||||||
session = session_factory()
|
|
||||||
connection_info = ConnectionInfo(session, time.time())
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
self._connections.add(connection_info)
|
|
||||||
|
|
||||||
connection_info.mark_used()
|
|
||||||
self._stats["total_created"] += 1
|
|
||||||
self._stats["pool_misses"] += 1
|
|
||||||
logger.debug(f"创建新连接 (活跃连接数: {len(self._connections)})")
|
|
||||||
|
|
||||||
yield connection_info.session
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# 发生错误时回滚连接
|
|
||||||
if connection_info and connection_info.session:
|
|
||||||
try:
|
|
||||||
await connection_info.session.rollback()
|
|
||||||
except Exception as rollback_error:
|
|
||||||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# 释放连接回池中
|
|
||||||
if connection_info:
|
|
||||||
connection_info.mark_released()
|
|
||||||
|
|
||||||
async def _get_reusable_connection(
|
|
||||||
self, session_factory: async_sessionmaker[AsyncSession]
|
|
||||||
) -> ConnectionInfo | None:
|
|
||||||
"""获取可复用的连接"""
|
|
||||||
# 导入方言适配器获取 ping 查询
|
|
||||||
from src.common.database.core.dialect_adapter import DialectAdapter
|
|
||||||
|
|
||||||
ping_query = DialectAdapter.get_ping_query()
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
# 清理过期连接
|
|
||||||
await self._cleanup_expired_connections_locked()
|
|
||||||
|
|
||||||
# 查找可复用的连接
|
|
||||||
for connection_info in list(self._connections):
|
|
||||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
|
||||||
# 验证连接是否仍然有效
|
|
||||||
try:
|
|
||||||
# 执行 ping 查询来验证连接
|
|
||||||
await connection_info.session.execute(text(ping_query))
|
|
||||||
return connection_info
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"连接验证失败,将移除: {e}")
|
|
||||||
await connection_info.close()
|
|
||||||
self._connections.remove(connection_info)
|
|
||||||
self._stats["total_expired"] += 1
|
|
||||||
|
|
||||||
# 检查是否可以创建新连接
|
|
||||||
if len(self._connections) >= self.max_pool_size:
|
|
||||||
logger.warning(f"连接池已满 ({len(self._connections)}/{self.max_pool_size}),等待复用")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _cleanup_expired_connections_locked(self):
|
|
||||||
"""清理过期连接(需要在锁内调用)"""
|
|
||||||
time.time()
|
|
||||||
expired_connections = [
|
|
||||||
connection_info for connection_info in list(self._connections)
|
|
||||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
|
||||||
]
|
|
||||||
|
|
||||||
for connection_info in expired_connections:
|
|
||||||
await connection_info.close()
|
|
||||||
self._connections.remove(connection_info)
|
|
||||||
self._stats["total_expired"] += 1
|
|
||||||
|
|
||||||
if expired_connections:
|
|
||||||
logger.debug(f"清理了 {len(expired_connections)} 个过期连接")
|
|
||||||
|
|
||||||
async def _cleanup_loop(self):
|
|
||||||
"""后台清理循环"""
|
|
||||||
while self._should_cleanup:
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(30.0) # 每30秒清理一次
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
await self._cleanup_expired_connections_locked()
|
|
||||||
|
|
||||||
# 更新统计信息
|
|
||||||
self._stats["active_connections"] = len(self._connections)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"连接池清理循环出错: {e}")
|
|
||||||
await asyncio.sleep(10.0)
|
|
||||||
|
|
||||||
async def _close_all_connections(self):
|
|
||||||
"""关闭所有连接"""
|
|
||||||
async with self._lock:
|
|
||||||
for connection_info in list(self._connections):
|
|
||||||
await connection_info.close()
|
|
||||||
|
|
||||||
self._connections.clear()
|
|
||||||
logger.info("所有连接已关闭")
|
|
||||||
|
|
||||||
def get_stats(self) -> dict[str, Any]:
|
|
||||||
"""获取连接池统计信息"""
|
|
||||||
return {
|
|
||||||
**self._stats,
|
|
||||||
"active_connections": len(self._connections),
|
|
||||||
"max_pool_size": self.max_pool_size,
|
|
||||||
"pool_efficiency": (
|
|
||||||
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
|
||||||
)
|
|
||||||
* 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局连接池管理器实例
|
|
||||||
_connection_pool_manager: ConnectionPoolManager | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
|
||||||
"""获取全局连接池管理器实例"""
|
|
||||||
global _connection_pool_manager
|
|
||||||
if _connection_pool_manager is None:
|
|
||||||
_connection_pool_manager = ConnectionPoolManager()
|
|
||||||
return _connection_pool_manager
|
|
||||||
|
|
||||||
|
|
||||||
async def start_connection_pool():
|
|
||||||
"""启动连接池"""
|
|
||||||
manager = get_connection_pool_manager()
|
|
||||||
await manager.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_connection_pool():
|
|
||||||
"""停止连接池"""
|
|
||||||
global _connection_pool_manager
|
|
||||||
if _connection_pool_manager:
|
|
||||||
await _connection_pool_manager.stop()
|
|
||||||
_connection_pool_manager = None
|
|
||||||
@@ -87,7 +87,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
|||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""获取数据库会话上下文管理器
|
"""获取数据库会话上下文管理器
|
||||||
|
|
||||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
这是数据库操作的主要入口点,直接从会话工厂获取独立会话。
|
||||||
|
|
||||||
支持的数据库:
|
支持的数据库:
|
||||||
- SQLite: 自动设置 busy_timeout 和外键约束
|
- SQLite: 自动设置 busy_timeout 和外键约束
|
||||||
@@ -101,20 +101,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
Yields:
|
Yields:
|
||||||
AsyncSession: SQLAlchemy异步会话对象
|
AsyncSession: SQLAlchemy异步会话对象
|
||||||
"""
|
"""
|
||||||
# 延迟导入避免循环依赖
|
async with get_db_session_direct() as session:
|
||||||
from ..optimization.connection_pool import get_connection_pool_manager
|
|
||||||
|
|
||||||
session_factory = await get_session_factory()
|
|
||||||
pool_manager = get_connection_pool_manager()
|
|
||||||
|
|
||||||
# 使用连接池管理器(透明复用连接)
|
|
||||||
async with pool_manager.get_session(session_factory) as session:
|
|
||||||
# 获取数据库类型并应用特定设置
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
assert global_config is not None
|
|
||||||
await _apply_session_settings(session, global_config.database.database_type)
|
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""数据库优化层
|
"""数据库优化层
|
||||||
|
|
||||||
职责:
|
职责:
|
||||||
- 连接池管理
|
|
||||||
- 批量调度
|
- 批量调度
|
||||||
- 多级缓存
|
- 多级缓存
|
||||||
- 数据预加载
|
- 数据预加载
|
||||||
@@ -23,12 +22,6 @@ from .cache_manager import (
|
|||||||
close_cache,
|
close_cache,
|
||||||
get_cache,
|
get_cache,
|
||||||
)
|
)
|
||||||
from .connection_pool import (
|
|
||||||
ConnectionPoolManager,
|
|
||||||
get_connection_pool_manager,
|
|
||||||
start_connection_pool,
|
|
||||||
stop_connection_pool,
|
|
||||||
)
|
|
||||||
from .preloader import (
|
from .preloader import (
|
||||||
AccessPattern,
|
AccessPattern,
|
||||||
CommonDataPreloader,
|
CommonDataPreloader,
|
||||||
@@ -46,8 +39,6 @@ __all__ = [
|
|||||||
"CacheEntry",
|
"CacheEntry",
|
||||||
"CacheStats",
|
"CacheStats",
|
||||||
"CommonDataPreloader",
|
"CommonDataPreloader",
|
||||||
# Connection Pool
|
|
||||||
"ConnectionPoolManager",
|
|
||||||
# Preloader
|
# Preloader
|
||||||
"DataPreloader",
|
"DataPreloader",
|
||||||
"LRUCache",
|
"LRUCache",
|
||||||
@@ -59,8 +50,5 @@ __all__ = [
|
|||||||
"close_preloader",
|
"close_preloader",
|
||||||
"get_batch_scheduler",
|
"get_batch_scheduler",
|
||||||
"get_cache",
|
"get_cache",
|
||||||
"get_connection_pool_manager",
|
|
||||||
"get_preloader",
|
"get_preloader",
|
||||||
"start_connection_pool",
|
|
||||||
"stop_connection_pool",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -304,13 +304,11 @@ class MultiLevelCache:
|
|||||||
# 1. 尝试从L1获取
|
# 1. 尝试从L1获取
|
||||||
value = await self.l1_cache.get(key)
|
value = await self.l1_cache.get(key)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
logger.debug(f"L1缓存命中: {key}")
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# 2. 尝试从L2获取
|
# 2. 尝试从L2获取
|
||||||
value = await self.l2_cache.get(key)
|
value = await self.l2_cache.get(key)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
logger.debug(f"L2缓存命中: {key}")
|
|
||||||
# 提升到L1
|
# 提升到L1
|
||||||
await self.l1_cache.set(key, value)
|
await self.l1_cache.set(key, value)
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -1,299 +0,0 @@
|
|||||||
"""
|
|
||||||
透明连接复用管理器
|
|
||||||
|
|
||||||
在不改变原有API的情况下,实现数据库连接的智能复用
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("database.connection_pool")
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionInfo:
|
|
||||||
"""连接信息包装器"""
|
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession, created_at: float):
|
|
||||||
self.session = session
|
|
||||||
self.created_at = created_at
|
|
||||||
self.last_used = created_at
|
|
||||||
self.in_use = False
|
|
||||||
self.ref_count = 0
|
|
||||||
|
|
||||||
def mark_used(self):
|
|
||||||
"""标记连接被使用"""
|
|
||||||
self.last_used = time.time()
|
|
||||||
self.in_use = True
|
|
||||||
self.ref_count += 1
|
|
||||||
|
|
||||||
def mark_released(self):
|
|
||||||
"""标记连接被释放"""
|
|
||||||
self.in_use = False
|
|
||||||
self.ref_count = max(0, self.ref_count - 1)
|
|
||||||
|
|
||||||
def is_expired(self, max_lifetime: float = 300.0, max_idle: float = 60.0) -> bool:
|
|
||||||
"""检查连接是否过期"""
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 检查总生命周期
|
|
||||||
if current_time - self.created_at > max_lifetime:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查空闲时间
|
|
||||||
if not self.in_use and current_time - self.last_used > max_idle:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
# 使用 shield 保护 close 操作,确保即使任务被取消也能完成关闭
|
|
||||||
from typing import cast
|
|
||||||
await cast(asyncio.Future, asyncio.shield(self.session.close()))
|
|
||||||
logger.debug("连接已关闭")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# 这是一个预期的行为,例如在流式聊天中断时
|
|
||||||
logger.debug("关闭连接时任务被取消")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"关闭连接时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPoolManager:
|
|
||||||
"""透明的连接池管理器"""
|
|
||||||
|
|
||||||
def __init__(self, max_pool_size: int = 10, max_lifetime: float = 300.0, max_idle: float = 60.0):
|
|
||||||
self.max_pool_size = max_pool_size
|
|
||||||
self.max_lifetime = max_lifetime
|
|
||||||
self.max_idle = max_idle
|
|
||||||
|
|
||||||
# 连接池
|
|
||||||
self._connections: set[ConnectionInfo] = set()
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self._stats = {
|
|
||||||
"total_created": 0,
|
|
||||||
"total_reused": 0,
|
|
||||||
"total_expired": 0,
|
|
||||||
"active_connections": 0,
|
|
||||||
"pool_hits": 0,
|
|
||||||
"pool_misses": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 后台清理任务
|
|
||||||
self._cleanup_task: asyncio.Task | None = None
|
|
||||||
self._should_cleanup = False
|
|
||||||
|
|
||||||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动连接池管理器"""
|
|
||||||
if self._cleanup_task is None:
|
|
||||||
self._should_cleanup = True
|
|
||||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
||||||
logger.info("✅ 连接池管理器已启动")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止连接池管理器"""
|
|
||||||
self._should_cleanup = False
|
|
||||||
|
|
||||||
if self._cleanup_task:
|
|
||||||
self._cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._cleanup_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._cleanup_task = None
|
|
||||||
|
|
||||||
# 关闭所有连接
|
|
||||||
await self._close_all_connections()
|
|
||||||
logger.info("✅ 连接池管理器已停止")
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_session(self, session_factory: async_sessionmaker[AsyncSession]):
|
|
||||||
"""
|
|
||||||
获取数据库会话的透明包装器
|
|
||||||
如果有可用连接则复用,否则创建新连接
|
|
||||||
|
|
||||||
事务管理说明:
|
|
||||||
- 正常退出时自动提交事务
|
|
||||||
- 发生异常时自动回滚事务
|
|
||||||
- 如果用户代码已手动调用 commit/rollback,再次调用是安全的(空操作)
|
|
||||||
- 支持所有数据库类型:SQLite、PostgreSQL
|
|
||||||
"""
|
|
||||||
connection_info = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试获取现有连接
|
|
||||||
connection_info = await self._get_reusable_connection(session_factory)
|
|
||||||
|
|
||||||
if connection_info:
|
|
||||||
# 复用现有连接
|
|
||||||
connection_info.mark_used()
|
|
||||||
self._stats["total_reused"] += 1
|
|
||||||
self._stats["pool_hits"] += 1
|
|
||||||
logger.debug(f"♻️ 复用连接 (池大小: {len(self._connections)})")
|
|
||||||
else:
|
|
||||||
# 创建新连接
|
|
||||||
session = session_factory()
|
|
||||||
connection_info = ConnectionInfo(session, time.time())
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
self._connections.add(connection_info)
|
|
||||||
|
|
||||||
connection_info.mark_used()
|
|
||||||
self._stats["total_created"] += 1
|
|
||||||
self._stats["pool_misses"] += 1
|
|
||||||
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
|
|
||||||
|
|
||||||
yield connection_info.session
|
|
||||||
|
|
||||||
# 🔧 正常退出时提交事务
|
|
||||||
# 这对所有数据库(SQLite、PostgreSQL)都很重要
|
|
||||||
# 因为 SQLAlchemy 默认使用事务模式,不会自动提交
|
|
||||||
# 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作
|
|
||||||
if connection_info and connection_info.session:
|
|
||||||
try:
|
|
||||||
# 检查事务是否处于活动状态,避免在已回滚的事务上提交
|
|
||||||
if connection_info.session.is_active:
|
|
||||||
await connection_info.session.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.warning(f"提交事务时出错: {commit_error}")
|
|
||||||
try:
|
|
||||||
await connection_info.session.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass # 忽略回滚错误,因为事务可能已经结束
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# 发生错误时回滚连接
|
|
||||||
if connection_info and connection_info.session:
|
|
||||||
try:
|
|
||||||
# 检查是否需要回滚(事务是否活动)
|
|
||||||
if connection_info.session.is_active:
|
|
||||||
await connection_info.session.rollback()
|
|
||||||
except Exception as rollback_error:
|
|
||||||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# 释放连接回池中
|
|
||||||
if connection_info:
|
|
||||||
connection_info.mark_released()
|
|
||||||
|
|
||||||
async def _get_reusable_connection(
|
|
||||||
self, session_factory: async_sessionmaker[AsyncSession]
|
|
||||||
) -> ConnectionInfo | None:
|
|
||||||
"""获取可复用的连接"""
|
|
||||||
async with self._lock:
|
|
||||||
# 清理过期连接
|
|
||||||
await self._cleanup_expired_connections_locked()
|
|
||||||
|
|
||||||
# 查找可复用的连接
|
|
||||||
for connection_info in list(self._connections):
|
|
||||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
|
||||||
# 验证连接是否仍然有效
|
|
||||||
try:
|
|
||||||
# 执行一个简单的查询来验证连接
|
|
||||||
await connection_info.session.execute(text("SELECT 1"))
|
|
||||||
return connection_info
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"连接验证失败,将移除: {e}")
|
|
||||||
await connection_info.close()
|
|
||||||
self._connections.remove(connection_info)
|
|
||||||
self._stats["total_expired"] += 1
|
|
||||||
|
|
||||||
# 检查是否可以创建新连接
|
|
||||||
if len(self._connections) >= self.max_pool_size:
|
|
||||||
logger.warning(f"⚠️ 连接池已满 ({len(self._connections)}/{self.max_pool_size})")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _cleanup_expired_connections_locked(self):
|
|
||||||
"""清理过期连接(需要在锁内调用)"""
|
|
||||||
expired_connections = [
|
|
||||||
connection_info for connection_info in list(self._connections)
|
|
||||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
|
||||||
]
|
|
||||||
|
|
||||||
for connection_info in expired_connections:
|
|
||||||
await connection_info.close()
|
|
||||||
self._connections.remove(connection_info)
|
|
||||||
self._stats["total_expired"] += 1
|
|
||||||
|
|
||||||
if expired_connections:
|
|
||||||
logger.debug(f"🧹 清理了 {len(expired_connections)} 个过期连接")
|
|
||||||
|
|
||||||
async def _cleanup_loop(self):
|
|
||||||
"""后台清理循环"""
|
|
||||||
while self._should_cleanup:
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(30.0) # 每30秒清理一次
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
await self._cleanup_expired_connections_locked()
|
|
||||||
|
|
||||||
# 更新统计信息
|
|
||||||
self._stats["active_connections"] = len(self._connections)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"连接池清理循环出错: {e}")
|
|
||||||
await asyncio.sleep(10.0)
|
|
||||||
|
|
||||||
async def _close_all_connections(self):
|
|
||||||
"""关闭所有连接"""
|
|
||||||
async with self._lock:
|
|
||||||
for connection_info in list(self._connections):
|
|
||||||
await connection_info.close()
|
|
||||||
|
|
||||||
self._connections.clear()
|
|
||||||
logger.info("所有连接已关闭")
|
|
||||||
|
|
||||||
def get_stats(self) -> dict[str, Any]:
|
|
||||||
"""获取连接池统计信息"""
|
|
||||||
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
|
|
||||||
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
|
|
||||||
|
|
||||||
return {
|
|
||||||
**self._stats,
|
|
||||||
"active_connections": len(self._connections),
|
|
||||||
"max_pool_size": self.max_pool_size,
|
|
||||||
"pool_efficiency": f"{pool_efficiency:.2f}%",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局连接池管理器实例
|
|
||||||
_connection_pool_manager: ConnectionPoolManager | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
|
||||||
"""获取全局连接池管理器实例"""
|
|
||||||
global _connection_pool_manager
|
|
||||||
if _connection_pool_manager is None:
|
|
||||||
_connection_pool_manager = ConnectionPoolManager()
|
|
||||||
return _connection_pool_manager
|
|
||||||
|
|
||||||
|
|
||||||
async def start_connection_pool():
|
|
||||||
"""启动连接池"""
|
|
||||||
manager = get_connection_pool_manager()
|
|
||||||
await manager.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_connection_pool():
|
|
||||||
"""停止连接池"""
|
|
||||||
global _connection_pool_manager
|
|
||||||
if _connection_pool_manager:
|
|
||||||
await _connection_pool_manager.stop()
|
|
||||||
_connection_pool_manager = None
|
|
||||||
@@ -923,6 +923,41 @@ class KokoroFlowChatterProactiveConfig(ValidatedConfigBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KokoroFlowChatterWaitingConfig(ValidatedConfigBase):
|
||||||
|
"""Kokoro Flow Chatter 等待策略配置"""
|
||||||
|
|
||||||
|
default_max_wait_seconds: int = Field(
|
||||||
|
default=300,
|
||||||
|
ge=0,
|
||||||
|
le=3600,
|
||||||
|
description="默认最大等待秒数(当LLM未给出等待时间时使用)",
|
||||||
|
)
|
||||||
|
min_wait_seconds: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=0,
|
||||||
|
le=1800,
|
||||||
|
description="允许的最小等待秒数,防止等待时间过短导致频繁打扰",
|
||||||
|
)
|
||||||
|
max_wait_seconds: int = Field(
|
||||||
|
default=1800,
|
||||||
|
ge=60,
|
||||||
|
le=7200,
|
||||||
|
description="允许的最大等待秒数,避免等待时间过长",
|
||||||
|
)
|
||||||
|
wait_duration_multiplier: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=10.0,
|
||||||
|
description="等待时长倍率,用于整体放大或缩短LLM给出的等待时间",
|
||||||
|
)
|
||||||
|
max_consecutive_timeouts: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=0,
|
||||||
|
le=10,
|
||||||
|
description="允许的连续等待超时次数上限,达到后不再等待用户回复 (0 表示不限制)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KokoroFlowChatterConfig(ValidatedConfigBase):
|
class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||||
"""
|
"""
|
||||||
Kokoro Flow Chatter 配置类 - 私聊专用心流对话系统
|
Kokoro Flow Chatter 配置类 - 私聊专用心流对话系统
|
||||||
@@ -947,6 +982,11 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
|||||||
description="是否在等待期间启用心理活动更新"
|
description="是否在等待期间启用心理活动更新"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
waiting: KokoroFlowChatterWaitingConfig = Field(
|
||||||
|
default_factory=KokoroFlowChatterWaitingConfig,
|
||||||
|
description="等待策略配置(默认等待时间、倍率等)",
|
||||||
|
)
|
||||||
|
|
||||||
# --- 私聊专属主动思考配置 ---
|
# --- 私聊专属主动思考配置 ---
|
||||||
proactive_thinking: KokoroFlowChatterProactiveConfig = Field(
|
proactive_thinking: KokoroFlowChatterProactiveConfig = Field(
|
||||||
default_factory=KokoroFlowChatterProactiveConfig,
|
default_factory=KokoroFlowChatterProactiveConfig,
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ class Memory:
|
|||||||
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
||||||
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime | None = None # 最近一次结构或元数据更新
|
||||||
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
||||||
access_count: int = 0 # 访问次数
|
access_count: int = 0 # 访问次数
|
||||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||||
@@ -245,6 +246,8 @@ class Memory:
|
|||||||
# 确保重要性和激活度在有效范围内
|
# 确保重要性和激活度在有效范围内
|
||||||
self.importance = max(0.0, min(1.0, self.importance))
|
self.importance = max(0.0, min(1.0, self.importance))
|
||||||
self.activation = max(0.0, min(1.0, self.activation))
|
self.activation = max(0.0, min(1.0, self.activation))
|
||||||
|
if not self.updated_at:
|
||||||
|
self.updated_at = self.created_at
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典(用于序列化)"""
|
"""转换为字典(用于序列化)"""
|
||||||
@@ -258,6 +261,7 @@ class Memory:
|
|||||||
"activation": self.activation,
|
"activation": self.activation,
|
||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
"last_accessed": self.last_accessed.isoformat(),
|
"last_accessed": self.last_accessed.isoformat(),
|
||||||
"access_count": self.access_count,
|
"access_count": self.access_count,
|
||||||
"decay_factor": self.decay_factor,
|
"decay_factor": self.decay_factor,
|
||||||
@@ -278,6 +282,13 @@ class Memory:
|
|||||||
# 备选:使用直接的 activation 字段
|
# 备选:使用直接的 activation 字段
|
||||||
activation_level = data.get("activation", 0.0)
|
activation_level = data.get("activation", 0.0)
|
||||||
|
|
||||||
|
updated_at_raw = data.get("updated_at")
|
||||||
|
if updated_at_raw:
|
||||||
|
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||||
|
else:
|
||||||
|
# 旧数据没有 updated_at,退化为最后访问时间或创建时间
|
||||||
|
updated_at = datetime.fromisoformat(data.get("last_accessed", data["created_at"]))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
subject_id=data["subject_id"],
|
subject_id=data["subject_id"],
|
||||||
@@ -288,6 +299,7 @@ class Memory:
|
|||||||
activation=activation_level, # 使用统一的激活度值
|
activation=activation_level, # 使用统一的激活度值
|
||||||
status=MemoryStatus(data.get("status", "staged")),
|
status=MemoryStatus(data.get("status", "staged")),
|
||||||
created_at=datetime.fromisoformat(data["created_at"]),
|
created_at=datetime.fromisoformat(data["created_at"]),
|
||||||
|
updated_at=updated_at,
|
||||||
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
|
last_accessed=datetime.fromisoformat(data.get("last_accessed", data["created_at"])),
|
||||||
access_count=data.get("access_count", 0),
|
access_count=data.get("access_count", 0),
|
||||||
decay_factor=data.get("decay_factor", 1.0),
|
decay_factor=data.get("decay_factor", 1.0),
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
import json_repair
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -187,8 +187,8 @@ class ShortTermMemoryManager:
|
|||||||
"importance": 0.7,
|
"importance": 0.7,
|
||||||
"attributes": {{
|
"attributes": {{
|
||||||
"time": "时间信息",
|
"time": "时间信息",
|
||||||
"attribute1": "其他属性1"
|
"attribute1": "其他属性1",
|
||||||
"attribute2": "其他属性2"
|
"attribute2": "其他属性2",
|
||||||
...
|
...
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
@@ -531,7 +531,7 @@ class ShortTermMemoryManager:
|
|||||||
json_str = re.sub(r"//.*", "", json_str)
|
json_str = re.sub(r"//.*", "", json_str)
|
||||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||||
|
|
||||||
data = json.loads(json_str)
|
data = json_repair.loads(json_str)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from src.common.logger import get_logger
|
|||||||
from src.plugin_system.base.base_chatter import BaseChatter
|
from src.plugin_system.base.base_chatter import BaseChatter
|
||||||
from src.plugin_system.base.component_types import ChatType
|
from src.plugin_system.base.component_types import ChatType
|
||||||
|
|
||||||
from .config import KFCMode, get_config
|
from .config import KFCMode, apply_wait_duration_rules, get_config
|
||||||
from .models import SessionStatus
|
from .models import SessionStatus
|
||||||
from .session import get_session_manager
|
from .session import get_session_manager
|
||||||
|
|
||||||
@@ -179,6 +179,30 @@ class KokoroFlowChatter(BaseChatter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 10. 执行动作
|
# 10. 执行动作
|
||||||
|
raw_wait = plan_response.max_wait_seconds
|
||||||
|
adjusted_wait = apply_wait_duration_rules(
|
||||||
|
raw_wait,
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
timeout_limit = max(0, self._config.waiting.max_consecutive_timeouts)
|
||||||
|
if (
|
||||||
|
timeout_limit
|
||||||
|
and session.consecutive_timeout_count >= timeout_limit
|
||||||
|
and raw_wait > 0
|
||||||
|
and adjusted_wait == 0
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"[KFC] 连续等待 %s 次未收到回复,暂停继续等待",
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
elif adjusted_wait != raw_wait:
|
||||||
|
logger.debug(
|
||||||
|
"[KFC] 调整等待时长: raw=%ss adjusted=%ss",
|
||||||
|
raw_wait,
|
||||||
|
adjusted_wait,
|
||||||
|
)
|
||||||
|
plan_response.max_wait_seconds = adjusted_wait
|
||||||
|
|
||||||
exec_results = []
|
exec_results = []
|
||||||
has_reply = False
|
has_reply = False
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,12 @@ class WaitingDefaults:
|
|||||||
# 最大等待时间
|
# 最大等待时间
|
||||||
max_wait_seconds: int = 1800
|
max_wait_seconds: int = 1800
|
||||||
|
|
||||||
|
# 等待时长倍率(>1 放大等待时间,<1 缩短)
|
||||||
|
wait_duration_multiplier: float = 1.0
|
||||||
|
|
||||||
|
# 连续等待超时上限(达到后不再继续等待,0 表示不限制)
|
||||||
|
max_consecutive_timeouts: int = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProactiveConfig:
|
class ProactiveConfig:
|
||||||
@@ -202,6 +208,8 @@ def load_config() -> KokoroFlowChatterConfig:
|
|||||||
default_max_wait_seconds=getattr(wait_cfg, 'default_max_wait_seconds', 300),
|
default_max_wait_seconds=getattr(wait_cfg, 'default_max_wait_seconds', 300),
|
||||||
min_wait_seconds=getattr(wait_cfg, 'min_wait_seconds', 30),
|
min_wait_seconds=getattr(wait_cfg, 'min_wait_seconds', 30),
|
||||||
max_wait_seconds=getattr(wait_cfg, 'max_wait_seconds', 1800),
|
max_wait_seconds=getattr(wait_cfg, 'max_wait_seconds', 1800),
|
||||||
|
wait_duration_multiplier=getattr(wait_cfg, 'wait_duration_multiplier', 1.0),
|
||||||
|
max_consecutive_timeouts=getattr(wait_cfg, 'max_consecutive_timeouts', 3),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 主动思考配置 - 支持 proactive 和 proactive_thinking 两种写法
|
# 主动思考配置 - 支持 proactive 和 proactive_thinking 两种写法
|
||||||
@@ -262,3 +270,35 @@ def reload_config() -> KokoroFlowChatterConfig:
|
|||||||
global _config
|
global _config
|
||||||
_config = load_config()
|
_config = load_config()
|
||||||
return _config
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
def apply_wait_duration_rules(raw_wait_seconds: int, consecutive_timeouts: int = 0) -> int:
|
||||||
|
"""根据配置计算最终的等待时间"""
|
||||||
|
if raw_wait_seconds <= 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
waiting_cfg = get_config().waiting
|
||||||
|
multiplier = max(waiting_cfg.wait_duration_multiplier, 0.0)
|
||||||
|
if multiplier == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
adjusted = int(round(raw_wait_seconds * multiplier))
|
||||||
|
|
||||||
|
min_wait = max(0, waiting_cfg.min_wait_seconds)
|
||||||
|
max_wait = max(waiting_cfg.max_wait_seconds, 0)
|
||||||
|
|
||||||
|
if max_wait > 0 and min_wait > 0 and max_wait < min_wait:
|
||||||
|
max_wait = min_wait
|
||||||
|
|
||||||
|
if max_wait > 0:
|
||||||
|
adjusted = min(adjusted, max_wait)
|
||||||
|
if min_wait > 0:
|
||||||
|
adjusted = max(adjusted, min_wait)
|
||||||
|
|
||||||
|
adjusted = max(adjusted, 0)
|
||||||
|
|
||||||
|
timeout_limit = max(0, waiting_cfg.max_consecutive_timeouts)
|
||||||
|
if timeout_limit and consecutive_timeouts >= timeout_limit:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return adjusted
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugin_system.apis.unified_scheduler import TriggerType, unified_scheduler
|
from src.plugin_system.apis.unified_scheduler import TriggerType, unified_scheduler
|
||||||
|
|
||||||
from .config import KFCMode, get_config
|
from .config import KFCMode, apply_wait_duration_rules, get_config
|
||||||
from .models import EventType, SessionStatus
|
from .models import EventType, SessionStatus
|
||||||
from .session import KokoroSession, get_session_manager
|
from .session import KokoroSession, get_session_manager
|
||||||
|
|
||||||
@@ -83,6 +83,7 @@ class ProactiveThinker:
|
|||||||
"""加载配置 - 使用统一的配置系统"""
|
"""加载配置 - 使用统一的配置系统"""
|
||||||
config = get_config()
|
config = get_config()
|
||||||
proactive_cfg = config.proactive
|
proactive_cfg = config.proactive
|
||||||
|
self._waiting_cfg = config.waiting
|
||||||
|
|
||||||
# 工作模式
|
# 工作模式
|
||||||
self._mode = config.mode
|
self._mode = config.mode
|
||||||
@@ -461,6 +462,30 @@ class ProactiveThinker:
|
|||||||
action.params["situation_type"] = "timeout"
|
action.params["situation_type"] = "timeout"
|
||||||
action.params["extra_context"] = extra_context
|
action.params["extra_context"] = extra_context
|
||||||
|
|
||||||
|
raw_wait = plan_response.max_wait_seconds
|
||||||
|
adjusted_wait = apply_wait_duration_rules(
|
||||||
|
raw_wait,
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
timeout_limit = max(0, getattr(self._waiting_cfg, "max_consecutive_timeouts", 0))
|
||||||
|
if (
|
||||||
|
timeout_limit
|
||||||
|
and session.consecutive_timeout_count >= timeout_limit
|
||||||
|
and raw_wait > 0
|
||||||
|
and adjusted_wait == 0
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"[ProactiveThinker] 连续等待 %s 次未获回复,停止继续等待",
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
elif adjusted_wait != raw_wait:
|
||||||
|
logger.debug(
|
||||||
|
"[ProactiveThinker] 调整超时等待: raw=%ss adjusted=%ss",
|
||||||
|
raw_wait,
|
||||||
|
adjusted_wait,
|
||||||
|
)
|
||||||
|
plan_response.max_wait_seconds = adjusted_wait
|
||||||
|
|
||||||
# ★ 在执行动作前最后一次检查状态,防止与 Chatter 并发
|
# ★ 在执行动作前最后一次检查状态,防止与 Chatter 并发
|
||||||
if session.status != SessionStatus.WAITING:
|
if session.status != SessionStatus.WAITING:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -684,6 +709,30 @@ class ProactiveThinker:
|
|||||||
action.params["situation_type"] = "proactive"
|
action.params["situation_type"] = "proactive"
|
||||||
action.params["extra_context"] = extra_context
|
action.params["extra_context"] = extra_context
|
||||||
|
|
||||||
|
raw_wait = plan_response.max_wait_seconds
|
||||||
|
adjusted_wait = apply_wait_duration_rules(
|
||||||
|
raw_wait,
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
timeout_limit = max(0, getattr(self._waiting_cfg, "max_consecutive_timeouts", 0))
|
||||||
|
if (
|
||||||
|
timeout_limit
|
||||||
|
and session.consecutive_timeout_count >= timeout_limit
|
||||||
|
and raw_wait > 0
|
||||||
|
and adjusted_wait == 0
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"[ProactiveThinker] 连续等待 %s 次未获回复,主动无需再等",
|
||||||
|
session.consecutive_timeout_count,
|
||||||
|
)
|
||||||
|
elif adjusted_wait != raw_wait:
|
||||||
|
logger.debug(
|
||||||
|
"[ProactiveThinker] 调整主动等待: raw=%ss adjusted=%ss",
|
||||||
|
raw_wait,
|
||||||
|
adjusted_wait,
|
||||||
|
)
|
||||||
|
plan_response.max_wait_seconds = adjusted_wait
|
||||||
|
|
||||||
# 执行动作(回复生成在 Action.execute() 中完成)
|
# 执行动作(回复生成在 Action.execute() 中完成)
|
||||||
for action in plan_response.actions:
|
for action in plan_response.actions:
|
||||||
await action_manager.execute_action(
|
await action_manager.execute_action(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "7.9.6"
|
version = "7.9.8"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||||
#如果你想要修改配置文件,请递增version的值
|
#如果你想要修改配置文件,请递增version的值
|
||||||
@@ -622,6 +622,14 @@ mode = "split"
|
|||||||
max_wait_seconds_default = 300 # 默认的最大等待秒数(AI发送消息后愿意等待用户回复的时间)
|
max_wait_seconds_default = 300 # 默认的最大等待秒数(AI发送消息后愿意等待用户回复的时间)
|
||||||
enable_continuous_thinking = true # 是否在等待期间启用心理活动更新
|
enable_continuous_thinking = true # 是否在等待期间启用心理活动更新
|
||||||
|
|
||||||
|
# --- 等待策略 ---
|
||||||
|
[kokoro_flow_chatter.waiting]
|
||||||
|
default_max_wait_seconds = 300 # LLM 未给出等待时间时的默认值
|
||||||
|
min_wait_seconds = 30 # 允许的最短等待时间,防止太快打扰用户
|
||||||
|
max_wait_seconds = 1800 # 允许的最长等待时间(秒)
|
||||||
|
wait_duration_multiplier = 1.0 # 对 LLM 给出的等待时间应用的倍率(>1 放大,<1 缩短)
|
||||||
|
max_consecutive_timeouts = 3 # 连续等待超时达到该次数后,强制不再继续等待(0 表示不限制)
|
||||||
|
|
||||||
# --- 私聊专属主动思考配置 ---
|
# --- 私聊专属主动思考配置 ---
|
||||||
# 注意:这是KFC专属的主动思考配置,只有当KFC启用时才生效。
|
# 注意:这是KFC专属的主动思考配置,只有当KFC启用时才生效。
|
||||||
# 它旨在模拟更真实、情感驱动的互动,而非简单的定时任务。
|
# 它旨在模拟更真实、情感驱动的互动,而非简单的定时任务。
|
||||||
|
|||||||
Reference in New Issue
Block a user