Compare commits
31 Commits
6fbfe735c9
...
1aa09ee340
| Author | SHA1 | Date | |
|---|---|---|---|
|
1aa09ee340
|
|||
|
25bd23ad3f
|
|||
|
|
179b5b7222 | ||
|
|
f39b0eaa44 | ||
|
|
b55df150d4 | ||
|
|
70217d7df8 | ||
|
|
f1bfcd1cff | ||
|
|
5a1d5052ca | ||
|
|
35502914a7 | ||
|
|
7d547b7b80 | ||
|
|
700cf477fb | ||
|
|
1f0b8fa04d | ||
|
|
1087d46ce2 | ||
|
|
da3752725e | ||
|
|
e5e552df65 | ||
|
|
0193913841 | ||
|
|
e6a4f855a2 | ||
|
|
9d01b81cef | ||
|
|
ef0c569348 | ||
|
|
e8bffe4a87 | ||
|
|
59e7a1a846 | ||
|
|
633585e6af | ||
|
|
c75cc88fb5 | ||
|
|
2d02bf4631 | ||
|
|
4592e37c10 | ||
|
|
c870af768d | ||
|
7735b161c8
|
|||
|
016c8647f7
|
|||
|
f269034b6a
|
|||
|
eac1ef2869
|
|||
|
8f3338f845
|
32
.gitea/workflows/build.yaml
Normal file
32
.gitea/workflows/build.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
- gitea
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: docker.gardel.top
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
tags: docker.gardel.top/gardel/mofox:dev
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
1
.github/copilot-instructions.md
vendored
1
.github/copilot-instructions.md
vendored
@@ -34,7 +34,6 @@ MoFox_Bot 是基于 MaiCore 的增强型 QQ 聊天机器人,集成了 LLM、
|
||||
- `PLUS_COMMAND`: 增强命令(支持参数解析、权限检查)
|
||||
- `TOOL`: LLM 工具调用(函数调用集成)
|
||||
- `EVENT_HANDLER`: 事件订阅处理器
|
||||
- `INTEREST_CALCULATOR`: 兴趣值计算器
|
||||
- `PROMPT`: 自定义提示词注入
|
||||
|
||||
**插件开发流程**:
|
||||
|
||||
149
.github/workflows/docker-image.yml
vendored
149
.github/workflows/docker-image.yml
vendored
@@ -1,149 +0,0 @@
|
||||
name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
- "v*"
|
||||
- "*.*.*"
|
||||
- "*.*.*-*"
|
||||
workflow_dispatch: # 允许手动触发工作流
|
||||
|
||||
# Workflow's jobs
|
||||
jobs:
|
||||
build-amd64:
|
||||
name: Build AMD64 Image
|
||||
runs-on: ubuntu-24.04
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push AMD64 image by digest
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
build-arm64:
|
||||
name: Build ARM64 Image
|
||||
runs-on: ubuntu-24.04-arm
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
steps:
|
||||
- name: Check out git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: --debug
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
|
||||
# Build and push ARM64 image by digest
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/arm64/v8
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
|
||||
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
|
||||
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
|
||||
build-args: |
|
||||
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
|
||||
VCS_REF=${{ github.sha }}
|
||||
|
||||
create-manifest:
|
||||
name: Create Multi-Arch Manifest
|
||||
runs-on: ubuntu-24.04
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# Log in docker hub
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# Generate metadata for Docker images
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
|
||||
|
||||
- name: Create and Push Manifest
|
||||
run: |
|
||||
# 为每个标签创建多架构镜像
|
||||
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
|
||||
echo "Creating manifest for $tag"
|
||||
docker buildx imagetools create -t $tag \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
|
||||
done
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -342,3 +342,4 @@ package.json
|
||||
/backup
|
||||
mofox_bot_statistics.html
|
||||
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
|
||||
depends-data/pinyin_dict.json
|
||||
|
||||
@@ -9,6 +9,10 @@ RUN apt-get update && apt-get install -y build-essential
|
||||
# 复制依赖列表和锁文件
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/ffmpeg
|
||||
COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/ffprobe
|
||||
RUN ldconfig && ffmpeg -version
|
||||
|
||||
# 安装依赖(使用 --frozen 确保使用锁文件中的版本)
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
|
||||
21
bot.py
21
bot.py
@@ -35,7 +35,6 @@ class StartupStageReporter:
|
||||
else:
|
||||
self._logger.info(title)
|
||||
|
||||
|
||||
startup_stage = StartupStageReporter(logger)
|
||||
|
||||
# 常量定义
|
||||
@@ -567,6 +566,7 @@ class MaiBotMain:
|
||||
|
||||
def __init__(self):
|
||||
self.main_system = None
|
||||
self._typo_prewarm_task = None
|
||||
|
||||
def setup_timezone(self):
|
||||
"""设置时区"""
|
||||
@@ -663,6 +663,25 @@ class MaiBotMain:
|
||||
async def run_async_init(self, main_system):
|
||||
"""执行异步初始化步骤"""
|
||||
|
||||
# 后台预热中文错别字生成器,避免首次使用阻塞主流程
|
||||
try:
|
||||
from src.chat.utils.typo_generator import get_typo_generator
|
||||
|
||||
typo_cfg = getattr(global_config, "chinese_typo", None)
|
||||
self._typo_prewarm_task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
get_typo_generator,
|
||||
error_rate=getattr(typo_cfg, "error_rate", 0.3),
|
||||
min_freq=getattr(typo_cfg, "min_freq", 5),
|
||||
tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2),
|
||||
word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3),
|
||||
max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200),
|
||||
)
|
||||
)
|
||||
logger.debug("已启动 ChineseTypoGenerator 后台预热任务")
|
||||
except Exception as e:
|
||||
logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}")
|
||||
|
||||
# 初始化数据库表结构
|
||||
await self.initialize_database_async()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -1022,6 +1023,15 @@ class EmojiManager:
|
||||
- 必须是表情包,非普通截图。
|
||||
- 图中文字不超过5个。
|
||||
请确保你的最终输出是严格的JSON对象,不要添加任何额外解释或文本。
|
||||
输出格式:
|
||||
```json
|
||||
{{
|
||||
"detailed_description": "",
|
||||
"keywords": [],
|
||||
"refined_sentence": "",
|
||||
"is_compliant": true
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
|
||||
@@ -1041,16 +1051,14 @@ class EmojiManager:
|
||||
if not vlm_response_str:
|
||||
continue
|
||||
|
||||
match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL)
|
||||
if match:
|
||||
vlm_response_json = json.loads(match.group(0))
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
vlm_response_json = self._parse_json_response(vlm_response_str)
|
||||
description = vlm_response_json.get("detailed_description", "")
|
||||
emotions = vlm_response_json.get("keywords", [])
|
||||
refined_description = vlm_response_json.get("refined_sentence", "")
|
||||
is_compliant = vlm_response_json.get("is_compliant", False)
|
||||
if description and emotions and refined_description:
|
||||
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
|
||||
break
|
||||
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误,准备重试。")
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
|
||||
@@ -1195,6 +1203,29 @@ class EmojiManager:
|
||||
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _parse_json_response(cls, response: str) -> dict[str, Any] | None:
|
||||
"""解析 LLM 的 JSON 响应"""
|
||||
try:
|
||||
# 尝试提取 JSON 代码块
|
||||
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 尝试直接解析
|
||||
json_str = response.strip()
|
||||
|
||||
# 移除可能的注释
|
||||
json_str = re.sub(r"//.*", "", json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
|
||||
data = json_repair.loads(json_str)
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
emoji_manager = None
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -76,6 +77,45 @@ def weighted_sample(population: list[dict], weights: list[float], k: int) -> lis
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
@staticmethod
|
||||
def _sample_with_temperature(
|
||||
candidates: list[tuple[Any, float, float, str]],
|
||||
max_num: int,
|
||||
temperature: float,
|
||||
) -> list[tuple[Any, float, float, str]]:
|
||||
"""
|
||||
对候选表达按温度采样,温度越高越均匀。
|
||||
|
||||
Args:
|
||||
candidates: (expr, similarity, count, best_predicted) 列表
|
||||
max_num: 需要返回的数量
|
||||
temperature: 温度参数,0 表示贪婪选择
|
||||
"""
|
||||
if max_num <= 0 or not candidates:
|
||||
return []
|
||||
|
||||
if temperature <= 0:
|
||||
return candidates[:max_num]
|
||||
|
||||
adjusted_temp = max(temperature, 1e-6)
|
||||
# 使用与排序相同的打分,但通过 softmax/temperature 放大尾部概率
|
||||
scores = [max(c[1] * (c[2] ** 0.5), 1e-8) for c in candidates]
|
||||
max_score = max(scores)
|
||||
weights = [math.exp((s - max_score) / adjusted_temp) for s in scores]
|
||||
|
||||
# 始终保留最高分一个,剩余的按温度采样,避免过度集中
|
||||
best_idx = scores.index(max_score)
|
||||
selected = [candidates[best_idx]]
|
||||
remaining_indices = [i for i in range(len(candidates)) if i != best_idx]
|
||||
|
||||
while remaining_indices and len(selected) < max_num:
|
||||
current_weights = [weights[i] for i in remaining_indices]
|
||||
picked_pos = random.choices(range(len(remaining_indices)), weights=current_weights, k=1)[0]
|
||||
picked_idx = remaining_indices.pop(picked_pos)
|
||||
selected.append(candidates[picked_idx])
|
||||
|
||||
return selected
|
||||
|
||||
def __init__(self, chat_id: str = ""):
|
||||
self.chat_id = chat_id
|
||||
if model_config is None:
|
||||
@@ -517,12 +557,21 @@ class ExpressionSelector:
|
||||
)
|
||||
return []
|
||||
|
||||
# 按照相似度*count排序,选择最佳匹配
|
||||
# 按照相似度*count排序,并根据温度采样,避免过度集中
|
||||
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||
temperature = getattr(global_config.expression, "model_temperature", 0.0)
|
||||
sampled_matches = self._sample_with_temperature(
|
||||
candidates=matched_expressions,
|
||||
max_num=max_num,
|
||||
temperature=temperature,
|
||||
)
|
||||
expressions_objs = [e[0] for e in sampled_matches]
|
||||
|
||||
# 显示最佳匹配的详细信息
|
||||
logger.debug(f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式")
|
||||
logger.debug(
|
||||
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式 "
|
||||
f"(候选 {len(matched_expressions)},temperature={temperature})"
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = [
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
"""
|
||||
兴趣度系统模块
|
||||
提供机器人兴趣标签和智能匹配功能,以及消息兴趣值计算功能
|
||||
目前仅保留兴趣计算器管理入口
|
||||
"""
|
||||
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
from src.common.data_models.bot_interest_data_model import InterestMatchResult
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from .interest_manager import InterestManager, get_interest_manager
|
||||
|
||||
__all__ = [
|
||||
# 机器人兴趣标签管理
|
||||
"BotInterestManager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
# 消息兴趣值计算管理
|
||||
"InterestManager",
|
||||
"InterestMatchResult",
|
||||
"bot_interest_manager",
|
||||
"get_interest_manager",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,10 +3,10 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, Any, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
from sqlalchemy import desc, insert, select, update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -25,29 +25,55 @@ class MessageStorageBatcher:
|
||||
消息存储批处理器
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 50,
|
||||
flush_interval: float = 5.0,
|
||||
*,
|
||||
commit_batch_size: int | None = None,
|
||||
commit_interval: float | None = None,
|
||||
db_chunk_size: int = 200,
|
||||
):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
batch_size: 写入队列中触发准备阶段的消息条数
|
||||
flush_interval: 自动刷新/检查间隔(秒)
|
||||
commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size,至少100)
|
||||
commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s))
|
||||
db_chunk_size: 单次SQL语句批量写入数量上限
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100)
|
||||
self.commit_interval = commit_interval or max(flush_interval * 2, 10.0)
|
||||
self.db_chunk_size = max(50, db_chunk_size)
|
||||
|
||||
self.pending_messages: deque = deque()
|
||||
self._prepared_buffer: list[dict[str, Any]] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_barrier = asyncio.Lock()
|
||||
self._flush_task = None
|
||||
self._running = False
|
||||
self._last_commit_ts = time.monotonic()
|
||||
|
||||
async def start(self):
|
||||
"""启动自动刷新任务"""
|
||||
if self._flush_task is None and not self._running:
|
||||
self._running = True
|
||||
self._last_commit_ts = time.monotonic()
|
||||
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
||||
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
|
||||
logger.info(
|
||||
"消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)",
|
||||
self.batch_size,
|
||||
self.flush_interval,
|
||||
self.commit_batch_size,
|
||||
self.commit_interval,
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
@@ -62,7 +88,7 @@ class MessageStorageBatcher:
|
||||
self._flush_task = None
|
||||
|
||||
# 刷新剩余的消息
|
||||
await self.flush()
|
||||
await self.flush(force=True)
|
||||
logger.info("消息存储批处理器已停止")
|
||||
|
||||
async def add_message(self, message_data: dict):
|
||||
@@ -76,61 +102,82 @@ class MessageStorageBatcher:
|
||||
'chat_stream': ChatStream
|
||||
}
|
||||
"""
|
||||
should_force_flush = False
|
||||
async with self._lock:
|
||||
self.pending_messages.append(message_data)
|
||||
|
||||
# 如果达到批量大小,立即刷新
|
||||
if len(self.pending_messages) >= self.batch_size:
|
||||
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
|
||||
await self.flush()
|
||||
should_force_flush = True
|
||||
|
||||
async def flush(self):
|
||||
"""执行批量写入"""
|
||||
async with self._lock:
|
||||
if not self.pending_messages:
|
||||
return
|
||||
if should_force_flush:
|
||||
logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新")
|
||||
await self.flush(force=True)
|
||||
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
async def flush(self, force: bool = False):
|
||||
"""执行批量写入, 支持强制落库和延迟提交策略。"""
|
||||
async with self._flush_barrier:
|
||||
async with self._lock:
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
|
||||
if not messages_to_store:
|
||||
if messages_to_store:
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
|
||||
await self._maybe_commit_buffer(force=force)
|
||||
|
||||
async def _maybe_commit_buffer(self, *, force: bool = False) -> None:
|
||||
"""根据阈值/时间窗口判断是否需要真正写库。"""
|
||||
if not self._prepared_buffer:
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
enough_rows = len(self._prepared_buffer) >= self.commit_batch_size
|
||||
waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval
|
||||
|
||||
if not (force or enough_rows or waited_long_enough):
|
||||
return
|
||||
|
||||
await self._write_buffer_to_database()
|
||||
|
||||
async def _write_buffer_to_database(self) -> None:
|
||||
payload = self._prepared_buffer
|
||||
if not payload:
|
||||
return
|
||||
|
||||
self._prepared_buffer = []
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
total = len(payload)
|
||||
|
||||
try:
|
||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||
messages_dicts = []
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"]
|
||||
)
|
||||
if message_dict:
|
||||
messages_dicts.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
continue
|
||||
|
||||
# 批量写入数据库 - 使用高效的批量INSERT
|
||||
if messages_dicts:
|
||||
from sqlalchemy import insert
|
||||
async with get_db_session() as session:
|
||||
stmt = insert(Messages).values(messages_dicts)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
success_count = len(messages_dicts)
|
||||
async with get_db_session() as session:
|
||||
for start in range(0, total, self.db_chunk_size):
|
||||
chunk = payload[start : start + self.db_chunk_size]
|
||||
if chunk:
|
||||
await session.execute(insert(Messages), chunk)
|
||||
await session.commit()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
self._last_commit_ts = time.monotonic()
|
||||
per_item = (elapsed / total) * 1000 if total else 0
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
|
||||
f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 回滚到缓冲区, 等待下一次尝试
|
||||
self._prepared_buffer = payload + self._prepared_buffer
|
||||
logger.error(f"批量存储消息失败: {e}")
|
||||
|
||||
async def _prepare_message_dict(self, message, chat_stream):
|
||||
|
||||
@@ -614,7 +614,7 @@ class DefaultReplyer:
|
||||
# 使用统一管理器的智能检索(Judge模型决策)
|
||||
search_result = await unified_manager.search_memories(
|
||||
query_text=query_text,
|
||||
use_judge=True,
|
||||
use_judge=global_config.memory.use_judge,
|
||||
recent_chat_history=chat_history, # 传递最近聊天历史
|
||||
)
|
||||
|
||||
@@ -1799,8 +1799,9 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if content:
|
||||
# 移除 [SPLIT] 标记,防止消息被分割
|
||||
content = content.replace("[SPLIT]", "")
|
||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm':
|
||||
# 移除 [SPLIT] 标记,防止消息被分割
|
||||
content = content.replace("[SPLIT]", "")
|
||||
|
||||
# 应用统一的格式过滤器
|
||||
from src.chat.utils.utils import filter_system_format_content
|
||||
|
||||
67
src/chat/semantic_interest/__init__.py
Normal file
67
src/chat/semantic_interest/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""语义兴趣度计算模块
|
||||
|
||||
基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统
|
||||
支持人设感知的自动训练和模型切换
|
||||
|
||||
2024.12 优化更新:
|
||||
- 新增 FastScorer:绕过 sklearn,使用 token→weight 字典直接计算
|
||||
- 全局线程池:避免重复创建 ThreadPoolExecutor
|
||||
- 批处理队列:攒消息一起算,提高 CPU 利用率
|
||||
- TF-IDF 降维:max_features 10000, ngram_range (2,3)
|
||||
- 权重剪枝:只保留高贡献 token
|
||||
"""
|
||||
|
||||
from .auto_trainer import AutoTrainer, get_auto_trainer
|
||||
from .dataset import DatasetGenerator, generate_training_dataset
|
||||
from .features_tfidf import TfidfFeatureExtractor
|
||||
from .model_lr import SemanticInterestModel, train_semantic_model
|
||||
from .optimized_scorer import (
|
||||
BatchScoringQueue,
|
||||
FastScorer,
|
||||
FastScorerConfig,
|
||||
clear_fast_scorer_instances,
|
||||
convert_sklearn_to_fast,
|
||||
get_fast_scorer,
|
||||
get_global_executor,
|
||||
shutdown_global_executor,
|
||||
)
|
||||
from .runtime_scorer import (
|
||||
ModelManager,
|
||||
SemanticInterestScorer,
|
||||
clear_scorer_instances,
|
||||
get_all_scorer_instances,
|
||||
get_semantic_scorer,
|
||||
get_semantic_scorer_sync,
|
||||
)
|
||||
from .trainer import SemanticInterestTrainer
|
||||
|
||||
__all__ = [
|
||||
# 运行时评分
|
||||
"SemanticInterestScorer",
|
||||
"ModelManager",
|
||||
"get_semantic_scorer", # 单例获取(异步)
|
||||
"get_semantic_scorer_sync", # 单例获取(同步)
|
||||
"clear_scorer_instances", # 清空单例
|
||||
"get_all_scorer_instances", # 查看所有实例
|
||||
# 优化评分器(推荐用于高频场景)
|
||||
"FastScorer",
|
||||
"FastScorerConfig",
|
||||
"BatchScoringQueue",
|
||||
"get_fast_scorer",
|
||||
"convert_sklearn_to_fast",
|
||||
"clear_fast_scorer_instances",
|
||||
"get_global_executor",
|
||||
"shutdown_global_executor",
|
||||
# 训练组件
|
||||
"TfidfFeatureExtractor",
|
||||
"SemanticInterestModel",
|
||||
"train_semantic_model",
|
||||
# 数据集生成
|
||||
"DatasetGenerator",
|
||||
"generate_training_dataset",
|
||||
# 训练器
|
||||
"SemanticInterestTrainer",
|
||||
# 自动训练
|
||||
"AutoTrainer",
|
||||
"get_auto_trainer",
|
||||
]
|
||||
375
src/chat/semantic_interest/auto_trainer.py
Normal file
375
src/chat/semantic_interest/auto_trainer.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""自动训练调度器
|
||||
|
||||
监控人设变化,自动触发模型训练和切换
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
||||
|
||||
logger = get_logger("semantic_interest.auto_trainer")
|
||||
|
||||
|
||||
class AutoTrainer:
|
||||
"""自动训练调度器
|
||||
|
||||
功能:
|
||||
1. 监控人设变化
|
||||
2. 自动构建训练数据集
|
||||
3. 定期重新训练模型
|
||||
4. 管理多个人设的模型
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | None = None,
|
||||
model_dir: Path | None = None,
|
||||
min_train_interval_hours: int = 720, # 最小训练间隔(小时,30天)
|
||||
min_samples_for_training: int = 100, # 最小训练样本数
|
||||
):
|
||||
"""初始化自动训练器
|
||||
|
||||
Args:
|
||||
data_dir: 数据集目录
|
||||
model_dir: 模型目录
|
||||
min_train_interval_hours: 最小训练间隔(小时)
|
||||
min_samples_for_training: 触发训练的最小样本数
|
||||
"""
|
||||
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
|
||||
self.model_dir = Path(model_dir or "data/semantic_interest/models")
|
||||
self.min_train_interval = timedelta(hours=min_train_interval_hours)
|
||||
self.min_samples = min_samples_for_training
|
||||
|
||||
# 人设状态缓存
|
||||
self.persona_cache_file = self.data_dir / "persona_cache.json"
|
||||
self.last_persona_hash: str | None = None
|
||||
self.last_train_time: datetime | None = None
|
||||
|
||||
# 训练器实例
|
||||
self.trainer = SemanticInterestTrainer(
|
||||
data_dir=self.data_dir,
|
||||
model_dir=self.model_dir,
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 加载缓存的人设状态
|
||||
self._load_persona_cache()
|
||||
|
||||
# 定时任务标志(防止重复启动)
|
||||
self._scheduled_task_running = False
|
||||
self._scheduled_task = None
|
||||
|
||||
logger.info("[自动训练器] 初始化完成")
|
||||
logger.info(f" - 数据目录: {self.data_dir}")
|
||||
logger.info(f" - 模型目录: {self.model_dir}")
|
||||
logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时")
|
||||
|
||||
def _load_persona_cache(self):
|
||||
"""加载缓存的人设状态"""
|
||||
if self.persona_cache_file.exists():
|
||||
try:
|
||||
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
|
||||
cache = json.load(f)
|
||||
self.last_persona_hash = cache.get("persona_hash")
|
||||
last_train_str = cache.get("last_train_time")
|
||||
if last_train_str:
|
||||
self.last_train_time = datetime.fromisoformat(last_train_str)
|
||||
logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 加载人设缓存失败: {e}")
|
||||
|
||||
def _save_persona_cache(self, persona_hash: str):
|
||||
"""保存人设状态到缓存"""
|
||||
cache = {
|
||||
"persona_hash": persona_hash,
|
||||
"last_train_time": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
with open(self.persona_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}")
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 保存人设缓存失败: {e}")
|
||||
|
||||
def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str:
|
||||
"""计算人设信息的哈希值
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息字典
|
||||
|
||||
Returns:
|
||||
SHA256 哈希值
|
||||
"""
|
||||
# 只关注影响模型的关键字段
|
||||
key_fields = {
|
||||
"name": persona_info.get("name", ""),
|
||||
"interests": sorted(persona_info.get("interests", [])),
|
||||
"dislikes": sorted(persona_info.get("dislikes", [])),
|
||||
"personality": persona_info.get("personality", ""),
|
||||
# 可选的更完整人设字段(存在则纳入哈希)
|
||||
"personality_core": persona_info.get("personality_core", ""),
|
||||
"personality_side": persona_info.get("personality_side", ""),
|
||||
"identity": persona_info.get("identity", ""),
|
||||
}
|
||||
|
||||
# 转为JSON并计算哈希
|
||||
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
def check_persona_changed(self, persona_info: dict[str, Any]) -> bool:
|
||||
"""检查人设是否发生变化
|
||||
|
||||
Args:
|
||||
persona_info: 当前人设信息
|
||||
|
||||
Returns:
|
||||
True 如果人设发生变化
|
||||
"""
|
||||
current_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
if self.last_persona_hash is None:
|
||||
logger.info("[自动训练器] 首次检测人设")
|
||||
return True
|
||||
|
||||
if current_hash != self.last_persona_hash:
|
||||
logger.info(f"[自动训练器] 检测到人设变化")
|
||||
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
||||
logger.info(f" - 新哈希: {current_hash[:8]}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
|
||||
"""判断是否应该训练模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
force: 强制训练
|
||||
|
||||
Returns:
|
||||
(是否应该训练, 原因说明)
|
||||
"""
|
||||
# 强制训练
|
||||
if force:
|
||||
return True, "强制训练"
|
||||
|
||||
# 检查人设是否变化
|
||||
persona_changed = self.check_persona_changed(persona_info)
|
||||
if persona_changed:
|
||||
return True, "人设发生变化"
|
||||
|
||||
# 检查训练间隔
|
||||
if self.last_train_time is None:
|
||||
return True, "从未训练过"
|
||||
|
||||
time_since_last_train = datetime.now() - self.last_train_time
|
||||
if time_since_last_train >= self.min_train_interval:
|
||||
return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时"
|
||||
|
||||
return False, "无需训练"
|
||||
|
||||
async def auto_train_if_needed(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
force: bool = False,
|
||||
) -> tuple[bool, Path | None]:
|
||||
"""自动训练(如果需要)
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
days: 采样天数
|
||||
max_samples: 最大采样数(默认1000条)
|
||||
force: 强制训练
|
||||
|
||||
Returns:
|
||||
(是否训练了, 模型路径)
|
||||
"""
|
||||
# 检查是否需要训练
|
||||
should_train, reason = self.should_train(persona_info, force)
|
||||
|
||||
if not should_train:
|
||||
logger.debug(f"[自动训练器] {reason},跳过训练")
|
||||
return False, None
|
||||
|
||||
logger.info(f"[自动训练器] 开始自动训练: {reason}")
|
||||
|
||||
try:
|
||||
# 计算人设哈希作为版本标识
|
||||
persona_hash = self._calculate_persona_hash(persona_info)
|
||||
model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# 执行训练
|
||||
dataset_path, model_path, metrics = await self.trainer.full_training_pipeline(
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_version=model_version,
|
||||
tfidf_config={
|
||||
"analyzer": "char",
|
||||
"ngram_range": (2, 4),
|
||||
"max_features": 10000,
|
||||
"min_df": 3,
|
||||
},
|
||||
model_config={
|
||||
"class_weight": "balanced",
|
||||
"max_iter": 1000,
|
||||
},
|
||||
)
|
||||
|
||||
# 更新缓存
|
||||
self.last_persona_hash = persona_hash
|
||||
self.last_train_time = datetime.now()
|
||||
self._save_persona_cache(persona_hash)
|
||||
|
||||
# 创建"latest"符号链接
|
||||
self._create_latest_link(model_path)
|
||||
|
||||
logger.info(f"[自动训练器] 训练完成!")
|
||||
logger.info(f" - 模型: {model_path.name}")
|
||||
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
return True, model_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 训练失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
def _create_latest_link(self, model_path: Path):
|
||||
"""创建指向最新模型的符号链接
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
|
||||
try:
|
||||
# 删除旧链接
|
||||
if latest_path.exists() or latest_path.is_symlink():
|
||||
latest_path.unlink()
|
||||
|
||||
# 创建新链接(Windows 需要管理员权限,使用复制代替)
|
||||
import shutil
|
||||
shutil.copy2(model_path, latest_path)
|
||||
|
||||
logger.info(f"[自动训练器] 已更新 latest 模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
||||
|
||||
async def scheduled_train(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
interval_hours: int = 24,
|
||||
):
|
||||
"""定时训练任务
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
# 检查是否已经有任务在运行
|
||||
if self._scheduled_task_running:
|
||||
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
return
|
||||
|
||||
self._scheduled_task_running = True
|
||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 检查并训练
|
||||
trained, model_path = await self.auto_train_if_needed(persona_info)
|
||||
|
||||
if trained:
|
||||
logger.info(f"[自动训练器] 定时训练完成: {model_path}")
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(interval_hours * 3600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 定时训练出错: {e}")
|
||||
# 出错后等待较短时间再试
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None:
|
||||
"""获取当前人设对应的模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
|
||||
Returns:
|
||||
模型文件路径,如果不存在则返回 None
|
||||
"""
|
||||
persona_hash = self._calculate_persona_hash(persona_info)
|
||||
|
||||
# 查找匹配的模型
|
||||
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
|
||||
matching_models = list(self.model_dir.glob(pattern))
|
||||
|
||||
if matching_models:
|
||||
# 返回最新的
|
||||
latest = max(matching_models, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
|
||||
return latest
|
||||
|
||||
# 没有找到,返回 latest
|
||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||
if latest_path.exists():
|
||||
logger.debug(f"[自动训练器] 使用 latest 模型")
|
||||
return latest_path
|
||||
|
||||
logger.warning(f"[自动训练器] 未找到可用模型")
|
||||
return None
|
||||
|
||||
def cleanup_old_models(self, keep_count: int = 5):
|
||||
"""清理旧模型文件
|
||||
|
||||
Args:
|
||||
keep_count: 保留最新的 N 个模型
|
||||
"""
|
||||
try:
|
||||
# 获取所有自动训练的模型
|
||||
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
|
||||
|
||||
if len(all_models) <= keep_count:
|
||||
return
|
||||
|
||||
# 按修改时间排序
|
||||
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
# 删除旧模型
|
||||
for old_model in all_models[keep_count:]:
|
||||
old_model.unlink()
|
||||
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
|
||||
|
||||
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[自动训练器] 清理模型失败: {e}")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_auto_trainer: AutoTrainer | None = None
|
||||
|
||||
|
||||
def get_auto_trainer() -> AutoTrainer:
|
||||
"""获取自动训练器单例"""
|
||||
global _auto_trainer
|
||||
if _auto_trainer is None:
|
||||
_auto_trainer = AutoTrainer()
|
||||
return _auto_trainer
|
||||
818
src/chat/semantic_interest/dataset.py
Normal file
818
src/chat/semantic_interest/dataset.py
Normal file
@@ -0,0 +1,818 @@
|
||||
"""数据集生成与 LLM 标注
|
||||
|
||||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("semantic_interest.dataset")
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""训练数据集生成器
|
||||
|
||||
从历史消息中采样并使用 LLM 进行标注
|
||||
"""
|
||||
|
||||
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
|
||||
HARD_MAX_SAMPLES = 2000
|
||||
|
||||
# 标注提示词模板(单条)
|
||||
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
||||
|
||||
## 人格信息
|
||||
{persona_info}
|
||||
|
||||
## 消息内容
|
||||
{message_text}
|
||||
|
||||
## 标注规则
|
||||
请判断角色对这条消息的兴趣程度,返回以下之一:
|
||||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||||
- **0**: 中立(可以回应但不特别感兴趣)
|
||||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||||
|
||||
只需返回数字 -1、0 或 1,不要其他内容。"""
|
||||
|
||||
# 批量标注提示词模板
|
||||
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
|
||||
|
||||
## 人格信息
|
||||
{persona_info}
|
||||
|
||||
## 标注规则
|
||||
对每条消息判断角色的兴趣程度:
|
||||
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
|
||||
- **0**: 中立(可以回应但不特别感兴趣)
|
||||
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
|
||||
|
||||
## 消息列表
|
||||
{messages_list}
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下JSON格式返回,每条消息一个标签:
|
||||
```json
|
||||
{example_output}
|
||||
```
|
||||
|
||||
只返回JSON,不要其他内容。"""
|
||||
|
||||
# 关键词生成提示词模板
|
||||
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
|
||||
|
||||
## 人格信息
|
||||
{persona_info}
|
||||
|
||||
## 任务说明
|
||||
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
|
||||
|
||||
1. **感兴趣的关键词**:包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等(约30-50个)
|
||||
2. **不感兴趣的关键词**:包括该角色不关心、反感、无聊的话题、价值观冲突的内容等(约30-50个)
|
||||
|
||||
## 输出格式
|
||||
请严格按照以下JSON格式返回:
|
||||
```json
|
||||
{{
|
||||
"interested": ["关键词1", "关键词2", "关键词3", ...],
|
||||
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
|
||||
}}
|
||||
```
|
||||
|
||||
注意:
|
||||
- 关键词可以是单个词语或短语(2-10个字)
|
||||
- 尽量覆盖多样化的话题和场景
|
||||
- 确保关键词与人格设定高度相关
|
||||
|
||||
只返回JSON,不要其他内容。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
max_samples_per_batch: int = 50,
|
||||
):
|
||||
"""初始化数据集生成器
|
||||
|
||||
Args:
|
||||
model_name: LLM 模型名称(None 则使用默认模型)
|
||||
max_samples_per_batch: 每批次最大采样数
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.max_samples_per_batch = max_samples_per_batch
|
||||
self.model_client = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化 LLM 客户端"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
# 使用 utilities 模型配置(标注更偏工具型)
|
||||
if hasattr(model_config.model_task_config, 'utils'):
|
||||
self.model_client = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="semantic_annotation"
|
||||
)
|
||||
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
|
||||
else:
|
||||
logger.error("未找到 utils 模型配置")
|
||||
self.model_client = None
|
||||
except ImportError as e:
|
||||
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
|
||||
self.model_client = None
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 客户端初始化失败: {e}")
|
||||
self.model_client = None
|
||||
|
||||
async def sample_messages(
|
||||
self,
|
||||
days: int = 7,
|
||||
min_length: int = 5,
|
||||
max_samples: int = 1000,
|
||||
priority_ranges: list[tuple[float, float]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从数据库采样消息(优化版:减少查询量和内存使用)
|
||||
|
||||
Args:
|
||||
days: 采样最近 N 天的消息
|
||||
min_length: 最小消息长度
|
||||
max_samples: 最大采样数量
|
||||
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
|
||||
|
||||
Returns:
|
||||
消息样本列表
|
||||
"""
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import Messages
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||
|
||||
# 限制采样数量硬上限
|
||||
requested_max_samples = max_samples
|
||||
if max_samples is None:
|
||||
max_samples = self.HARD_MAX_SAMPLES
|
||||
else:
|
||||
max_samples = int(max_samples)
|
||||
if max_samples <= 0:
|
||||
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
|
||||
return []
|
||||
if max_samples > self.HARD_MAX_SAMPLES:
|
||||
logger.warning(
|
||||
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES},"
|
||||
f"已截断为 {self.HARD_MAX_SAMPLES}"
|
||||
)
|
||||
max_samples = self.HARD_MAX_SAMPLES
|
||||
|
||||
# 查询条件
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
cutoff_ts = cutoff_time.timestamp()
|
||||
|
||||
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
|
||||
# 这样可以在保证足够样本的同时减少查询量
|
||||
prefetch_limit = int(max_samples * 1.5)
|
||||
|
||||
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
|
||||
query_builder = QueryBuilder(Messages)
|
||||
|
||||
# 过滤条件:时间范围 + 消息文本不为空
|
||||
messages = await query_builder.filter(
|
||||
time__gte=cutoff_ts,
|
||||
).order_by(
|
||||
"-time" # 按时间倒序,优先采样最新消息
|
||||
).limit(
|
||||
prefetch_limit # 限制预取数量
|
||||
).all(as_dict=True)
|
||||
|
||||
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit})")
|
||||
|
||||
# 过滤消息长度和提取文本
|
||||
filtered = []
|
||||
for msg in messages:
|
||||
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
||||
text = text.strip()
|
||||
if text and len(text) >= min_length:
|
||||
filtered.append({**msg, "message_text": text})
|
||||
# 达到目标数量即可停止
|
||||
if len(filtered) >= max_samples:
|
||||
break
|
||||
|
||||
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples})")
|
||||
|
||||
# 如果过滤后数量不足,记录警告
|
||||
if len(filtered) < max_samples:
|
||||
logger.warning(
|
||||
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples}),"
|
||||
f"可能需要扩大采样范围(增加 days 参数或降低 min_length)"
|
||||
)
|
||||
|
||||
# 随机打乱样本顺序(避免时间偏向)
|
||||
if len(filtered) > 0:
|
||||
random.shuffle(filtered)
|
||||
|
||||
# 转换为标准格式
|
||||
result = []
|
||||
for msg in filtered:
|
||||
result.append({
|
||||
"message_id": msg.get("message_id"),
|
||||
"user_id": msg.get("user_id"),
|
||||
"chat_id": msg.get("chat_id"),
|
||||
"message_text": msg.get("message_text", ""),
|
||||
"timestamp": msg.get("time"),
|
||||
"platform": msg.get("chat_info_platform"),
|
||||
})
|
||||
|
||||
logger.info(f"采样完成,共 {len(result)} 条消息")
|
||||
return result
|
||||
|
||||
async def generate_initial_keywords(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
temperature: float = 0.7,
|
||||
num_iterations: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""使用 LLM 生成初始关键词数据集
|
||||
|
||||
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息
|
||||
temperature: 生成温度(默认0.7,较高温度增加多样性)
|
||||
num_iterations: 重复生成次数(默认3次)
|
||||
|
||||
Returns:
|
||||
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
|
||||
"""
|
||||
if not self.model_client:
|
||||
await self.initialize()
|
||||
|
||||
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}次")
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.KEYWORD_GENERATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
)
|
||||
|
||||
all_keywords_data = []
|
||||
|
||||
# 重复生成多次
|
||||
for iteration in range(num_iterations):
|
||||
try:
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,跳过关键词生成")
|
||||
break
|
||||
|
||||
logger.info(f"第 {iteration + 1}/{num_iterations} 次生成关键词...")
|
||||
|
||||
# 调用 LLM(使用较高温度)
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=1000, # 关键词列表需要较多token
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# 解析响应(generate_response_async 返回元组)
|
||||
response_text = response[0] if isinstance(response, tuple) else response
|
||||
keywords_data = self._parse_keywords_response(response_text)
|
||||
|
||||
if keywords_data:
|
||||
interested = keywords_data.get("interested", [])
|
||||
not_interested = keywords_data.get("not_interested", [])
|
||||
|
||||
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
|
||||
|
||||
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
|
||||
for keyword in interested:
|
||||
if keyword and keyword.strip():
|
||||
all_keywords_data.append({
|
||||
"message_text": keyword.strip(),
|
||||
"label": 1,
|
||||
"source": "llm_generated_initial",
|
||||
"iteration": iteration + 1,
|
||||
})
|
||||
|
||||
for keyword in not_interested:
|
||||
if keyword and keyword.strip():
|
||||
all_keywords_data.append({
|
||||
"message_text": keyword.strip(),
|
||||
"label": -1,
|
||||
"source": "llm_generated_initial",
|
||||
"iteration": iteration + 1,
|
||||
})
|
||||
else:
|
||||
logger.warning(f"第 {iteration + 1} 次生成失败,未能解析关键词")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第 {iteration + 1} 次关键词生成失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in all_keywords_data:
|
||||
label = item["label"]
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
logger.info(f"标签分布: {label_counts}")
|
||||
|
||||
return all_keywords_data
|
||||
|
||||
def _parse_keywords_response(self, response: str) -> dict | None:
|
||||
"""解析关键词生成的JSON响应
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本
|
||||
|
||||
Returns:
|
||||
解析后的字典,包含 interested 和 not_interested 列表
|
||||
"""
|
||||
try:
|
||||
# 提取JSON部分(去除markdown代码块标记)
|
||||
response = response.strip()
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0].strip()
|
||||
|
||||
# 解析JSON
|
||||
import json_repair
|
||||
response = json_repair.repair_json(response)
|
||||
data = json.loads(response)
|
||||
|
||||
# 验证格式
|
||||
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
|
||||
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
|
||||
return data
|
||||
|
||||
logger.warning(f"关键词响应格式不正确: {data}")
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析关键词JSON失败: {e}")
|
||||
logger.debug(f"响应内容: {response}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析关键词响应失败: {e}")
|
||||
return None
|
||||
|
||||
async def annotate_message(
|
||||
self,
|
||||
message_text: str,
|
||||
persona_info: dict[str, Any],
|
||||
) -> int:
|
||||
"""使用 LLM 标注单条消息
|
||||
|
||||
Args:
|
||||
message_text: 消息文本
|
||||
persona_info: 人格信息
|
||||
|
||||
Returns:
|
||||
标签 (-1, 0, 1)
|
||||
"""
|
||||
if not self.model_client:
|
||||
await self.initialize()
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.ANNOTATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
message_text=message_text,
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||||
return 0
|
||||
|
||||
# 调用 LLM
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=10,
|
||||
temperature=0.1, # 低温度保证一致性
|
||||
)
|
||||
|
||||
# 解析响应(generate_response_async 返回元组)
|
||||
response_text = response[0] if isinstance(response, tuple) else response
|
||||
label = self._parse_label(response_text)
|
||||
return label
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 标注失败: {e}")
|
||||
return 0 # 默认返回中立
|
||||
|
||||
async def annotate_batch(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
persona_info: dict[str, Any],
|
||||
save_path: Path | None = None,
|
||||
batch_size: int = 50,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""批量标注消息(真正的批量模式)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
persona_info: 人格信息
|
||||
save_path: 保存路径(可选)
|
||||
batch_size: 每次LLM请求处理的消息数(默认20)
|
||||
|
||||
Returns:
|
||||
标注后的数据集
|
||||
"""
|
||||
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size} 条")
|
||||
|
||||
annotated_data = []
|
||||
|
||||
for i in range(0, len(messages), batch_size):
|
||||
batch = messages[i : i + batch_size]
|
||||
|
||||
# 批量标注(一次LLM请求处理多条消息)
|
||||
labels = await self._annotate_batch_llm(batch, persona_info)
|
||||
|
||||
# 保存结果
|
||||
for msg, label in zip(batch, labels):
|
||||
annotated_data.append({
|
||||
"message_id": msg["message_id"],
|
||||
"message_text": msg["message_text"],
|
||||
"label": label,
|
||||
"user_id": msg.get("user_id"),
|
||||
"chat_id": msg.get("chat_id"),
|
||||
"timestamp": msg.get("timestamp"),
|
||||
})
|
||||
|
||||
logger.info(f"已标注 {len(annotated_data)}/{len(messages)} 条")
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in annotated_data:
|
||||
label = item["label"]
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
|
||||
logger.info(f"标注完成,标签分布: {label_counts}")
|
||||
|
||||
# 保存到文件
|
||||
if save_path:
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"数据集已保存到: {save_path}")
|
||||
|
||||
return annotated_data
|
||||
|
||||
async def _annotate_batch_llm(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
persona_info: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""使用一次LLM请求标注多条消息
|
||||
|
||||
Args:
|
||||
messages: 消息列表(通常20条)
|
||||
persona_info: 人格信息
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
if not self.model_client:
|
||||
logger.warning("LLM 客户端未初始化,返回默认标签")
|
||||
return [0] * len(messages)
|
||||
|
||||
# 构造人格描述
|
||||
persona_desc = self._format_persona_info(persona_info)
|
||||
|
||||
# 构造消息列表
|
||||
messages_list = ""
|
||||
for idx, msg in enumerate(messages, 1):
|
||||
messages_list += f"{idx}. {msg['message_text']}\n"
|
||||
|
||||
# 构造示例输出
|
||||
example_output = json.dumps(
|
||||
{str(i): 0 for i in range(1, len(messages) + 1)},
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
# 构造提示词
|
||||
prompt = self.BATCH_ANNOTATION_PROMPT.format(
|
||||
persona_info=persona_desc,
|
||||
messages_list=messages_list,
|
||||
example_output=example_output,
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用 LLM(使用更大的token限制)
|
||||
response = await self.model_client.generate_response_async(
|
||||
prompt=prompt,
|
||||
max_tokens=500, # 批量标注需要更多token
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# 解析批量响应(generate_response_async 返回元组)
|
||||
response_text = response[0] if isinstance(response, tuple) else response
|
||||
labels = self._parse_batch_labels(response_text, len(messages))
|
||||
return labels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量LLM标注失败: {e},返回默认值")
|
||||
return [0] * len(messages)
|
||||
|
||||
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
|
||||
"""格式化人格信息
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息字典
|
||||
|
||||
Returns:
|
||||
格式化后的人格描述
|
||||
"""
|
||||
def _stringify(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return "、".join([str(v) for v in value if v is not None and str(v).strip()])
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True)
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value).strip()
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
name = _stringify(persona_info.get("name"))
|
||||
if name:
|
||||
parts.append(f"角色名称: {name}")
|
||||
|
||||
# 核心/侧面/身份等完整人设信息
|
||||
personality_core = _stringify(persona_info.get("personality_core"))
|
||||
if personality_core:
|
||||
parts.append(f"核心人设: {personality_core}")
|
||||
|
||||
personality_side = _stringify(persona_info.get("personality_side"))
|
||||
if personality_side:
|
||||
parts.append(f"侧面特质: {personality_side}")
|
||||
|
||||
identity = _stringify(persona_info.get("identity"))
|
||||
if identity:
|
||||
parts.append(f"身份特征: {identity}")
|
||||
|
||||
# 追加其他未覆盖字段(保持信息完整)
|
||||
known_keys = {
|
||||
"name",
|
||||
"personality_core",
|
||||
"personality_side",
|
||||
"identity",
|
||||
}
|
||||
for key, value in persona_info.items():
|
||||
if key in known_keys:
|
||||
continue
|
||||
value_str = _stringify(value)
|
||||
if value_str:
|
||||
parts.append(f"{key}: {value_str}")
|
||||
|
||||
return "\n".join(parts) if parts else "无特定人格设定"
|
||||
|
||||
def _parse_label(self, response: str) -> int:
|
||||
"""解析 LLM 响应为标签
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本
|
||||
|
||||
Returns:
|
||||
标签 (-1, 0, 1)
|
||||
"""
|
||||
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple,这里取首元素并转为字符串
|
||||
if isinstance(response, (tuple, list)):
|
||||
response = response[0] if response else ""
|
||||
response = str(response).strip()
|
||||
|
||||
# 尝试直接解析数字
|
||||
if response in ["-1", "0", "1"]:
|
||||
return int(response)
|
||||
|
||||
# 尝试提取数字
|
||||
if "-1" in response:
|
||||
return -1
|
||||
elif "1" in response:
|
||||
return 1
|
||||
elif "0" in response:
|
||||
return 0
|
||||
|
||||
# 默认返回中立
|
||||
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
|
||||
return 0
|
||||
|
||||
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
|
||||
"""解析批量LLM响应为标签列表
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本(JSON格式)
|
||||
expected_count: 期望的标签数量
|
||||
|
||||
Returns:
|
||||
标签列表
|
||||
"""
|
||||
try:
|
||||
# 兼容 tuple/list 返回格式
|
||||
if isinstance(response, (tuple, list)):
|
||||
response = response[0] if response else ""
|
||||
response = str(response)
|
||||
|
||||
# 提取JSON内容
|
||||
import re
|
||||
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 尝试直接解析
|
||||
json_str = response
|
||||
import json_repair
|
||||
# 解析JSON
|
||||
labels_json = json_repair.repair_json(json_str)
|
||||
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
|
||||
|
||||
# 转换为列表
|
||||
labels = []
|
||||
for i in range(1, expected_count + 1):
|
||||
key = str(i)
|
||||
# 检查是否为字典且包含该键
|
||||
if isinstance(labels_dict, dict) and key in labels_dict:
|
||||
label = labels_dict[key]
|
||||
# 确保标签值有效
|
||||
if label in [-1, 0, 1]:
|
||||
labels.append(label)
|
||||
else:
|
||||
logger.warning(f"无效标签值 {label},使用默认值 0")
|
||||
labels.append(0)
|
||||
else:
|
||||
# 尝试从值列表或数组中顺序取值
|
||||
if isinstance(labels_dict, list) and len(labels_dict) >= i:
|
||||
label = labels_dict[i - 1]
|
||||
labels.append(label if label in [-1, 0, 1] else 0)
|
||||
else:
|
||||
labels.append(0)
|
||||
|
||||
if len(labels) != expected_count:
|
||||
logger.warning(
|
||||
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)},"
|
||||
f"补齐为 {expected_count}"
|
||||
)
|
||||
# 补齐或截断
|
||||
if len(labels) < expected_count:
|
||||
labels.extend([0] * (expected_count - len(labels)))
|
||||
else:
|
||||
labels = labels[:expected_count]
|
||||
|
||||
return labels
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
|
||||
return [0] * expected_count
|
||||
except Exception as e:
|
||||
# 兜底:尝试直接提取所有标签数字
|
||||
try:
|
||||
import re
|
||||
numbers = re.findall(r"-?1|0", response)
|
||||
labels = [int(n) for n in numbers[:expected_count]]
|
||||
if len(labels) < expected_count:
|
||||
labels.extend([0] * (expected_count - len(labels)))
|
||||
return labels
|
||||
except Exception:
|
||||
logger.error(f"批量标签解析失败: {e}")
|
||||
return [0] * expected_count
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
|
||||
"""加载训练数据集
|
||||
|
||||
Args:
|
||||
path: 数据集文件路径
|
||||
|
||||
Returns:
|
||||
(文本列表, 标签列表)
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
texts = [item["message_text"] for item in data]
|
||||
labels = [item["label"] for item in data]
|
||||
|
||||
logger.info(f"加载数据集: {len(texts)} 条样本")
|
||||
return texts, labels
|
||||
|
||||
|
||||
async def generate_training_dataset(
|
||||
output_path: Path,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
model_name: str | None = None,
|
||||
generate_initial_keywords: bool = True,
|
||||
keyword_temperature: float = 0.7,
|
||||
keyword_iterations: int = 3,
|
||||
) -> Path:
|
||||
"""生成训练数据集(主函数)
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
persona_info: 人格信息
|
||||
days: 采样最近 N 天的消息
|
||||
max_samples: 最大采样数
|
||||
model_name: LLM 模型名称
|
||||
generate_initial_keywords: 是否生成初始关键词数据集(默认True)
|
||||
keyword_temperature: 关键词生成温度(默认0.7)
|
||||
keyword_iterations: 关键词生成迭代次数(默认3)
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
generator = DatasetGenerator(model_name=model_name)
|
||||
await generator.initialize()
|
||||
|
||||
# 第一步:生成初始关键词数据集(如果启用)
|
||||
initial_keywords_data = []
|
||||
if generate_initial_keywords:
|
||||
logger.info("=" * 60)
|
||||
logger.info("步骤 1/3: 生成初始关键词数据集")
|
||||
logger.info("=" * 60)
|
||||
initial_keywords_data = await generator.generate_initial_keywords(
|
||||
persona_info=persona_info,
|
||||
temperature=keyword_temperature,
|
||||
num_iterations=keyword_iterations,
|
||||
)
|
||||
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)} 条")
|
||||
else:
|
||||
logger.info("跳过初始关键词生成")
|
||||
|
||||
# 第二步:采样真实消息
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
|
||||
logger.info("=" * 60)
|
||||
messages = await generator.sample_messages(
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
logger.info(f"✓ 消息采样完成: {len(messages)} 条")
|
||||
|
||||
# 第三步:批量标注真实消息
|
||||
logger.info("=" * 60)
|
||||
logger.info("步骤 3/3: LLM 标注真实消息")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 注意:不保存到文件,返回标注后的数据
|
||||
annotated_messages = await generator.annotate_batch(
|
||||
messages=messages,
|
||||
persona_info=persona_info,
|
||||
save_path=None, # 暂不保存
|
||||
)
|
||||
logger.info(f"✓ 消息标注完成: {len(annotated_messages)} 条")
|
||||
|
||||
# 第四步:合并数据集
|
||||
logger.info("=" * 60)
|
||||
logger.info("步骤 4/4: 合并数据集")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
|
||||
combined_dataset = []
|
||||
|
||||
# 添加初始关键词数据
|
||||
if initial_keywords_data:
|
||||
combined_dataset.extend(initial_keywords_data)
|
||||
logger.info(f" + 初始关键词: {len(initial_keywords_data)} 条")
|
||||
|
||||
# 添加标注后的消息
|
||||
combined_dataset.extend(annotated_messages)
|
||||
logger.info(f" + 标注消息: {len(annotated_messages)} 条")
|
||||
|
||||
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
|
||||
|
||||
# 统计标签分布
|
||||
label_counts = {}
|
||||
for item in combined_dataset:
|
||||
label = item.get("label", 0)
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
logger.info(f" 最终标签分布: {label_counts}")
|
||||
|
||||
# 保存合并后的数据集
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"✓ 训练数据集已保存: {output_path}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return output_path
|
||||
|
||||
147
src/chat/semantic_interest/features_tfidf.py
Normal file
147
src/chat/semantic_interest/features_tfidf.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""TF-IDF 特征向量化器
|
||||
|
||||
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.features")
|
||||
|
||||
|
||||
class TfidfFeatureExtractor:
|
||||
"""TF-IDF 特征提取器
|
||||
|
||||
使用字符级 n-gram 策略,适合中文/多语言场景
|
||||
|
||||
优化说明(2024.12):
|
||||
- max_features 从 20000 降到 10000,减少计算量
|
||||
- ngram_range 默认 (2, 3),对于兴趣任务足够
|
||||
- min_df 提高到 3,过滤低频噪声
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
analyzer: str = "char", # type: ignore
|
||||
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
|
||||
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
|
||||
min_df: int = 3, # 优化:过滤低频 n-gram
|
||||
max_df: float = 0.95,
|
||||
):
|
||||
"""初始化特征提取器
|
||||
|
||||
Args:
|
||||
analyzer: 分析器类型 ('char' 或 'word')
|
||||
ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram
|
||||
max_features: 词表最大大小,防止特征爆炸
|
||||
min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表
|
||||
max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词)
|
||||
"""
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
analyzer=analyzer,
|
||||
ngram_range=ngram_range,
|
||||
max_features=max_features,
|
||||
min_df=min_df,
|
||||
max_df=max_df,
|
||||
lowercase=True,
|
||||
strip_accents=None, # 保留中文字符
|
||||
sublinear_tf=True, # 使用对数 TF 缩放
|
||||
norm="l2", # L2 归一化
|
||||
)
|
||||
self.is_fitted = False
|
||||
|
||||
logger.info(
|
||||
f"TF-IDF 特征提取器初始化: analyzer={analyzer}, "
|
||||
f"ngram_range={ngram_range}, max_features={max_features}"
|
||||
)
|
||||
|
||||
def fit(self, texts: list[str]) -> "TfidfFeatureExtractor":
|
||||
"""训练向量化器
|
||||
|
||||
Args:
|
||||
texts: 训练文本列表
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
|
||||
self.vectorizer.fit(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, texts: list[str]):
|
||||
"""将文本转换为 TF-IDF 向量
|
||||
|
||||
Args:
|
||||
texts: 待转换文本列表
|
||||
|
||||
Returns:
|
||||
稀疏矩阵
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
|
||||
|
||||
return self.vectorizer.transform(texts)
|
||||
|
||||
def fit_transform(self, texts: list[str]):
|
||||
"""训练并转换文本
|
||||
|
||||
Args:
|
||||
texts: 训练文本列表
|
||||
|
||||
Returns:
|
||||
稀疏矩阵
|
||||
"""
|
||||
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
|
||||
result = self.vectorizer.fit_transform(texts)
|
||||
self.is_fitted = True
|
||||
|
||||
vocab_size = len(self.vectorizer.vocabulary_)
|
||||
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
|
||||
|
||||
return result
|
||||
|
||||
def get_feature_names(self) -> list[str]:
|
||||
"""获取特征名称列表
|
||||
|
||||
Returns:
|
||||
特征名称列表
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("向量化器尚未训练")
|
||||
|
||||
return self.vectorizer.get_feature_names_out().tolist()
|
||||
|
||||
def get_vocabulary_size(self) -> int:
|
||||
"""获取词表大小
|
||||
|
||||
Returns:
|
||||
词表大小
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
return 0
|
||||
return len(self.vectorizer.vocabulary_)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""获取配置信息
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
params = self.vectorizer.get_params()
|
||||
return {
|
||||
"analyzer": params["analyzer"],
|
||||
"ngram_range": params["ngram_range"],
|
||||
"max_features": params["max_features"],
|
||||
"min_df": params["min_df"],
|
||||
"max_df": params["max_df"],
|
||||
"vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0,
|
||||
"is_fitted": self.is_fitted,
|
||||
}
|
||||
263
src/chat/semantic_interest/model_lr.py
Normal file
263
src/chat/semantic_interest/model_lr.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Logistic Regression 模型训练与推理
|
||||
|
||||
使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1)
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
|
||||
logger = get_logger("semantic_interest.model")
|
||||
|
||||
|
||||
class SemanticInterestModel:
|
||||
"""语义兴趣度模型
|
||||
|
||||
使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
class_weight: str | dict | None = "balanced",
|
||||
max_iter: int = 1000,
|
||||
solver: str = "lbfgs", # type: ignore
|
||||
n_jobs: int = -1,
|
||||
):
|
||||
"""初始化模型
|
||||
|
||||
Args:
|
||||
class_weight: 类别权重配置
|
||||
- "balanced": 自动平衡类别权重
|
||||
- dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6}
|
||||
- None: 不使用权重
|
||||
max_iter: 最大迭代次数
|
||||
solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等)
|
||||
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
|
||||
"""
|
||||
self.clf = LogisticRegression(
|
||||
solver=solver,
|
||||
max_iter=max_iter,
|
||||
class_weight=class_weight,
|
||||
n_jobs=n_jobs,
|
||||
random_state=42,
|
||||
)
|
||||
self.is_fitted = False
|
||||
self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射
|
||||
self.training_metrics = {}
|
||||
|
||||
logger.info(
|
||||
f"Logistic Regression 模型初始化: class_weight={class_weight}, "
|
||||
f"max_iter={max_iter}, solver={solver}"
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
X_train,
|
||||
y_train,
|
||||
X_val=None,
|
||||
y_val=None,
|
||||
verbose: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X_train: 训练集特征矩阵
|
||||
y_train: 训练集标签(-1, 0, 1)
|
||||
X_val: 验证集特征矩阵(可选)
|
||||
y_val: 验证集标签(可选)
|
||||
verbose: 是否输出详细日志
|
||||
|
||||
Returns:
|
||||
训练指标字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"开始训练模型,训练样本数: {len(y_train)}")
|
||||
|
||||
# 训练模型
|
||||
self.clf.fit(X_train, y_train)
|
||||
self.is_fitted = True
|
||||
|
||||
training_time = time.time() - start_time
|
||||
logger.info(f"模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# 计算训练集指标
|
||||
y_train_pred = self.clf.predict(X_train)
|
||||
train_accuracy = (y_train_pred == y_train).mean()
|
||||
|
||||
metrics = {
|
||||
"training_time": training_time,
|
||||
"train_accuracy": train_accuracy,
|
||||
"train_samples": len(y_train),
|
||||
}
|
||||
|
||||
if verbose:
|
||||
logger.info(f"训练集准确率: {train_accuracy:.4f}")
|
||||
logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}")
|
||||
|
||||
# 如果提供了验证集,计算验证指标
|
||||
if X_val is not None and y_val is not None:
|
||||
val_metrics = self.evaluate(X_val, y_val, verbose=verbose)
|
||||
metrics.update(val_metrics)
|
||||
|
||||
self.training_metrics = metrics
|
||||
return metrics
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
X_test,
|
||||
y_test,
|
||||
verbose: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""评估模型
|
||||
|
||||
Args:
|
||||
X_test: 测试集特征矩阵
|
||||
y_test: 测试集标签
|
||||
verbose: 是否输出详细日志
|
||||
|
||||
Returns:
|
||||
评估指标字典
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
y_pred = self.clf.predict(X_test)
|
||||
accuracy = (y_pred == y_test).mean()
|
||||
|
||||
metrics = {
|
||||
"test_accuracy": accuracy,
|
||||
"test_samples": len(y_test),
|
||||
}
|
||||
|
||||
if verbose:
|
||||
logger.info(f"测试集准确率: {accuracy:.4f}")
|
||||
logger.info("\n分类报告:")
|
||||
report = classification_report(
|
||||
y_test,
|
||||
y_pred,
|
||||
labels=[-1, 0, 1],
|
||||
target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"],
|
||||
zero_division=0,
|
||||
)
|
||||
logger.info(f"\n{report}")
|
||||
|
||||
logger.info("\n混淆矩阵:")
|
||||
cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1])
|
||||
logger.info(f"\n{cm}")
|
||||
|
||||
return metrics
|
||||
|
||||
def predict_proba(self, X) -> np.ndarray:
|
||||
"""预测概率分布
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
|
||||
Returns:
|
||||
概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
proba = self.clf.predict_proba(X)
|
||||
|
||||
# 确保类别顺序为 [-1, 0, 1]
|
||||
classes = self.clf.classes_
|
||||
if not np.array_equal(classes, [-1, 0, 1]):
|
||||
# 需要重新排序
|
||||
sorted_proba = np.zeros_like(proba)
|
||||
for i, cls in enumerate([-1, 0, 1]):
|
||||
idx = np.where(classes == cls)[0]
|
||||
if len(idx) > 0:
|
||||
sorted_proba[:, i] = proba[:, idx[0]]
|
||||
return sorted_proba
|
||||
|
||||
return proba
|
||||
|
||||
def predict(self, X) -> np.ndarray:
|
||||
"""预测类别
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
|
||||
Returns:
|
||||
预测标签数组
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError("模型尚未训练")
|
||||
|
||||
return self.clf.predict(X)
|
||||
|
||||
def get_config(self) -> dict:
|
||||
"""获取模型配置
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
"""
|
||||
params = self.clf.get_params()
|
||||
return {
|
||||
"solver": params["solver"],
|
||||
"max_iter": params["max_iter"],
|
||||
"class_weight": params["class_weight"],
|
||||
"is_fitted": self.is_fitted,
|
||||
"classes": self.clf.classes_.tolist() if self.is_fitted else None,
|
||||
}
|
||||
|
||||
|
||||
def train_semantic_model(
|
||||
texts: list[str],
|
||||
labels: list[int],
|
||||
test_size: float = 0.1,
|
||||
random_state: int = 42,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]:
|
||||
"""训练完整的语义兴趣度模型
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
labels: 对应的标签列表 (-1, 0, 1)
|
||||
test_size: 验证集比例
|
||||
random_state: 随机种子
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
|
||||
Returns:
|
||||
(特征提取器, 模型, 训练指标)
|
||||
"""
|
||||
logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}")
|
||||
|
||||
# 划分训练集和验证集
|
||||
X_train_texts, X_val_texts, y_train, y_val = train_test_split(
|
||||
texts,
|
||||
labels,
|
||||
test_size=test_size,
|
||||
stratify=labels,
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}")
|
||||
|
||||
# 初始化并训练 TF-IDF 向量化器
|
||||
tfidf_config = tfidf_config or {}
|
||||
feature_extractor = TfidfFeatureExtractor(**tfidf_config)
|
||||
X_train = feature_extractor.fit_transform(X_train_texts)
|
||||
X_val = feature_extractor.transform(X_val_texts)
|
||||
|
||||
# 初始化并训练模型
|
||||
model_config = model_config or {}
|
||||
model = SemanticInterestModel(**model_config)
|
||||
metrics = model.train(X_train, y_train, X_val, y_val)
|
||||
|
||||
logger.info("语义兴趣度模型训练完成")
|
||||
|
||||
return feature_extractor, model, metrics
|
||||
641
src/chat/semantic_interest/optimized_scorer.py
Normal file
641
src/chat/semantic_interest/optimized_scorer.py
Normal file
@@ -0,0 +1,641 @@
|
||||
"""优化的语义兴趣度评分器
|
||||
|
||||
实现关键优化:
|
||||
1. TF-IDF + LR 权重融合为 token→weight 字典
|
||||
2. 稀疏权重剪枝(只保留高贡献 token)
|
||||
3. 全局线程池 + 异步调度
|
||||
4. 批处理队列系统
|
||||
5. 绕过 sklearn 的纯 Python scorer
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("semantic_interest.optimized")
|
||||
|
||||
# ============================================================================
|
||||
# 全局线程池(避免每次创建新的 executor)
|
||||
# ============================================================================
|
||||
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
|
||||
_EXECUTOR_LOCK = asyncio.Lock()
|
||||
|
||||
def get_global_executor(max_workers: int = 4) -> ThreadPoolExecutor:
|
||||
"""获取全局线程池(单例)"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is None:
|
||||
_GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="semantic_scorer")
|
||||
logger.info(f"[优化评分器] 创建全局线程池,workers={max_workers}")
|
||||
return _GLOBAL_EXECUTOR
|
||||
|
||||
|
||||
def shutdown_global_executor():
|
||||
"""关闭全局线程池"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is not None:
|
||||
_GLOBAL_EXECUTOR.shutdown(wait=False)
|
||||
_GLOBAL_EXECUTOR = None
|
||||
logger.info("[优化评分器] 全局线程池已关闭")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 快速评分器(绕过 sklearn)
|
||||
# ============================================================================
|
||||
@dataclass
|
||||
class FastScorerConfig:
|
||||
"""快速评分器配置"""
|
||||
# n-gram 参数
|
||||
analyzer: str = "char"
|
||||
ngram_range: tuple[int, int] = (2, 4)
|
||||
lowercase: bool = True
|
||||
|
||||
# 权重剪枝阈值(绝对值小于此值的权重视为 0)
|
||||
weight_prune_threshold: float = 1e-4
|
||||
|
||||
# 只保留 top-k 权重(0 表示不限制)
|
||||
top_k_weights: int = 0
|
||||
|
||||
# sigmoid 缩放因子
|
||||
sigmoid_alpha: float = 1.0
|
||||
|
||||
# 评分超时(秒)
|
||||
score_timeout: float = 2.0
|
||||
|
||||
|
||||
class FastScorer:
|
||||
"""快速语义兴趣度评分器
|
||||
|
||||
将 TF-IDF + LR 融合成一个纯 Python 的 token→weight 字典 scorer。
|
||||
|
||||
核心公式:
|
||||
- TF-IDF: x_i = tf_i * idf_i
|
||||
- LR: z = Σ_i (w_i * x_i) + b = Σ_i (w_i * idf_i * tf_i) + b
|
||||
- 定义 w'_i = w_i * idf_i,则 z = Σ_i (w'_i * tf_i) + b
|
||||
|
||||
这样在线评分只需要:
|
||||
1. 手动做 n-gram tokenize
|
||||
2. 统计 tf
|
||||
3. 查表 w'_i,累加求和
|
||||
4. sigmoid 转 [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(self, config: FastScorerConfig | None = None):
|
||||
"""初始化快速评分器"""
|
||||
self.config = config or FastScorerConfig()
|
||||
|
||||
# 融合后的权重字典: {token: combined_weight}
|
||||
# 对于三分类,我们计算 z_interest = z_pos - z_neg
|
||||
# 所以 combined_weight = (w_pos - w_neg) * idf
|
||||
self.token_weights: dict[str, float] = {}
|
||||
|
||||
# 偏置项: bias_pos - bias_neg
|
||||
self.bias: float = 0.0
|
||||
|
||||
# 元信息
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
# 统计
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
# n-gram 正则(预编译)
|
||||
self._tokenize_pattern = re.compile(r'\s+')
|
||||
|
||||
@classmethod
|
||||
def from_sklearn_model(
|
||||
cls,
|
||||
vectorizer, # TfidfVectorizer 或 TfidfFeatureExtractor
|
||||
model, # SemanticInterestModel 或 LogisticRegression
|
||||
config: FastScorerConfig | None = None,
|
||||
) -> "FastScorer":
|
||||
"""从 sklearn 模型创建快速评分器
|
||||
|
||||
Args:
|
||||
vectorizer: TF-IDF 向量化器
|
||||
model: Logistic Regression 模型
|
||||
config: 配置
|
||||
|
||||
Returns:
|
||||
FastScorer 实例
|
||||
"""
|
||||
scorer = cls(config)
|
||||
scorer._extract_weights(vectorizer, model)
|
||||
return scorer
|
||||
|
||||
def _extract_weights(self, vectorizer, model):
|
||||
"""从 sklearn 模型提取并融合权重
|
||||
|
||||
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
||||
"""
|
||||
# 获取底层 sklearn 对象
|
||||
if hasattr(vectorizer, 'vectorizer'):
|
||||
# TfidfFeatureExtractor 包装类
|
||||
tfidf = vectorizer.vectorizer
|
||||
else:
|
||||
tfidf = vectorizer
|
||||
|
||||
if hasattr(model, 'clf'):
|
||||
# SemanticInterestModel 包装类
|
||||
clf = model.clf
|
||||
else:
|
||||
clf = model
|
||||
|
||||
# 获取词表和 IDF
|
||||
vocabulary = tfidf.vocabulary_ # {token: index}
|
||||
idf = tfidf.idf_ # numpy array, shape (n_features,)
|
||||
|
||||
# 获取 LR 权重
|
||||
# clf.coef_ shape: (n_classes, n_features) 对于多分类
|
||||
# classes_ 顺序应该是 [-1, 0, 1]
|
||||
coef = clf.coef_ # shape (3, n_features)
|
||||
intercept = clf.intercept_ # shape (3,)
|
||||
classes = clf.classes_
|
||||
|
||||
# 找到 -1 和 1 的索引
|
||||
idx_neg = np.where(classes == -1)[0][0]
|
||||
idx_pos = np.where(classes == 1)[0][0]
|
||||
|
||||
# 计算 z_interest = z_pos - z_neg 的权重
|
||||
w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,)
|
||||
b_interest = intercept[idx_pos] - intercept[idx_neg]
|
||||
|
||||
# 融合: combined_weight = w_interest * idf
|
||||
combined_weights = w_interest * idf
|
||||
|
||||
# 构建 token→weight 字典
|
||||
token_weights = {}
|
||||
for token, idx in vocabulary.items():
|
||||
weight = combined_weights[idx]
|
||||
# 权重剪枝
|
||||
if abs(weight) >= self.config.weight_prune_threshold:
|
||||
token_weights[token] = weight
|
||||
|
||||
# 如果设置了 top-k 限制
|
||||
if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights:
|
||||
# 按绝对值排序,保留 top-k
|
||||
sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True)
|
||||
token_weights = dict(sorted_items[:self.config.top_k_weights])
|
||||
|
||||
self.token_weights = token_weights
|
||||
self.bias = float(b_interest)
|
||||
self.is_loaded = True
|
||||
|
||||
# 更新元信息
|
||||
self.meta = {
|
||||
"original_vocab_size": len(vocabulary),
|
||||
"pruned_vocab_size": len(token_weights),
|
||||
"prune_ratio": 1 - len(token_weights) / len(vocabulary) if vocabulary else 0,
|
||||
"weight_prune_threshold": self.config.weight_prune_threshold,
|
||||
"top_k_weights": self.config.top_k_weights,
|
||||
"bias": self.bias,
|
||||
"ngram_range": self.config.ngram_range,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[FastScorer] 权重提取完成: "
|
||||
f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, "
|
||||
f"剪枝率={self.meta['prune_ratio']:.2%}"
|
||||
)
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""将文本转换为 n-gram tokens
|
||||
|
||||
与 sklearn 的 char n-gram 保持一致
|
||||
"""
|
||||
if self.config.lowercase:
|
||||
text = text.lower()
|
||||
|
||||
# 字符级 n-gram
|
||||
min_n, max_n = self.config.ngram_range
|
||||
tokens = []
|
||||
|
||||
for n in range(min_n, max_n + 1):
|
||||
for i in range(len(text) - n + 1):
|
||||
tokens.append(text[i:i + n])
|
||||
|
||||
return tokens
|
||||
|
||||
def _compute_tf(self, tokens: list[str]) -> dict[str, float]:
|
||||
"""计算词频(TF)
|
||||
|
||||
注意:sklearn 使用 sublinear_tf=True 时是 1 + log(tf)
|
||||
这里简化为原始计数,因为对于短消息差异不大
|
||||
"""
|
||||
return dict(Counter(tokens))
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0]
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Tokenize
|
||||
tokens = self._tokenize(text)
|
||||
|
||||
if not tokens:
|
||||
return 0.5 # 空文本返回中立值
|
||||
|
||||
# 2. 计算 TF
|
||||
tf = self._compute_tf(tokens)
|
||||
|
||||
# 3. 加权求和: z = Σ (w'_i * tf_i) + b
|
||||
z = self.bias
|
||||
for token, count in tf.items():
|
||||
if token in self.token_weights:
|
||||
z += self.token_weights[token] * count
|
||||
|
||||
# 4. Sigmoid 转换
|
||||
# interest = 1 / (1 + exp(-α * z))
|
||||
alpha = self.config.sigmoid_alpha
|
||||
try:
|
||||
interest = 1.0 / (1.0 + math.exp(-alpha * z))
|
||||
except OverflowError:
|
||||
interest = 0.0 if z < 0 else 1.0
|
||||
|
||||
# 统计
|
||||
self.total_scores += 1
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interest
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
return [self.score(text) for text in texts]
|
||||
|
||||
async def score_async(self, text: str, timeout: float | None = None) -> float:
|
||||
"""异步计算兴趣度(使用全局线程池)"""
|
||||
timeout = timeout or self.config.score_timeout
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score, text),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...")
|
||||
return 0.5
|
||||
|
||||
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
|
||||
"""异步批量计算兴趣度"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
timeout = timeout or self.config.score_timeout * len(texts)
|
||||
executor = get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score_batch, texts),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
return {
|
||||
"is_loaded": self.is_loaded,
|
||||
"total_scores": self.total_scores,
|
||||
"total_time": self.total_time,
|
||||
"avg_score_time_ms": avg_time * 1000,
|
||||
"vocab_size": len(self.token_weights),
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
def save(self, path: Path | str):
|
||||
"""保存快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
bundle = {
|
||||
"token_weights": self.token_weights,
|
||||
"bias": self.bias,
|
||||
"config": {
|
||||
"analyzer": self.config.analyzer,
|
||||
"ngram_range": self.config.ngram_range,
|
||||
"lowercase": self.config.lowercase,
|
||||
"weight_prune_threshold": self.config.weight_prune_threshold,
|
||||
"top_k_weights": self.config.top_k_weights,
|
||||
"sigmoid_alpha": self.config.sigmoid_alpha,
|
||||
"score_timeout": self.config.score_timeout,
|
||||
},
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
joblib.dump(bundle, path)
|
||||
logger.info(f"[FastScorer] 已保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path | str) -> "FastScorer":
|
||||
"""加载快速评分器"""
|
||||
import joblib
|
||||
path = Path(path)
|
||||
|
||||
bundle = joblib.load(path)
|
||||
|
||||
config = FastScorerConfig(**bundle["config"])
|
||||
scorer = cls(config)
|
||||
scorer.token_weights = bundle["token_weights"]
|
||||
scorer.bias = bundle["bias"]
|
||||
scorer.meta = bundle.get("meta", {})
|
||||
scorer.is_loaded = True
|
||||
|
||||
logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}")
|
||||
return scorer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 批处理评分队列
|
||||
# ============================================================================
|
||||
@dataclass
|
||||
class ScoringRequest:
|
||||
"""评分请求"""
|
||||
text: str
|
||||
future: asyncio.Future
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class BatchScoringQueue:
|
||||
"""批处理评分队列
|
||||
|
||||
攒一小撮消息一起算,提高 CPU 利用率
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: FastScorer,
|
||||
batch_size: int = 16,
|
||||
flush_interval_ms: float = 50.0,
|
||||
):
|
||||
"""初始化批处理队列
|
||||
|
||||
Args:
|
||||
scorer: 评分器实例
|
||||
batch_size: 批次大小,达到后立即处理
|
||||
flush_interval_ms: 刷新间隔(毫秒),超过后强制处理
|
||||
"""
|
||||
self.scorer = scorer
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_ms / 1000.0
|
||||
|
||||
self._pending: list[ScoringRequest] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
# 统计
|
||||
self.total_batches = 0
|
||||
self.total_requests = 0
|
||||
|
||||
async def start(self):
|
||||
"""启动批处理队列"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.info(f"[BatchQueue] 启动,batch_size={self.batch_size}, interval={self.flush_interval*1000}ms")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理队列"""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 处理剩余请求
|
||||
await self._flush()
|
||||
logger.info("[BatchQueue] 已停止")
|
||||
|
||||
async def score(self, text: str) -> float:
|
||||
"""提交评分请求并等待结果
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
request = ScoringRequest(text=text, future=future)
|
||||
|
||||
async with self._lock:
|
||||
self._pending.append(request)
|
||||
self.total_requests += 1
|
||||
|
||||
# 达到批次大小,立即处理
|
||||
if len(self._pending) >= self.batch_size:
|
||||
asyncio.create_task(self._flush())
|
||||
|
||||
return await future
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""定时刷新循环"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush()
|
||||
|
||||
async def _flush(self):
|
||||
"""处理当前待处理的请求"""
|
||||
async with self._lock:
|
||||
if not self._pending:
|
||||
return
|
||||
|
||||
batch = self._pending.copy()
|
||||
self._pending.clear()
|
||||
|
||||
if not batch:
|
||||
return
|
||||
|
||||
self.total_batches += 1
|
||||
|
||||
try:
|
||||
# 批量评分
|
||||
texts = [req.text for req in batch]
|
||||
scores = await self.scorer.score_batch_async(texts)
|
||||
|
||||
# 分发结果
|
||||
for req, score in zip(batch, scores):
|
||||
if not req.future.done():
|
||||
req.future.set_result(score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BatchQueue] 批量评分失败: {e}")
|
||||
# 返回默认值
|
||||
for req in batch:
|
||||
if not req.future.done():
|
||||
req.future.set_result(0.5)
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
|
||||
return {
|
||||
"total_batches": self.total_batches,
|
||||
"total_requests": self.total_requests,
|
||||
"avg_batch_size": avg_batch_size,
|
||||
"pending_count": len(self._pending),
|
||||
"batch_size": self.batch_size,
|
||||
"flush_interval_ms": self.flush_interval * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 优化评分器工厂
|
||||
# ============================================================================
|
||||
_fast_scorer_instances: dict[str, FastScorer] = {}
|
||||
_batch_queue_instances: dict[str, BatchScoringQueue] = {}
|
||||
|
||||
|
||||
async def get_fast_scorer(
|
||||
model_path: str | Path,
|
||||
use_batch_queue: bool = False,
|
||||
batch_size: int = 16,
|
||||
flush_interval_ms: float = 50.0,
|
||||
force_reload: bool = False,
|
||||
) -> FastScorer | BatchScoringQueue:
|
||||
"""获取快速评分器实例(单例)
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径(.pkl 格式,可以是 sklearn 模型或 FastScorer 保存的)
|
||||
use_batch_queue: 是否使用批处理队列
|
||||
batch_size: 批处理大小
|
||||
flush_interval_ms: 批处理刷新间隔(毫秒)
|
||||
force_reload: 是否强制重新加载
|
||||
|
||||
Returns:
|
||||
FastScorer 或 BatchScoringQueue 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
# 检查是否已存在
|
||||
if not force_reload:
|
||||
if use_batch_queue and path_key in _batch_queue_instances:
|
||||
return _batch_queue_instances[path_key]
|
||||
elif not use_batch_queue and path_key in _fast_scorer_instances:
|
||||
return _fast_scorer_instances[path_key]
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"[优化评分器] 加载模型: {model_path}")
|
||||
|
||||
bundle = joblib.load(model_path)
|
||||
|
||||
# 检查是 FastScorer 还是 sklearn 模型
|
||||
if "token_weights" in bundle:
|
||||
# FastScorer 格式
|
||||
scorer = FastScorer.load(model_path)
|
||||
else:
|
||||
# sklearn 模型格式,需要转换
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
_fast_scorer_instances[path_key] = scorer
|
||||
|
||||
# 如果需要批处理队列
|
||||
if use_batch_queue:
|
||||
queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms)
|
||||
await queue.start()
|
||||
_batch_queue_instances[path_key] = queue
|
||||
return queue
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def convert_sklearn_to_fast(
|
||||
sklearn_model_path: str | Path,
|
||||
output_path: str | Path | None = None,
|
||||
config: FastScorerConfig | None = None,
|
||||
) -> FastScorer:
|
||||
"""将 sklearn 模型转换为 FastScorer 格式
|
||||
|
||||
Args:
|
||||
sklearn_model_path: sklearn 模型路径
|
||||
output_path: 输出路径(可选)
|
||||
config: FastScorer 配置
|
||||
|
||||
Returns:
|
||||
FastScorer 实例
|
||||
"""
|
||||
import joblib
|
||||
|
||||
sklearn_model_path = Path(sklearn_model_path)
|
||||
bundle = joblib.load(sklearn_model_path)
|
||||
|
||||
vectorizer = bundle["vectorizer"]
|
||||
model = bundle["model"]
|
||||
|
||||
# 从 vectorizer 配置推断 n-gram range
|
||||
if config is None:
|
||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
|
||||
config = FastScorerConfig(
|
||||
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
|
||||
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
|
||||
|
||||
# 保存转换后的模型
|
||||
if output_path:
|
||||
output_path = Path(output_path)
|
||||
scorer.save(output_path)
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def clear_fast_scorer_instances():
|
||||
"""清空所有快速评分器实例"""
|
||||
global _fast_scorer_instances, _batch_queue_instances
|
||||
|
||||
# 停止所有批处理队列
|
||||
for queue in _batch_queue_instances.values():
|
||||
asyncio.create_task(queue.stop())
|
||||
|
||||
_fast_scorer_instances.clear()
|
||||
_batch_queue_instances.clear()
|
||||
|
||||
logger.info("[优化评分器] 已清空所有实例")
|
||||
744
src/chat/semantic_interest/runtime_scorer.py
Normal file
744
src/chat/semantic_interest/runtime_scorer.py
Normal file
@@ -0,0 +1,744 @@
|
||||
"""运行时语义兴趣度评分器
|
||||
|
||||
在线推理时使用,提供快速的兴趣度评分
|
||||
支持异步加载、超时保护、批量优化、模型预热
|
||||
|
||||
2024.12 优化更新:
|
||||
- 新增 FastScorer 模式,绕过 sklearn 直接使用 token→weight 字典
|
||||
- 全局线程池避免每次创建新的 executor
|
||||
- 可选的批处理队列模式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||
|
||||
logger = get_logger("semantic_interest.scorer")
|
||||
|
||||
# 全局配置
|
||||
DEFAULT_SCORE_TIMEOUT = 2.0 # 评分超时(秒),从 5.0 降低到 2.0
|
||||
|
||||
# 全局线程池(避免每次创建新的 executor)
|
||||
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
|
||||
_EXECUTOR_MAX_WORKERS = 4
|
||||
|
||||
|
||||
def _get_global_executor() -> ThreadPoolExecutor:
|
||||
"""获取全局线程池(单例)"""
|
||||
global _GLOBAL_EXECUTOR
|
||||
if _GLOBAL_EXECUTOR is None:
|
||||
_GLOBAL_EXECUTOR = ThreadPoolExecutor(
|
||||
max_workers=_EXECUTOR_MAX_WORKERS,
|
||||
thread_name_prefix="semantic_scorer"
|
||||
)
|
||||
logger.info(f"[评分器] 创建全局线程池,workers={_EXECUTOR_MAX_WORKERS}")
|
||||
return _GLOBAL_EXECUTOR
|
||||
|
||||
|
||||
# 单例管理
|
||||
_scorer_instances: dict[str, "SemanticInterestScorer"] = {} # 模型路径 -> 评分器实例
|
||||
_instance_lock = asyncio.Lock() # 创建实例的锁
|
||||
|
||||
|
||||
class SemanticInterestScorer:
|
||||
"""语义兴趣度评分器
|
||||
|
||||
加载训练好的模型,在运行时快速计算消息的语义兴趣度
|
||||
优化特性:
|
||||
- 异步加载支持(非阻塞)
|
||||
- 批量评分优化
|
||||
- 超时保护
|
||||
- 模型预热
|
||||
- 全局线程池(避免重复创建 executor)
|
||||
- 可选的 FastScorer 模式(绕过 sklearn)
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str | Path, use_fast_scorer: bool = True):
|
||||
"""初始化评分器
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径 (.pkl)
|
||||
use_fast_scorer: 是否使用快速评分器模式(推荐)
|
||||
"""
|
||||
self.model_path = Path(model_path)
|
||||
self.vectorizer: TfidfFeatureExtractor | None = None
|
||||
self.model: SemanticInterestModel | None = None
|
||||
self.meta: dict[str, Any] = {}
|
||||
self.is_loaded = False
|
||||
|
||||
# 快速评分器模式
|
||||
self._use_fast_scorer = use_fast_scorer
|
||||
self._fast_scorer = None # FastScorer 实例
|
||||
|
||||
# 统计信息
|
||||
self.total_scores = 0
|
||||
self.total_time = 0.0
|
||||
|
||||
def load(self):
|
||||
"""同步加载模型(阻塞)"""
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
logger.info(f"开始加载模型: {self.model_path}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
bundle = joblib.load(self.model_path)
|
||||
|
||||
self.vectorizer = bundle["vectorizer"]
|
||||
self.model = bundle["model"]
|
||||
self.meta = bundle.get("meta", {})
|
||||
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
self._fast_scorer = FastScorer.from_sklearn_model(
|
||||
self.vectorizer, self.model, config
|
||||
)
|
||||
logger.info(
|
||||
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
|
||||
f"剪枝到 {len(self._fast_scorer.token_weights)}"
|
||||
)
|
||||
|
||||
self.is_loaded = True
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"模型加载成功,耗时: {load_time:.3f}秒, "
|
||||
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
|
||||
)
|
||||
|
||||
if self.meta:
|
||||
logger.info(f"模型元信息: {self.meta}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {e}")
|
||||
raise
|
||||
|
||||
async def load_async(self):
|
||||
"""异步加载模型(非阻塞)"""
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
logger.info(f"开始异步加载模型: {self.model_path}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 在全局线程池中执行 I/O 密集型操作
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
bundle = await loop.run_in_executor(executor, joblib.load, self.model_path)
|
||||
|
||||
self.vectorizer = bundle["vectorizer"]
|
||||
self.model = bundle["model"]
|
||||
self.meta = bundle.get("meta", {})
|
||||
|
||||
# 如果启用快速评分器模式,创建 FastScorer
|
||||
if self._use_fast_scorer:
|
||||
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
|
||||
|
||||
config = FastScorerConfig(
|
||||
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
|
||||
weight_prune_threshold=1e-4,
|
||||
)
|
||||
self._fast_scorer = FastScorer.from_sklearn_model(
|
||||
self.vectorizer, self.model, config
|
||||
)
|
||||
logger.info(
|
||||
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
|
||||
f"剪枝到 {len(self._fast_scorer.token_weights)}"
|
||||
)
|
||||
|
||||
self.is_loaded = True
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"模型异步加载成功,耗时: {load_time:.3f}秒, "
|
||||
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
|
||||
)
|
||||
|
||||
if self.meta:
|
||||
logger.info(f"模型元信息: {self.meta}")
|
||||
|
||||
# 预热模型
|
||||
await self._warmup_async()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型异步加载失败: {e}")
|
||||
raise
|
||||
|
||||
def reload(self):
|
||||
"""重新加载模型(热更新)"""
|
||||
logger.info("重新加载模型...")
|
||||
self.is_loaded = False
|
||||
self.load()
|
||||
|
||||
async def reload_async(self):
|
||||
"""异步重新加载模型"""
|
||||
logger.info("异步重新加载模型...")
|
||||
self.is_loaded = False
|
||||
await self.load_async()
|
||||
|
||||
def score(self, text: str) -> float:
|
||||
"""计算单条消息的语义兴趣度
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0],越高表示越感兴趣
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载,请先调用 load() 或 load_async() 方法")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 优先使用 FastScorer(绕过 sklearn,更快)
|
||||
if self._fast_scorer is not None:
|
||||
interest = self._fast_scorer.score(text)
|
||||
else:
|
||||
# 回退到原始 sklearn 路径
|
||||
# 向量化
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 预测概率
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
|
||||
# proba 顺序为 [-1, 0, 1]
|
||||
p_neg, p_neu, p_pos = proba
|
||||
|
||||
# 兴趣分计算策略:
|
||||
# interest = P(1) + 0.5 * P(0)
|
||||
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
|
||||
# 确保在 [0, 1] 范围内
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
|
||||
# 统计
|
||||
self.total_scores += 1
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interest
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}")
|
||||
return 0.5 # 默认返回中立值
|
||||
|
||||
async def score_async(self, text: str, timeout: float = DEFAULT_SCORE_TIMEOUT) -> float:
|
||||
"""异步计算兴趣度(带超时保护)
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
timeout: 超时时间(秒),超时返回中立值 0.5
|
||||
|
||||
Returns:
|
||||
兴趣分 [0.0, 1.0]
|
||||
"""
|
||||
# 使用全局线程池,避免每次创建新的 executor
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score, text),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"兴趣度计算超时({timeout}秒),消息: {text[:50]}")
|
||||
return 0.5 # 默认中立值
|
||||
|
||||
def score_batch(self, texts: list[str]) -> list[float]:
|
||||
"""批量计算兴趣度
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
|
||||
Returns:
|
||||
兴趣分列表
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载")
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 优先使用 FastScorer
|
||||
if self._fast_scorer is not None:
|
||||
interests = self._fast_scorer.score_batch(texts)
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
return interests
|
||||
else:
|
||||
# 回退到原始 sklearn 路径
|
||||
# 批量向量化
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 批量预测
|
||||
proba = self.model.predict_proba(X)
|
||||
|
||||
# 计算兴趣分
|
||||
interests = []
|
||||
for p_neg, p_neu, p_pos in proba:
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
interest = max(0.0, min(1.0, interest))
|
||||
interests.append(interest)
|
||||
|
||||
# 统计
|
||||
self.total_scores += len(texts)
|
||||
self.total_time += time.time() - start_time
|
||||
|
||||
return interests
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量兴趣度计算失败: {e}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
|
||||
"""异步批量计算兴趣度
|
||||
|
||||
Args:
|
||||
texts: 消息文本列表
|
||||
timeout: 超时时间(秒),None 则使用单条超时*文本数
|
||||
|
||||
Returns:
|
||||
兴趣分列表
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 计算动态超时
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_SCORE_TIMEOUT * len(texts)
|
||||
|
||||
# 使用全局线程池
|
||||
executor = _get_global_executor()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, self.score_batch, texts),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}")
|
||||
return [0.5] * len(texts)
|
||||
|
||||
def _warmup(self, sample_texts: list[str] | None = None):
|
||||
"""预热模型(执行几次推理以优化性能)
|
||||
|
||||
Args:
|
||||
sample_texts: 预热用的样本文本,None 则使用默认样本
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
return
|
||||
|
||||
if sample_texts is None:
|
||||
sample_texts = [
|
||||
"你好",
|
||||
"今天天气怎么样?",
|
||||
"我对这个话题很感兴趣"
|
||||
]
|
||||
|
||||
logger.debug(f"开始预热模型,样本数: {len(sample_texts)}")
|
||||
start_time = time.time()
|
||||
|
||||
for text in sample_texts:
|
||||
try:
|
||||
self.score(text)
|
||||
except Exception:
|
||||
pass # 忽略预热错误
|
||||
|
||||
warmup_time = time.time() - start_time
|
||||
logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}秒")
|
||||
|
||||
async def _warmup_async(self, sample_texts: list[str] | None = None):
|
||||
"""异步预热模型"""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._warmup, sample_texts)
|
||||
|
||||
def get_detailed_score(self, text: str) -> dict[str, Any]:
|
||||
"""获取详细的兴趣度评分信息
|
||||
|
||||
Args:
|
||||
text: 消息文本
|
||||
|
||||
Returns:
|
||||
包含概率分布和最终分数的详细信息
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise ValueError("模型尚未加载")
|
||||
|
||||
X = self.vectorizer.transform([text])
|
||||
proba = self.model.predict_proba(X)[0]
|
||||
pred_label = self.model.predict(X)[0]
|
||||
|
||||
p_neg, p_neu, p_pos = proba
|
||||
interest = float(p_pos + 0.5 * p_neu)
|
||||
|
||||
return {
|
||||
"interest_score": max(0.0, min(1.0, interest)),
|
||||
"proba_distribution": {
|
||||
"dislike": float(p_neg),
|
||||
"neutral": float(p_neu),
|
||||
"like": float(p_pos),
|
||||
},
|
||||
"predicted_label": int(pred_label),
|
||||
"text_preview": text[:100],
|
||||
}
|
||||
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取评分器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
|
||||
|
||||
stats = {
|
||||
"is_loaded": self.is_loaded,
|
||||
"model_path": str(self.model_path),
|
||||
"total_scores": self.total_scores,
|
||||
"total_time": self.total_time,
|
||||
"avg_score_time": avg_time,
|
||||
"avg_score_time_ms": avg_time * 1000, # 毫秒单位更直观
|
||||
"vocabulary_size": (
|
||||
self.vectorizer.get_vocabulary_size()
|
||||
if self.vectorizer and self.is_loaded
|
||||
else 0
|
||||
),
|
||||
"use_fast_scorer": self._use_fast_scorer,
|
||||
"fast_scorer_enabled": self._fast_scorer is not None,
|
||||
"meta": self.meta,
|
||||
}
|
||||
|
||||
# 如果启用了 FastScorer,添加其统计
|
||||
if self._fast_scorer is not None:
|
||||
stats["fast_scorer_stats"] = self._fast_scorer.get_statistics()
|
||||
|
||||
return stats
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mode = "fast" if self._fast_scorer else "sklearn"
|
||||
return (
|
||||
f"SemanticInterestScorer("
|
||||
f"loaded={self.is_loaded}, "
|
||||
f"mode={mode}, "
|
||||
f"model={self.model_path.name})"
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""模型管理器
|
||||
|
||||
支持模型热更新、版本管理和人设感知的模型切换
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: Path):
|
||||
"""初始化管理器
|
||||
|
||||
Args:
|
||||
model_dir: 模型目录
|
||||
"""
|
||||
self.model_dir = Path(model_dir)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.current_scorer: SemanticInterestScorer | None = None
|
||||
self.current_version: str | None = None
|
||||
self.current_persona_info: dict[str, Any] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 自动训练器集成
|
||||
self._auto_trainer = None
|
||||
self._auto_training_started = False # 防止重复启动自动训练
|
||||
|
||||
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None, use_async: bool = True) -> SemanticInterestScorer:
|
||||
"""加载指定版本的模型,支持人设感知(使用单例)
|
||||
|
||||
Args:
|
||||
version: 模型版本号或 "latest" 或 "auto"
|
||||
persona_info: 人设信息,用于自动选择匹配的模型
|
||||
use_async: 是否使用异步加载(推荐)
|
||||
|
||||
Returns:
|
||||
评分器实例(单例)
|
||||
"""
|
||||
async with self._lock:
|
||||
# 如果指定了人设信息,尝试使用自动训练器
|
||||
if persona_info is not None and version == "auto":
|
||||
model_path = await self._get_persona_model(persona_info)
|
||||
elif version == "latest":
|
||||
model_path = self._get_latest_model()
|
||||
else:
|
||||
model_path = self.model_dir / f"semantic_interest_{version}.pkl"
|
||||
|
||||
if not model_path or not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 使用单例获取评分器
|
||||
scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async)
|
||||
|
||||
self.current_scorer = scorer
|
||||
self.current_version = version
|
||||
self.current_persona_info = persona_info
|
||||
|
||||
logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}")
|
||||
return scorer
|
||||
|
||||
async def reload_current_model(self):
|
||||
"""重新加载当前模型"""
|
||||
if not self.current_scorer:
|
||||
raise ValueError("尚未加载任何模型")
|
||||
|
||||
async with self._lock:
|
||||
await self.current_scorer.reload_async()
|
||||
logger.info("模型已重新加载")
|
||||
|
||||
def _get_latest_model(self) -> Path:
|
||||
"""获取最新的模型文件
|
||||
|
||||
Returns:
|
||||
最新模型文件路径
|
||||
"""
|
||||
model_files = list(self.model_dir.glob("semantic_interest_*.pkl"))
|
||||
|
||||
if not model_files:
|
||||
raise FileNotFoundError(f"在 {self.model_dir} 中未找到模型文件")
|
||||
|
||||
# 按修改时间排序
|
||||
latest = max(model_files, key=lambda p: p.stat().st_mtime)
|
||||
return latest
|
||||
|
||||
def get_scorer(self) -> SemanticInterestScorer:
|
||||
"""获取当前评分器
|
||||
|
||||
Returns:
|
||||
当前评分器实例
|
||||
"""
|
||||
if not self.current_scorer:
|
||||
raise ValueError("尚未加载任何模型")
|
||||
|
||||
return self.current_scorer
|
||||
|
||||
async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None:
|
||||
"""根据人设信息获取或训练模型
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
|
||||
Returns:
|
||||
模型文件路径
|
||||
"""
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
# 检查是否需要训练
|
||||
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
days=7,
|
||||
max_samples=1000, # 初始训练使用1000条消息
|
||||
)
|
||||
|
||||
if trained and model_path:
|
||||
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
|
||||
return model_path
|
||||
|
||||
# 获取现有的人设模型
|
||||
model_path = self._auto_trainer.get_model_for_persona(persona_info)
|
||||
if model_path:
|
||||
return model_path
|
||||
|
||||
# 降级到 latest
|
||||
logger.warning("[模型管理器] 未找到人设模型,使用 latest")
|
||||
return self._get_latest_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 获取人设模型失败: {e}")
|
||||
return self._get_latest_model()
|
||||
|
||||
async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool:
|
||||
"""检查人设变化并重新加载模型
|
||||
|
||||
Args:
|
||||
persona_info: 当前人设信息
|
||||
|
||||
Returns:
|
||||
True 如果重新加载了模型
|
||||
"""
|
||||
# 检查人设是否变化
|
||||
if self.current_persona_info == persona_info:
|
||||
return False
|
||||
|
||||
logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
|
||||
|
||||
try:
|
||||
await self.load_model(version="auto", persona_info=persona_info)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 重新加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24):
|
||||
"""启动自动训练任务
|
||||
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
# 使用锁防止并发启动
|
||||
async with self._lock:
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
# 标记为已启动
|
||||
self._auto_training_started = True
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
self._auto_training_started = False # 失败时重置标志
|
||||
|
||||
|
||||
# 单例获取函数
|
||||
async def get_semantic_scorer(
|
||||
model_path: str | Path,
|
||||
force_reload: bool = False,
|
||||
use_async: bool = True
|
||||
) -> SemanticInterestScorer:
|
||||
"""获取语义兴趣度评分器实例(单例模式)
|
||||
|
||||
同一个模型路径只会创建一个评分器实例,避免重复加载模型。
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
force_reload: 是否强制重新加载模型
|
||||
use_async: 是否使用异步加载(推荐)
|
||||
|
||||
Returns:
|
||||
评分器实例(单例)
|
||||
|
||||
Example:
|
||||
>>> scorer = await get_semantic_scorer("data/semantic_interest/models/model.pkl")
|
||||
>>> score = await scorer.score_async("今天天气真好")
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve()) # 使用绝对路径作为键
|
||||
|
||||
async with _instance_lock:
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
scorer = _scorer_instances[path_key]
|
||||
if scorer.is_loaded:
|
||||
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
|
||||
return scorer
|
||||
else:
|
||||
logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}")
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
_scorer_instances[path_key] = scorer
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
# 加载模型
|
||||
if use_async:
|
||||
await scorer.load_async()
|
||||
else:
|
||||
scorer.load()
|
||||
|
||||
return scorer
|
||||
|
||||
|
||||
def get_semantic_scorer_sync(
|
||||
model_path: str | Path,
|
||||
force_reload: bool = False
|
||||
) -> SemanticInterestScorer:
|
||||
"""获取语义兴趣度评分器实例(同步版本,单例模式)
|
||||
|
||||
注意:这是同步版本,推荐使用异步版本 get_semantic_scorer()
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
force_reload: 是否强制重新加载模型
|
||||
|
||||
Returns:
|
||||
评分器实例(单例)
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
path_key = str(model_path.resolve())
|
||||
|
||||
# 检查是否已存在实例
|
||||
if not force_reload and path_key in _scorer_instances:
|
||||
scorer = _scorer_instances[path_key]
|
||||
if scorer.is_loaded:
|
||||
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
|
||||
return scorer
|
||||
|
||||
# 创建或重新加载实例
|
||||
if path_key not in _scorer_instances:
|
||||
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
|
||||
scorer = SemanticInterestScorer(model_path)
|
||||
_scorer_instances[path_key] = scorer
|
||||
else:
|
||||
scorer = _scorer_instances[path_key]
|
||||
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
|
||||
|
||||
# 加载模型
|
||||
scorer.load()
|
||||
return scorer
|
||||
|
||||
|
||||
def clear_scorer_instances():
|
||||
"""清空所有评分器实例(释放内存)"""
|
||||
global _scorer_instances
|
||||
count = len(_scorer_instances)
|
||||
_scorer_instances.clear()
|
||||
logger.info(f"[单例] 已清空 {count} 个评分器实例")
|
||||
|
||||
|
||||
def get_all_scorer_instances() -> dict[str, SemanticInterestScorer]:
|
||||
"""获取所有已创建的评分器实例
|
||||
|
||||
Returns:
|
||||
{模型路径: 评分器实例} 的字典
|
||||
"""
|
||||
return _scorer_instances.copy()
|
||||
202
src/chat/semantic_interest/trainer.py
Normal file
202
src/chat/semantic_interest/trainer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""训练器入口脚本
|
||||
|
||||
统一的训练流程入口,包含数据采样、标注、训练、评估
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import joblib
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
||||
from src.chat.semantic_interest.model_lr import train_semantic_model
|
||||
|
||||
logger = get_logger("semantic_interest.trainer")
|
||||
|
||||
|
||||
class SemanticInterestTrainer:
|
||||
"""语义兴趣度训练器
|
||||
|
||||
统一管理训练流程
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | None = None,
|
||||
model_dir: Path | None = None,
|
||||
):
|
||||
"""初始化训练器
|
||||
|
||||
Args:
|
||||
data_dir: 数据集目录
|
||||
model_dir: 模型保存目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
|
||||
self.model_dir = Path(model_dir or "data/semantic_interest/models")
|
||||
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def prepare_dataset(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
model_name: str | None = None,
|
||||
dataset_name: str | None = None,
|
||||
generate_initial_keywords: bool = True,
|
||||
keyword_temperature: float = 0.7,
|
||||
keyword_iterations: int = 3,
|
||||
) -> Path:
|
||||
"""准备训练数据集
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息
|
||||
days: 采样最近 N 天的消息
|
||||
max_samples: 最大采样数
|
||||
model_name: LLM 模型名称
|
||||
dataset_name: 数据集名称(默认使用时间戳)
|
||||
generate_initial_keywords: 是否生成初始关键词数据集
|
||||
keyword_temperature: 关键词生成温度
|
||||
keyword_iterations: 关键词生成迭代次数
|
||||
|
||||
Returns:
|
||||
数据集文件路径
|
||||
"""
|
||||
if dataset_name is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
dataset_name = f"dataset_{timestamp}"
|
||||
|
||||
output_path = self.data_dir / f"{dataset_name}.json"
|
||||
|
||||
logger.info(f"开始准备数据集: {dataset_name}")
|
||||
|
||||
await generate_training_dataset(
|
||||
output_path=output_path,
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_name=model_name,
|
||||
generate_initial_keywords=generate_initial_keywords,
|
||||
keyword_temperature=keyword_temperature,
|
||||
keyword_iterations=keyword_iterations,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
dataset_path: Path,
|
||||
model_version: str | None = None,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
test_size: float = 0.1,
|
||||
) -> tuple[Path, dict]:
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
dataset_path: 数据集文件路径
|
||||
model_version: 模型版本号(默认使用时间戳)
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
test_size: 验证集比例
|
||||
|
||||
Returns:
|
||||
(模型文件路径, 训练指标)
|
||||
"""
|
||||
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
||||
|
||||
# 加载数据集
|
||||
from src.chat.semantic_interest.dataset import DatasetGenerator
|
||||
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
||||
|
||||
# 训练模型
|
||||
vectorizer, model, metrics = train_semantic_model(
|
||||
texts=texts,
|
||||
labels=labels,
|
||||
test_size=test_size,
|
||||
tfidf_config=tfidf_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# 保存模型
|
||||
if model_version is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_version = timestamp
|
||||
|
||||
model_path = self.model_dir / f"semantic_interest_{model_version}.pkl"
|
||||
|
||||
bundle = {
|
||||
"vectorizer": vectorizer,
|
||||
"model": model,
|
||||
"meta": {
|
||||
"version": model_version,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"dataset": str(dataset_path),
|
||||
"train_samples": len(texts),
|
||||
"metrics": metrics,
|
||||
"tfidf_config": vectorizer.get_config(),
|
||||
"model_config": model.get_config(),
|
||||
},
|
||||
}
|
||||
|
||||
joblib.dump(bundle, model_path)
|
||||
logger.info(f"模型已保存到: {model_path}")
|
||||
|
||||
return model_path, metrics
|
||||
|
||||
async def full_training_pipeline(
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 1000,
|
||||
llm_model_name: str | None = None,
|
||||
tfidf_config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
dataset_name: str | None = None,
|
||||
model_version: str | None = None,
|
||||
) -> tuple[Path, Path, dict]:
|
||||
"""完整训练流程
|
||||
|
||||
Args:
|
||||
persona_info: 人格信息
|
||||
days: 采样天数
|
||||
max_samples: 最大采样数
|
||||
llm_model_name: LLM 模型名称
|
||||
tfidf_config: TF-IDF 配置
|
||||
model_config: 模型配置
|
||||
dataset_name: 数据集名称
|
||||
model_version: 模型版本
|
||||
|
||||
Returns:
|
||||
(数据集路径, 模型路径, 训练指标)
|
||||
"""
|
||||
logger.info("开始完整训练流程")
|
||||
|
||||
# 1. 准备数据集
|
||||
dataset_path = await self.prepare_dataset(
|
||||
persona_info=persona_info,
|
||||
days=days,
|
||||
max_samples=max_samples,
|
||||
model_name=llm_model_name,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
|
||||
# 2. 训练模型
|
||||
model_path, metrics = self.train_model(
|
||||
dataset_path=dataset_path,
|
||||
model_version=model_version,
|
||||
tfidf_config=tfidf_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
logger.info("完整训练流程完成")
|
||||
logger.info(f"数据集: {dataset_path}")
|
||||
logger.info(f"模型: {model_path}")
|
||||
logger.info(f"指标: {metrics}")
|
||||
|
||||
return dataset_path, model_path, metrics
|
||||
|
||||
@@ -96,7 +96,7 @@ class ChineseTypoGenerator:
|
||||
|
||||
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
|
||||
if _shared_pinyin_dict is None:
|
||||
_shared_pinyin_dict = self._create_pinyin_dict()
|
||||
_shared_pinyin_dict = self._load_or_create_pinyin_dict()
|
||||
logger.debug("拼音字典已创建并缓存")
|
||||
self.pinyin_dict = _shared_pinyin_dict
|
||||
|
||||
@@ -141,6 +141,35 @@ class ChineseTypoGenerator:
|
||||
|
||||
return normalized_freq
|
||||
|
||||
def _load_or_create_pinyin_dict(self):
|
||||
"""
|
||||
加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动)
|
||||
"""
|
||||
cache_file = Path("depends-data/pinyin_dict.json")
|
||||
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
# 恢复为 defaultdict(list) 以兼容旧逻辑
|
||||
restored = defaultdict(list)
|
||||
for py, chars in data.items():
|
||||
restored[py] = list(chars)
|
||||
return restored
|
||||
except Exception as e:
|
||||
logger.warning(f"读取拼音缓存失败,将重新生成: {e}")
|
||||
|
||||
pinyin_dict = self._create_pinyin_dict()
|
||||
|
||||
try:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"写入拼音缓存失败(不影响使用): {e}")
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
@staticmethod
|
||||
def _create_pinyin_dict():
|
||||
"""
|
||||
|
||||
@@ -10,11 +10,6 @@ CoreSink 统一管理器
|
||||
3. 使用 MessageRuntime 进行消息路由和处理
|
||||
4. 提供统一的消息发送接口
|
||||
|
||||
架构说明(2025-11 重构):
|
||||
- 集成 mofox_wire.MessageRuntime 作为消息路由中心
|
||||
- 使用 @runtime.on_message() 装饰器注册消息处理器
|
||||
- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑
|
||||
- 简化消息处理链条,提高可扩展性
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -215,26 +215,25 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(paginated_stmt)
|
||||
# .all() 已经返回 list,无需再包装
|
||||
instances = result.scalars().all()
|
||||
|
||||
if not instances:
|
||||
# 没有更多数据
|
||||
break
|
||||
|
||||
# 在 session 内部转换为字典列表
|
||||
# 在 session 内部转换为字典列表,保证字段可用再释放连接
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
if as_dict:
|
||||
yield instances_dicts
|
||||
else:
|
||||
yield [_dict_to_model(self.model, row) for row in instances_dicts]
|
||||
if as_dict:
|
||||
yield instances_dicts
|
||||
else:
|
||||
yield [_dict_to_model(self.model, row) for row in instances_dicts]
|
||||
|
||||
# 如果返回的记录数小于 batch_size,说明已经是最后一批
|
||||
if len(instances) < batch_size:
|
||||
break
|
||||
# 如果返回的记录数小于 batch_size,说明已经是最后一批
|
||||
if len(instances) < batch_size:
|
||||
break
|
||||
|
||||
offset += batch_size
|
||||
offset += batch_size
|
||||
|
||||
async def iter_all(
|
||||
self,
|
||||
|
||||
@@ -20,6 +20,7 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("redis_cache")
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from redis.asyncio.connection import Connection, SSLConnection
|
||||
|
||||
|
||||
class RedisCache(CacheBackend):
|
||||
@@ -98,7 +99,11 @@ class RedisCache(CacheBackend):
|
||||
return self._client
|
||||
|
||||
try:
|
||||
# 创建连接池 (使用 aioredis 模块确保类型安全)
|
||||
# redis-py 7.x+ 使用 connection_class 来指定 SSL 连接
|
||||
# 不再支持直接传递 ssl=True/False 给 ConnectionPool
|
||||
connection_class = SSLConnection if self.ssl else Connection
|
||||
|
||||
# 创建连接池
|
||||
self._pool = aioredis.ConnectionPool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
@@ -108,7 +113,7 @@ class RedisCache(CacheBackend):
|
||||
socket_timeout=self.socket_timeout,
|
||||
socket_connect_timeout=self.socket_timeout,
|
||||
decode_responses=False, # 我们自己处理序列化
|
||||
ssl=self.ssl,
|
||||
connection_class=connection_class,
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
|
||||
259
src/common/log_broadcaster.py
Normal file
259
src/common/log_broadcaster.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
日志广播系统
|
||||
用于实时推送日志到多个订阅者(如WebSocket客户端)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
class LogBroadcaster:
|
||||
"""日志广播器,用于实时推送日志到订阅者"""
|
||||
|
||||
def __init__(self, max_buffer_size: int = 1000):
|
||||
"""
|
||||
初始化日志广播器
|
||||
|
||||
Args:
|
||||
max_buffer_size: 缓冲区最大大小,超过后会丢弃旧日志
|
||||
"""
|
||||
self.subscribers: set[Callable[[dict[str, Any]], None]] = set()
|
||||
self.buffer: deque[dict[str, Any]] = deque(maxlen=max_buffer_size)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def subscribe(self, callback: Callable[[dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
订阅日志推送
|
||||
|
||||
Args:
|
||||
callback: 接收日志的回调函数,参数为日志字典
|
||||
"""
|
||||
async with self._lock:
|
||||
self.subscribers.add(callback)
|
||||
|
||||
async def unsubscribe(self, callback: Callable[[dict[str, Any]], None]) -> None:
|
||||
"""
|
||||
取消订阅
|
||||
|
||||
Args:
|
||||
callback: 要移除的回调函数
|
||||
"""
|
||||
async with self._lock:
|
||||
self.subscribers.discard(callback)
|
||||
|
||||
async def broadcast(self, log_record: dict[str, Any]) -> None:
|
||||
"""
|
||||
广播日志到所有订阅者
|
||||
|
||||
Args:
|
||||
log_record: 日志记录字典
|
||||
"""
|
||||
# 添加到缓冲区
|
||||
async with self._lock:
|
||||
self.buffer.append(log_record)
|
||||
# 创建订阅者列表的副本,避免在迭代时修改
|
||||
subscribers = list(self.subscribers)
|
||||
|
||||
# 异步发送到所有订阅者
|
||||
tasks = []
|
||||
for callback in subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
tasks.append(asyncio.create_task(callback(log_record)))
|
||||
else:
|
||||
# 同步回调在线程池中执行
|
||||
tasks.append(asyncio.to_thread(callback, log_record))
|
||||
except Exception:
|
||||
# 忽略单个订阅者的错误
|
||||
pass
|
||||
|
||||
# 等待所有发送完成(但不阻塞太久)
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=1.0)
|
||||
|
||||
def get_recent_logs(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取最近的日志记录
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志数量
|
||||
|
||||
Returns:
|
||||
日志记录列表
|
||||
"""
|
||||
return list(self.buffer)[-limit:]
|
||||
|
||||
def clear_buffer(self) -> None:
|
||||
"""清空日志缓冲区"""
|
||||
self.buffer.clear()
|
||||
|
||||
|
||||
class BroadcastLogHandler(logging.Handler):
|
||||
"""
|
||||
日志处理器,将日志推送到广播器
|
||||
"""
|
||||
|
||||
def __init__(self, broadcaster: LogBroadcaster):
|
||||
"""
|
||||
初始化处理器
|
||||
|
||||
Args:
|
||||
broadcaster: 日志广播器实例
|
||||
"""
|
||||
super().__init__()
|
||||
self.broadcaster = broadcaster
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def _get_logger_metadata(self, logger_name: str) -> dict[str, str | None]:
|
||||
"""
|
||||
获取logger的元数据(别名和颜色)
|
||||
|
||||
Args:
|
||||
logger_name: logger名称
|
||||
|
||||
Returns:
|
||||
包含alias和color的字典
|
||||
"""
|
||||
try:
|
||||
# 导入logger元数据获取函数
|
||||
from src.common.logger import get_logger_meta
|
||||
|
||||
return get_logger_meta(logger_name)
|
||||
except Exception:
|
||||
# 如果获取失败,返回空元数据
|
||||
return {"alias": None, "color": None}
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
"""
|
||||
处理日志记录
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
"""
|
||||
try:
|
||||
# 获取logger元数据(别名和颜色)
|
||||
logger_meta = self._get_logger_metadata(record.name)
|
||||
|
||||
# 转换日志记录为字典
|
||||
log_dict = {
|
||||
"timestamp": self.format_time(record),
|
||||
"level": record.levelname, # 保持大写,与前端筛选器一致
|
||||
"logger_name": record.name, # 原始logger名称
|
||||
"event": record.getMessage(),
|
||||
}
|
||||
|
||||
# 添加别名和颜色(如果存在)
|
||||
if logger_meta["alias"]:
|
||||
log_dict["alias"] = logger_meta["alias"]
|
||||
if logger_meta["color"]:
|
||||
log_dict["color"] = logger_meta["color"]
|
||||
|
||||
# 添加额外字段
|
||||
if hasattr(record, "__dict__"):
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in (
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"created",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"module",
|
||||
"msecs",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
):
|
||||
try:
|
||||
# 尝试序列化以确保可以转为JSON
|
||||
orjson.dumps(value)
|
||||
log_dict[key] = value
|
||||
except (TypeError, ValueError):
|
||||
log_dict[key] = str(value)
|
||||
|
||||
# 获取或创建事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新任务
|
||||
if self.loop is None:
|
||||
try:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
except Exception:
|
||||
return
|
||||
loop = self.loop
|
||||
|
||||
# 在事件循环中异步广播
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.broadcaster.broadcast(log_dict), loop
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# 忽略广播错误,避免影响日志系统
|
||||
pass
|
||||
|
||||
def format_time(self, record: logging.LogRecord) -> str:
|
||||
"""
|
||||
格式化时间戳
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
格式化的时间字符串
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(record.created)
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
||||
|
||||
|
||||
# 全局广播器实例
|
||||
_global_broadcaster: LogBroadcaster | None = None
|
||||
|
||||
|
||||
def get_log_broadcaster() -> LogBroadcaster:
|
||||
"""
|
||||
获取全局日志广播器实例
|
||||
|
||||
Returns:
|
||||
日志广播器实例
|
||||
"""
|
||||
global _global_broadcaster
|
||||
if _global_broadcaster is None:
|
||||
_global_broadcaster = LogBroadcaster()
|
||||
return _global_broadcaster
|
||||
|
||||
|
||||
def setup_log_broadcasting() -> LogBroadcaster:
|
||||
"""
|
||||
设置日志广播系统,将日志处理器添加到根日志记录器
|
||||
|
||||
Returns:
|
||||
日志广播器实例
|
||||
"""
|
||||
broadcaster = get_log_broadcaster()
|
||||
|
||||
# 创建并添加广播处理器到根日志记录器
|
||||
handler = BroadcastLogHandler(broadcaster)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
|
||||
# 添加到根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
return broadcaster
|
||||
@@ -1,6 +1,7 @@
|
||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
@@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler):
|
||||
self.backup_count = backup_count
|
||||
self.encoding = encoding
|
||||
self._lock = threading.Lock()
|
||||
self._current_size = 0
|
||||
self._bytes_since_check = 0
|
||||
self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8"))
|
||||
self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024)
|
||||
|
||||
# 当前活跃的日志文件
|
||||
self.current_file = None
|
||||
@@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler):
|
||||
# 极低概率碰撞,稍作等待
|
||||
time.sleep(0.001)
|
||||
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
|
||||
self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0
|
||||
self._bytes_since_check = 0
|
||||
|
||||
def _should_rollover(self):
|
||||
"""检查是否需要轮转"""
|
||||
if self.current_file and self.current_file.exists():
|
||||
return self.current_file.stat().st_size >= self.max_bytes
|
||||
def _should_rollover(self, incoming_size: int = 0) -> bool:
|
||||
"""检查是否需要轮转,使用内存缓存的大小信息减少磁盘stat次数。"""
|
||||
if not self.current_file:
|
||||
return False
|
||||
|
||||
projected = self._current_size + incoming_size
|
||||
if projected >= self.max_bytes:
|
||||
return True
|
||||
|
||||
self._bytes_since_check += incoming_size
|
||||
if self._bytes_since_check >= self._stat_refresh_threshold:
|
||||
try:
|
||||
if self.current_file.exists():
|
||||
self._current_size = self.current_file.stat().st_size
|
||||
else:
|
||||
self._current_size = 0
|
||||
except OSError:
|
||||
self._current_size = 0
|
||||
finally:
|
||||
self._bytes_since_check = 0
|
||||
return False
|
||||
|
||||
def _do_rollover(self):
|
||||
@@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
"""发出日志记录"""
|
||||
try:
|
||||
message = self.format(record)
|
||||
encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes
|
||||
|
||||
with self._lock:
|
||||
# 检查是否需要轮转
|
||||
if self._should_rollover():
|
||||
if self._should_rollover(encoded_len):
|
||||
self._do_rollover()
|
||||
|
||||
# 写入日志
|
||||
if self.current_stream:
|
||||
msg = self.format(record)
|
||||
self.current_stream.write(msg + "\n")
|
||||
self.current_stream.write(message + "\n")
|
||||
self.current_stream.flush()
|
||||
self._current_size += encoded_len
|
||||
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
@@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = {
|
||||
}
|
||||
|
||||
|
||||
# 创建全局 Rich Console 实例用于颜色渲染
|
||||
_rich_console = Console(force_terminal=True, color_system="truecolor")
|
||||
|
||||
|
||||
class ModuleColoredConsoleRenderer:
|
||||
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
|
||||
|
||||
@@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer:
|
||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||
self._colors = colors
|
||||
self._config = LOG_CONFIG
|
||||
self._render_console = Console(force_terminal=True, color_system="truecolor", width=999)
|
||||
|
||||
# 日志级别颜色 (#RRGGBB 格式)
|
||||
self._level_colors_hex = {
|
||||
@@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer:
|
||||
self._enable_level_colors = False
|
||||
self._enable_full_content_colors = False
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_markup(content: str) -> bool:
|
||||
"""快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。"""
|
||||
if not content:
|
||||
return False
|
||||
return "[" in content and "]" in content
|
||||
|
||||
def _render_content_text(self, content: str, *, style: str | None = None) -> Text:
|
||||
"""只在必要时解析 Rich 标记,降低CPU占用。"""
|
||||
if self._looks_like_markup(content):
|
||||
try:
|
||||
return Text.from_markup(content, style=style)
|
||||
except Exception:
|
||||
return Text(content, style=style)
|
||||
return Text(content, style=style)
|
||||
|
||||
def __call__(self, logger, method_name, event_dict):
|
||||
# sourcery skip: merge-duplicate-blocks
|
||||
"""渲染日志消息"""
|
||||
@@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer:
|
||||
if prefix:
|
||||
# 解析 prefix 中的 Rich 标记
|
||||
if module_hex_color:
|
||||
content_text.append(Text.from_markup(prefix, style=module_hex_color))
|
||||
content_text.append(self._render_content_text(prefix, style=module_hex_color))
|
||||
else:
|
||||
content_text.append(Text.from_markup(prefix))
|
||||
content_text.append(self._render_content_text(prefix))
|
||||
|
||||
# 与"内心思考"段落之间插入空行
|
||||
if prefix:
|
||||
@@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer:
|
||||
else:
|
||||
# 使用 Text.from_markup 解析 Rich 标记语言
|
||||
if module_hex_color:
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content, style=module_hex_color))
|
||||
except Exception:
|
||||
# 如果标记解析失败,回退到普通文本
|
||||
parts.append(Text(event_content, style=module_hex_color))
|
||||
parts.append(self._render_content_text(event_content, style=module_hex_color))
|
||||
else:
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content))
|
||||
except Exception:
|
||||
# 如果标记解析失败,回退到普通文本
|
||||
parts.append(Text(event_content))
|
||||
parts.append(self._render_content_text(event_content))
|
||||
else:
|
||||
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content))
|
||||
except Exception:
|
||||
# 如果标记解析失败,使用普通文本
|
||||
parts.append(Text(event_content))
|
||||
parts.append(self._render_content_text(event_content))
|
||||
|
||||
# 处理其他字段
|
||||
extras = []
|
||||
@@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer:
|
||||
|
||||
# 使用 Rich 拼接并返回字符串
|
||||
result = Text(" ").join(parts)
|
||||
# 将 Rich Text 对象转换为带 ANSI 颜色码的字符串
|
||||
from io import StringIO
|
||||
string_io = StringIO()
|
||||
temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999)
|
||||
temp_console.print(result, end="")
|
||||
return string_io.getvalue()
|
||||
# 使用持久化 Console + capture 避免每条日志重复实例化
|
||||
with self._render_console.capture() as capture:
|
||||
self._render_console.print(result, end="")
|
||||
return capture.get()
|
||||
|
||||
|
||||
# 配置标准logging以支持文件输出和压缩
|
||||
|
||||
@@ -506,24 +506,16 @@ def load_config(config_path: str) -> Config:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
|
||||
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
|
||||
# 但在 Pydantic 严格模式下可能导致类型验证失败
|
||||
config_dict = config_data.unwrap()
|
||||
|
||||
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
|
||||
try:
|
||||
logger.debug("正在解析和验证配置文件...")
|
||||
config = Config.from_dict(config_data)
|
||||
config = Config.from_dict(config_dict)
|
||||
logger.debug("配置文件解析和验证完成")
|
||||
|
||||
# 【临时修复】在验证后,手动从原始数据重新加载 master_users
|
||||
try:
|
||||
# 先将 tomlkit 对象转换为纯 Python 字典
|
||||
config_dict = config_data.unwrap()
|
||||
if "permission" in config_dict and "master_users" in config_dict["permission"]:
|
||||
raw_master_users = config_dict["permission"]["master_users"]
|
||||
# 现在 raw_master_users 就是一个标准的 Python 列表了
|
||||
config.permission.master_users = raw_master_users
|
||||
logger.debug(f"【临时修复】已手动将 master_users 设置为: {config.permission.master_users}")
|
||||
except Exception as patch_exc:
|
||||
logger.error(f"【临时修复】手动设置 master_users 失败: {patch_exc}")
|
||||
|
||||
return config
|
||||
except Exception as e:
|
||||
logger.critical(f"配置文件解析失败: {e}")
|
||||
@@ -581,4 +573,4 @@ def initialize_configs_once() -> tuple[Config, APIAdapterConfig]:
|
||||
# 同一进程只执行一次初始化,避免重复生成或覆盖配置
|
||||
global_config, model_config = initialize_configs_once()
|
||||
|
||||
logger.debug("非常的新鲜,非常的美味!")
|
||||
logger.debug("非常的新鲜,非常的美味!")
|
||||
|
||||
@@ -213,6 +213,12 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
default="classic",
|
||||
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
|
||||
)
|
||||
model_temperature: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=5.0,
|
||||
description="表达模型采样温度,0为贪婪,值越大越容易采样到低分表达"
|
||||
)
|
||||
expiration_days: int = Field(
|
||||
default=90,
|
||||
description="表达方式过期天数,超过此天数未激活的表达方式将被清理"
|
||||
@@ -508,6 +514,7 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
short_term_decay_factor: float = Field(default=0.98, description="衰减因子")
|
||||
|
||||
# 长期记忆层配置
|
||||
use_judge: bool = Field(default=True, description="使用评判模型决定是否检索长期记忆")
|
||||
long_term_batch_size: int = Field(default=10, description="批量转移大小")
|
||||
long_term_decay_factor: float = Field(default=0.95, description="衰减因子")
|
||||
long_term_auto_transfer_interval: int = Field(default=60, description="自动转移间隔(秒)")
|
||||
@@ -796,14 +803,6 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
||||
# 兴趣评分系统参数
|
||||
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
||||
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
||||
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
|
||||
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
|
||||
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
|
||||
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
|
||||
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
|
||||
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
||||
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
||||
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
||||
|
||||
# 回复决策系统参数
|
||||
no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值")
|
||||
@@ -1009,4 +1008,3 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
default_factory=KokoroFlowChatterProactiveConfig,
|
||||
description="私聊专属主动思考配置"
|
||||
)
|
||||
|
||||
|
||||
@@ -79,9 +79,6 @@ class Individuality:
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
|
||||
# 初始化智能兴趣系统
|
||||
await self._initialize_smart_interest_system(personality_result, identity_result)
|
||||
|
||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||
if personality_changed or identity_changed:
|
||||
logger.info("将清空数据库中原有的关键词缓存")
|
||||
@@ -93,20 +90,6 @@ class Individuality:
|
||||
}
|
||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||
|
||||
async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str):
|
||||
"""初始化智能兴趣系统"""
|
||||
# 组合完整的人设描述
|
||||
full_personality = f"{personality_result},{identity_result}"
|
||||
|
||||
# 使用统一的评分API初始化智能兴趣系统
|
||||
from src.plugin_system.apis import person_api
|
||||
|
||||
await person_api.initialize_smart_interests(
|
||||
personality_description=full_personality, personality_id=self.bot_person_id
|
||||
)
|
||||
|
||||
logger.info("智能兴趣系统初始化完成")
|
||||
|
||||
async def get_personality_block(self) -> str:
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
|
||||
99
src/main.py
99
src/main.py
@@ -33,7 +33,6 @@ from src.config.config import global_config
|
||||
from src.individuality.individuality import Individuality, get_individuality
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
@@ -120,93 +119,6 @@ class MainSystem:
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async def _initialize_interest_calculator(self) -> None:
|
||||
"""初始化兴趣值计算组件 - 通过插件系统自动发现和加载"""
|
||||
try:
|
||||
logger.debug("开始自动发现兴趣值计算组件...")
|
||||
|
||||
# 使用组件注册表自动发现兴趣计算器组件
|
||||
interest_calculators = {}
|
||||
try:
|
||||
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
|
||||
logger.debug(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
|
||||
except Exception as e:
|
||||
logger.error(f"从组件注册表获取兴趣计算器失败: {e}")
|
||||
|
||||
if not interest_calculators:
|
||||
logger.warning("未发现任何兴趣计算器组件")
|
||||
return
|
||||
|
||||
# 初始化兴趣度管理器
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
|
||||
interest_manager = get_interest_manager()
|
||||
await interest_manager.initialize()
|
||||
|
||||
# 尝试注册所有可用的计算器
|
||||
registered_calculators = []
|
||||
|
||||
for calc_name, calc_info in interest_calculators.items():
|
||||
enabled = getattr(calc_info, "enabled", True)
|
||||
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
||||
|
||||
if not enabled or not default_enabled:
|
||||
logger.debug(f"兴趣计算器 {calc_name} 未启用,跳过")
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.plugin_system.base.component_types import ComponentType as CT
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
component_class = component_registry.get_component_class(
|
||||
calc_name, CT.INTEREST_CALCULATOR
|
||||
)
|
||||
|
||||
if not component_class:
|
||||
logger.warning(f"无法找到 {calc_name} 的组件类")
|
||||
continue
|
||||
|
||||
logger.debug(f"成功获取 {calc_name} 的组件类: {component_class.__name__}")
|
||||
|
||||
# 确保组件是 BaseInterestCalculator 的子类
|
||||
if not issubclass(component_class, BaseInterestCalculator):
|
||||
logger.warning(f"{calc_name} 不是 BaseInterestCalculator 的有效子类")
|
||||
continue
|
||||
|
||||
# 显式转换类型以修复 Pyright 错误
|
||||
component_class = cast(type[BaseInterestCalculator], component_class)
|
||||
|
||||
# 创建组件实例
|
||||
calculator_instance = component_class()
|
||||
|
||||
# 初始化组件
|
||||
if not await calculator_instance.initialize():
|
||||
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
|
||||
continue
|
||||
|
||||
# 注册到兴趣管理器
|
||||
if await interest_manager.register_calculator(calculator_instance):
|
||||
registered_calculators.append(calculator_instance)
|
||||
logger.debug(f"成功注册兴趣计算器: {calc_name}")
|
||||
else:
|
||||
logger.error(f"兴趣计算器 {calc_name} 注册失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}")
|
||||
|
||||
if registered_calculators:
|
||||
logger.debug(f"成功注册了 {len(registered_calculators)} 个兴趣计算器")
|
||||
for calc in registered_calculators:
|
||||
logger.debug(f" - {calc.component_name} v{calc.component_version}")
|
||||
else:
|
||||
logger.error("未能成功注册任何兴趣计算器")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化兴趣度计算器失败: {e}")
|
||||
|
||||
async def _async_cleanup(self) -> None:
|
||||
"""异步清理资源"""
|
||||
if self._cleanup_started:
|
||||
@@ -474,6 +386,14 @@ class MainSystem:
|
||||
await mood_manager.start()
|
||||
logger.debug("情绪管理器初始化成功")
|
||||
|
||||
# 初始化日志广播系统
|
||||
try:
|
||||
from src.common.log_broadcaster import setup_log_broadcasting
|
||||
setup_log_broadcasting()
|
||||
logger.debug("日志广播系统初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"日志广播系统初始化失败: {e}")
|
||||
|
||||
# 启动聊天管理器的自动保存任务
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
task = asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
@@ -499,9 +419,6 @@ class MainSystem:
|
||||
except Exception as e:
|
||||
logger.error(f"三层记忆系统初始化失败: {e}")
|
||||
|
||||
# 初始化消息兴趣值计算组件
|
||||
await self._initialize_interest_calculator()
|
||||
|
||||
# 初始化LPMM知识库
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# ruff: noqa: G004, BLE001
|
||||
# pylint: disable=logging-fstring-interpolation,broad-except,unused-argument
|
||||
# pyright: reportOptionalMemberAccess=false
|
||||
"""
|
||||
记忆管理器 - Phase 3
|
||||
|
||||
@@ -218,7 +221,7 @@ class MemoryManager:
|
||||
subject: str,
|
||||
memory_type: str,
|
||||
topic: str,
|
||||
object: str | None = None,
|
||||
obj: str | None = None,
|
||||
attributes: dict[str, str] | None = None,
|
||||
importance: float = 0.5,
|
||||
**kwargs,
|
||||
@@ -230,7 +233,7 @@ class MemoryManager:
|
||||
subject: 主体(谁)
|
||||
memory_type: 记忆类型(事件/观点/事实/关系)
|
||||
topic: 主题(做什么/想什么)
|
||||
object: 客体(对谁/对什么)
|
||||
obj: 客体(对谁/对什么)
|
||||
attributes: 属性字典(时间、地点、原因等)
|
||||
importance: 重要性 (0.0-1.0)
|
||||
**kwargs: 其他参数
|
||||
@@ -246,7 +249,7 @@ class MemoryManager:
|
||||
subject=subject,
|
||||
memory_type=memory_type,
|
||||
topic=topic,
|
||||
object=object,
|
||||
object=obj,
|
||||
attributes=attributes,
|
||||
importance=importance,
|
||||
**kwargs,
|
||||
@@ -775,6 +778,8 @@ class MemoryManager:
|
||||
logger.debug(f"传播激活到相关记忆 {related_id[:8]} 失败: {e}")
|
||||
|
||||
# 再次保存传播后的更新
|
||||
assert self.persistence is not None
|
||||
assert self.graph_store is not None
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
logger.debug(f"后台保存激活更新完成,处理了 {len(memories)} 条记忆")
|
||||
@@ -811,7 +816,6 @@ class MemoryManager:
|
||||
|
||||
# 批量执行传播任务
|
||||
if propagation_tasks:
|
||||
import asyncio
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*propagation_tasks, return_exceptions=True),
|
||||
@@ -837,6 +841,8 @@ class MemoryManager:
|
||||
Returns:
|
||||
相关记忆 ID 列表
|
||||
"""
|
||||
_ = max_depth # 保留参数以兼容旧调用
|
||||
|
||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||
if not memory:
|
||||
return []
|
||||
@@ -997,7 +1003,7 @@ class MemoryManager:
|
||||
if memories_to_forget:
|
||||
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
|
||||
|
||||
for memory_id, activation in memories_to_forget:
|
||||
for memory_id, _ in memories_to_forget:
|
||||
# cleanup_orphans=False:暂不清理孤立节点
|
||||
success = await self.forget_memory(memory_id, cleanup_orphans=False)
|
||||
if success:
|
||||
@@ -1008,6 +1014,8 @@ class MemoryManager:
|
||||
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
|
||||
|
||||
# 保存最终更新
|
||||
assert self.persistence is not None
|
||||
assert self.graph_store is not None
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
|
||||
logger.info(
|
||||
@@ -1059,7 +1067,7 @@ class MemoryManager:
|
||||
# 2. 清理孤立边(指向已删除节点的边)
|
||||
edges_to_remove = []
|
||||
|
||||
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
|
||||
for source, target, _ in self.graph_store.graph.edges(data="edge_id"):
|
||||
# 检查边的源节点和目标节点是否还存在于node_to_memories中
|
||||
if source not in self.graph_store.node_to_memories or \
|
||||
target not in self.graph_store.node_to_memories:
|
||||
@@ -1096,7 +1104,7 @@ class MemoryManager:
|
||||
if not self._initialized or not self.graph_store:
|
||||
return {}
|
||||
|
||||
stats = self.graph_store.get_statistics()
|
||||
stats: dict[str, Any] = self.graph_store.get_statistics()
|
||||
|
||||
# 添加激活度统计
|
||||
all_memories = self.graph_store.get_all_memories()
|
||||
@@ -1152,7 +1160,7 @@ class MemoryManager:
|
||||
logger.info("开始记忆整理:检查遗忘 + 清理孤立节点...")
|
||||
|
||||
# 步骤1: 自动遗忘低激活度的记忆
|
||||
forgotten_count = await self.auto_forget()
|
||||
forgotten_count = await self.auto_forget_memories()
|
||||
|
||||
# 步骤2: 清理孤立节点和边(auto_forget内部已执行,这里再次确保)
|
||||
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
|
||||
@@ -1292,6 +1300,8 @@ class MemoryManager:
|
||||
result["orphan_edges_cleaned"] = consolidate_result.get("orphan_edges_cleaned", 0)
|
||||
|
||||
# 2. 保存数据
|
||||
assert self.persistence is not None
|
||||
assert self.graph_store is not None
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
result["saved"] = True
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -186,8 +187,8 @@ class ShortTermMemoryManager:
|
||||
"importance": 0.7,
|
||||
"attributes": {{
|
||||
"time": "时间信息",
|
||||
"attribute1": "其他属性1"
|
||||
"attribute2": "其他属性2"
|
||||
"attribute1": "其他属性1",
|
||||
"attribute2": "其他属性2",
|
||||
...
|
||||
}}
|
||||
}}
|
||||
@@ -530,7 +531,7 @@ class ShortTermMemoryManager:
|
||||
json_str = re.sub(r"//.*", "", json_str)
|
||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||
|
||||
data = json.loads(json_str)
|
||||
data = json_repair.loads(json_str)
|
||||
return data
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.plugin_system.services.interest_service import interest_service
|
||||
from src.plugin_system.services.relationship_service import relationship_service
|
||||
|
||||
logger = get_logger("person_api")
|
||||
@@ -169,37 +168,6 @@ async def update_user_relationship(user_id: str, relationship_score: float, rela
|
||||
await relationship_service.update_user_relationship(user_id, relationship_score, relationship_text, user_name)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 兴趣系统API
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def initialize_smart_interests(personality_description: str, personality_id: str = "default"):
|
||||
"""
|
||||
初始化智能兴趣系统
|
||||
|
||||
Args:
|
||||
personality_description: 机器人性格描述
|
||||
personality_id: 性格ID
|
||||
"""
|
||||
await interest_service.initialize_smart_interests(personality_description, personality_id)
|
||||
|
||||
|
||||
async def calculate_interest_match(
|
||||
content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""计算消息兴趣匹配,返回匹配结果"""
|
||||
if not content:
|
||||
logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空")
|
||||
return None
|
||||
|
||||
try:
|
||||
return await interest_service.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 系统状态与缓存API
|
||||
# =============================================================================
|
||||
@@ -214,7 +182,6 @@ def get_system_stats() -> dict[str, Any]:
|
||||
"""
|
||||
return {
|
||||
"relationship_service": relationship_service.get_cache_stats(),
|
||||
"interest_service": interest_service.get_interest_stats(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, InterestCalculatorInfo
|
||||
|
||||
logger = get_logger("base_interest_calculator")
|
||||
|
||||
@@ -210,26 +209,6 @@ class BaseInterestCalculator(ABC):
|
||||
return default
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_interest_calculator_info(cls) -> "InterestCalculatorInfo":
|
||||
"""从类属性生成InterestCalculatorInfo
|
||||
|
||||
遵循BaseCommand和BaseAction的设计模式,从类属性自动生成组件信息
|
||||
|
||||
Returns:
|
||||
InterestCalculatorInfo: 生成的兴趣计算器信息对象
|
||||
"""
|
||||
name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", ""))
|
||||
if "." in name:
|
||||
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
||||
return InterestCalculatorInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.INTEREST_CALCULATOR,
|
||||
description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"),
|
||||
enabled_by_default=getattr(cls, "enabled_by_default", True),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -7,7 +7,6 @@ from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
InterestCalculatorInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
@@ -17,7 +16,6 @@ from .base_action import BaseAction
|
||||
from .base_adapter import BaseAdapter
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_interest_calculator import BaseInterestCalculator
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .plugin_base import PluginBase
|
||||
@@ -59,15 +57,6 @@ class BasePlugin(PluginBase):
|
||||
logger.warning(f"Action组件 {component_class.__name__} 缺少 get_action_info 方法")
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.INTEREST_CALCULATOR:
|
||||
if hasattr(component_class, "get_interest_calculator_info"):
|
||||
return component_class.get_interest_calculator_info()
|
||||
else:
|
||||
logger.warning(
|
||||
f"InterestCalculator组件 {component_class.__name__} 缺少 get_interest_calculator_info 方法"
|
||||
)
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.PLUS_COMMAND:
|
||||
# PlusCommand组件的get_info方法尚未实现
|
||||
logger.warning("PlusCommand组件的get_info方法尚未实现")
|
||||
@@ -123,7 +112,6 @@ class BasePlugin(PluginBase):
|
||||
| tuple[PlusCommandInfo, type[PlusCommand]]
|
||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||
| tuple[ToolInfo, type[BaseTool]]
|
||||
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
||||
| tuple[PromptInfo, type[BasePrompt]]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
@@ -48,7 +48,6 @@ class ComponentType(Enum):
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||
CHATTER = "chatter" # 聊天处理器组件
|
||||
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
||||
PROMPT = "prompt" # Prompt组件
|
||||
ROUTER = "router" # 路由组件
|
||||
ADAPTER = "adapter" # 适配器组件
|
||||
@@ -298,17 +297,6 @@ class ChatterInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.CHATTER
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestCalculatorInfo(ComponentInfo):
|
||||
"""兴趣度计算组件信息(单例模式)"""
|
||||
|
||||
enabled_by_default: bool = True # 是否默认启用
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.INTEREST_CALCULATOR
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventInfo(ComponentInfo):
|
||||
"""事件组件信息"""
|
||||
|
||||
@@ -17,7 +17,6 @@ from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_http_component import BaseRouterComponent
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
@@ -28,7 +27,6 @@ from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
InterestCalculatorInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
@@ -48,7 +46,6 @@ ComponentClassType = (
|
||||
| type[BaseEventHandler]
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| type[BasePrompt]
|
||||
| type[BaseRouterComponent]
|
||||
| type[BaseAdapter]
|
||||
@@ -144,10 +141,6 @@ class ComponentRegistry:
|
||||
self._chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
|
||||
|
||||
# InterestCalculator 相关
|
||||
self._interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {}
|
||||
self._enabled_interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {}
|
||||
|
||||
# Prompt 相关
|
||||
self._prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
@@ -283,7 +276,6 @@ class ComponentRegistry:
|
||||
ComponentType.TOOL: self._register_tool,
|
||||
ComponentType.EVENT_HANDLER: self._register_event_handler,
|
||||
ComponentType.CHATTER: self._register_chatter,
|
||||
ComponentType.INTEREST_CALCULATOR: self._register_interest_calculator,
|
||||
ComponentType.PROMPT: self._register_prompt,
|
||||
ComponentType.ROUTER: self._register_router,
|
||||
ComponentType.ADAPTER: self._register_adapter,
|
||||
@@ -344,9 +336,6 @@ class ComponentRegistry:
|
||||
case ComponentType.CHATTER:
|
||||
self._chatter_registry.pop(component_name, None)
|
||||
self._enabled_chatter_registry.pop(component_name, None)
|
||||
case ComponentType.INTEREST_CALCULATOR:
|
||||
self._interest_calculator_registry.pop(component_name, None)
|
||||
self._enabled_interest_calculator_registry.pop(component_name, None)
|
||||
case ComponentType.PROMPT:
|
||||
self._prompt_registry.pop(component_name, None)
|
||||
self._enabled_prompt_registry.pop(component_name, None)
|
||||
@@ -497,25 +486,6 @@ class ComponentRegistry:
|
||||
self._enabled_chatter_registry[info.name] = chatter_class
|
||||
return True
|
||||
|
||||
def _register_interest_calculator(self, info: ComponentInfo, cls: ComponentClassType) -> bool:
|
||||
"""
|
||||
注册 InterestCalculator 组件到特定注册表。
|
||||
|
||||
Args:
|
||||
info: InterestCalculator 组件的元数据信息
|
||||
cls: InterestCalculator 组件的类定义
|
||||
|
||||
Returns:
|
||||
注册成功返回 True
|
||||
"""
|
||||
calc_info = cast(InterestCalculatorInfo, info)
|
||||
calc_class = cast(type[BaseInterestCalculator], cls)
|
||||
_assign_plugin_attrs(calc_class, info.plugin_name, self.get_plugin_config(info.plugin_name) or {})
|
||||
self._interest_calculator_registry[info.name] = calc_class
|
||||
if calc_info.enabled:
|
||||
self._enabled_interest_calculator_registry[info.name] = calc_class
|
||||
return True
|
||||
|
||||
def _register_prompt(self, info: ComponentInfo, cls: ComponentClassType) -> bool:
|
||||
"""
|
||||
注册 Prompt 组件到 Prompt 特定注册表。
|
||||
@@ -950,26 +920,6 @@ class ComponentRegistry:
|
||||
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
|
||||
return info if isinstance(info, ChatterInfo) else None
|
||||
|
||||
# --- InterestCalculator ---
|
||||
def get_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]:
|
||||
"""获取所有已注册的 InterestCalculator 类。"""
|
||||
return self._interest_calculator_registry.copy()
|
||||
|
||||
def get_enabled_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]:
|
||||
"""
|
||||
获取所有已启用的 InterestCalculator 类。
|
||||
|
||||
会检查组件的全局启用状态。
|
||||
|
||||
Returns:
|
||||
可用的 InterestCalculator 名称到类的字典
|
||||
"""
|
||||
return {
|
||||
name: cls
|
||||
for name, cls in self._interest_calculator_registry.items()
|
||||
if self.is_component_available(name, ComponentType.INTEREST_CALCULATOR)
|
||||
}
|
||||
|
||||
# --- Prompt ---
|
||||
def get_prompt_registry(self) -> dict[str, type[BasePrompt]]:
|
||||
"""获取所有已注册的 Prompt 类。"""
|
||||
|
||||
@@ -110,8 +110,6 @@ class ComponentStateManager:
|
||||
)
|
||||
case ComponentType.CHATTER:
|
||||
self._registry._enabled_chatter_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.INTEREST_CALCULATOR:
|
||||
self._registry._enabled_interest_calculator_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.PROMPT:
|
||||
self._registry._enabled_prompt_registry[component_name] = target_class # type: ignore
|
||||
case ComponentType.ADAPTER:
|
||||
@@ -161,8 +159,6 @@ class ComponentStateManager:
|
||||
event_manager.remove_event_handler(component_name)
|
||||
case ComponentType.CHATTER:
|
||||
self._registry._enabled_chatter_registry.pop(component_name, None)
|
||||
case ComponentType.INTEREST_CALCULATOR:
|
||||
self._registry._enabled_interest_calculator_registry.pop(component_name, None)
|
||||
case ComponentType.PROMPT:
|
||||
self._registry._enabled_prompt_registry.pop(component_name, None)
|
||||
case ComponentType.ADAPTER:
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""
|
||||
兴趣系统服务
|
||||
提供独立的兴趣管理功能,不依赖任何插件
|
||||
"""
|
||||
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("interest_service")
|
||||
|
||||
|
||||
class InterestService:
|
||||
"""兴趣系统服务 - 独立于插件的兴趣管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = bot_interest_manager.is_initialized
|
||||
|
||||
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
|
||||
"""
|
||||
初始化智能兴趣系统
|
||||
|
||||
Args:
|
||||
personality_description: 机器人性格描述
|
||||
personality_id: 性格ID
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化智能兴趣系统...")
|
||||
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||
self.is_initialized = True
|
||||
logger.info("智能兴趣系统初始化完成。")
|
||||
|
||||
# 显示初始化后的统计信息
|
||||
stats = bot_interest_manager.get_interest_stats()
|
||||
logger.debug(f"兴趣系统统计: {stats}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
self.is_initialized = False
|
||||
|
||||
async def calculate_interest_match(
|
||||
self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
):
|
||||
"""
|
||||
计算消息与兴趣的匹配度
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
keywords: 关键字列表
|
||||
message_embedding: 已经生成的消息embedding,可选
|
||||
|
||||
Returns:
|
||||
匹配结果
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
logger.warning("兴趣系统未初始化,无法计算匹配度")
|
||||
return None
|
||||
|
||||
try:
|
||||
if not keywords:
|
||||
# 如果没有关键字,则从内容中提取
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
|
||||
return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding)
|
||||
except Exception as e:
|
||||
logger.error(f"计算兴趣匹配失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词"""
|
||||
import re
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def get_interest_stats(self) -> dict:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.is_initialized:
|
||||
return {"initialized": False}
|
||||
|
||||
try:
|
||||
return {
|
||||
"initialized": True,
|
||||
**bot_interest_manager.get_interest_stats()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣系统统计失败: {e}")
|
||||
return {"initialized": True, "error": str(e)}
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
interest_service = InterestService()
|
||||
@@ -1,15 +1,22 @@
|
||||
"""AffinityFlow 风格兴趣值计算组件
|
||||
|
||||
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
|
||||
集成了语义兴趣度计算(TF-IDF + Logistic Regression)
|
||||
|
||||
2024.12 优化更新:
|
||||
- 使用 FastScorer 优化评分(绕过 sklearn,纯 Python 字典计算)
|
||||
- 支持批处理队列模式(高频群聊场景)
|
||||
- 全局线程池避免重复创建 executor
|
||||
- 更短的超时时间(2秒)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
|
||||
@@ -36,18 +43,21 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
# 从配置加载评分权重
|
||||
affinity_config = global_config.affinity_flow
|
||||
self.score_weights = {
|
||||
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
|
||||
"semantic": 0.5, # 语义兴趣度权重(核心维度)
|
||||
"relationship": affinity_config.relationship_weight, # 关系分权重
|
||||
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
|
||||
}
|
||||
|
||||
# 语义兴趣度评分器(替代原有的 embedding 兴趣匹配)
|
||||
self.semantic_scorer = None
|
||||
self.use_semantic_scoring = True # 必须启用
|
||||
self._semantic_initialized = False # 防止重复初始化
|
||||
self.model_manager = None
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||
|
||||
# 兴趣匹配系统配置
|
||||
self.use_smart_matching = True
|
||||
|
||||
# 连续不回复概率提升
|
||||
self.no_reply_count = 0
|
||||
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||
@@ -69,14 +79,17 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count
|
||||
self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate
|
||||
|
||||
logger.info("[Affinity兴趣计算器] 初始化完成:")
|
||||
logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR):")
|
||||
logger.info(f" - 权重配置: {self.score_weights}")
|
||||
logger.info(f" - 回复阈值: {self.reply_threshold}")
|
||||
logger.info(f" - 智能匹配: {self.use_smart_matching}")
|
||||
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression + FastScorer优化)")
|
||||
logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}")
|
||||
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
|
||||
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
|
||||
|
||||
# 异步初始化语义评分器
|
||||
asyncio.create_task(self._initialize_semantic_scorer())
|
||||
|
||||
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
"""执行AffinityFlow风格的兴趣值计算"""
|
||||
try:
|
||||
@@ -93,10 +106,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
|
||||
logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}")
|
||||
|
||||
# 1. 计算兴趣匹配分
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
|
||||
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
|
||||
# 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配)
|
||||
semantic_score = await self._calculate_semantic_score(content)
|
||||
logger.debug(f"[Affinity兴趣计算] 语义兴趣度(TF-IDF+LR): {semantic_score}")
|
||||
|
||||
# 2. 计算关系分
|
||||
relationship_score = await self._calculate_relationship_score(user_id)
|
||||
@@ -108,12 +120,12 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
# 4. 综合评分
|
||||
# 确保所有分数都是有效的 float 值
|
||||
interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0
|
||||
semantic_score = float(semantic_score) if semantic_score is not None else 0.0
|
||||
relationship_score = float(relationship_score) if relationship_score is not None else 0.0
|
||||
mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0
|
||||
|
||||
raw_total_score = (
|
||||
interest_match_score * self.score_weights["interest_match"]
|
||||
semantic_score * self.score_weights["semantic"]
|
||||
+ relationship_score * self.score_weights["relationship"]
|
||||
+ mentioned_score * self.score_weights["mentioned"]
|
||||
)
|
||||
@@ -122,7 +134,8 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
total_score = min(raw_total_score, 1.0)
|
||||
|
||||
logger.debug(
|
||||
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
||||
f"[Affinity兴趣计算] 综合得分计算: "
|
||||
f"{semantic_score:.3f}*{self.score_weights['semantic']} + "
|
||||
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
||||
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
|
||||
)
|
||||
@@ -153,7 +166,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
logger.debug(
|
||||
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
||||
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||
f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||
)
|
||||
|
||||
return InterestCalculationResult(
|
||||
@@ -172,55 +185,6 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(
|
||||
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
|
||||
) -> float:
|
||||
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
|
||||
|
||||
# 调试日志:检查各个条件
|
||||
if not content:
|
||||
logger.debug("兴趣匹配返回0.0: 内容为空")
|
||||
return 0.0
|
||||
if not self.use_smart_matching:
|
||||
logger.debug("兴趣匹配返回0.0: 智能匹配未启用")
|
||||
return 0.0
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化")
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...")
|
||||
|
||||
try:
|
||||
# 使用机器人的兴趣标签系统进行智能匹配(5秒超时保护)
|
||||
match_result = await asyncio.wait_for(
|
||||
bot_interest_manager.calculate_interest_match(
|
||||
content, keywords or [], getattr(message, "semantic_embedding", None)
|
||||
),
|
||||
timeout=5.0
|
||||
)
|
||||
logger.debug(f"兴趣匹配结果: {match_result}")
|
||||
|
||||
if match_result:
|
||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||
affinity_config = global_config.affinity_flow
|
||||
match_count_bonus = min(
|
||||
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||
)
|
||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||
# 移除兴趣匹配分数上限,允许超过1.0,最终分数会被整体限制
|
||||
logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})")
|
||||
return final_score
|
||||
else:
|
||||
logger.debug("兴趣匹配返回0.0: match_result为None")
|
||||
return 0.0
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("[超时] 兴趣匹配计算超时(>5秒),返回默认分值0.5以保留其他分数")
|
||||
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
|
||||
except Exception as e:
|
||||
logger.warning(f"智能兴趣匹配失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算用户关系分"""
|
||||
if not user_id:
|
||||
@@ -316,60 +280,204 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
return adjusted_reply_threshold, adjusted_action_threshold
|
||||
|
||||
def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
async def _initialize_semantic_scorer(self):
|
||||
"""异步初始化语义兴趣度评分器(使用单例 + FastScorer优化)"""
|
||||
# 检查是否已初始化
|
||||
if self._semantic_initialized:
|
||||
logger.debug("[语义评分] 评分器已初始化,跳过")
|
||||
return
|
||||
|
||||
if not self.use_semantic_scoring:
|
||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||
return
|
||||
|
||||
# 防止并发初始化(使用锁)
|
||||
if not hasattr(self, '_init_lock'):
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async with self._init_lock:
|
||||
# 双重检查
|
||||
if self._semantic_initialized:
|
||||
logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过")
|
||||
return
|
||||
|
||||
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||
key_words = getattr(message, "key_words", "")
|
||||
if key_words:
|
||||
try:
|
||||
extracted = orjson.loads(key_words)
|
||||
if isinstance(extracted, list):
|
||||
keywords = extracted
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
from src.chat.semantic_interest import get_semantic_scorer
|
||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||
|
||||
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||
if not keywords:
|
||||
key_words_lite = getattr(message, "key_words_lite", "")
|
||||
if key_words_lite:
|
||||
# 查找最新的模型文件
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
if not model_dir.exists():
|
||||
logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用模型管理器(支持人设感知)
|
||||
if self.model_manager is None:
|
||||
self.model_manager = ModelManager(model_dir)
|
||||
logger.debug("[语义评分] 模型管理器已创建")
|
||||
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
# 先检查是否已有可用模型
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
auto_trainer = get_auto_trainer()
|
||||
existing_model = auto_trainer.get_model_for_persona(persona_info)
|
||||
|
||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||
try:
|
||||
extracted = orjson.loads(key_words_lite)
|
||||
if isinstance(extracted, list):
|
||||
keywords = extracted
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
if existing_model and existing_model.exists():
|
||||
# 直接加载已有模型
|
||||
logger.info(f"[语义评分] 使用已有模型: {existing_model.name}")
|
||||
scorer = await get_semantic_scorer(existing_model, use_async=True)
|
||||
else:
|
||||
# 使用 ModelManager 自动选择或训练
|
||||
scorer = await self.model_manager.load_model(
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
|
||||
self.semantic_scorer = scorer
|
||||
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
|
||||
# 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动
|
||||
if not existing_model or not existing_model.exists():
|
||||
await self.model_manager.start_auto_training(
|
||||
persona_info=persona_info,
|
||||
interval_hours=24
|
||||
)
|
||||
else:
|
||||
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
force=True # 强制训练
|
||||
)
|
||||
if trained and model_path:
|
||||
# 使用单例获取评分器(默认启用 FastScorer)
|
||||
self.semantic_scorer = await get_semantic_scorer(model_path)
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
else:
|
||||
logger.error("[语义评分] 首次训练失败")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
# 如果还是没有,从消息内容中提取(降级方案)
|
||||
if not keywords:
|
||||
content = getattr(message, "processed_plain_text", "") or ""
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
except ImportError:
|
||||
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
||||
self.use_semantic_scoring = False
|
||||
except Exception as e:
|
||||
logger.error(f"[语义评分] 初始化失败: {e}")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
def _get_current_persona_info(self) -> dict[str, Any]:
|
||||
"""获取当前人设信息
|
||||
|
||||
Returns:
|
||||
人设信息字典
|
||||
"""
|
||||
# 默认信息(至少包含名字)
|
||||
persona_info = {
|
||||
"name": global_config.bot.nickname,
|
||||
"interests": [],
|
||||
"dislikes": [],
|
||||
"personality": "",
|
||||
}
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
# 优先从已生成的人设文件获取(Individuality 初始化时会生成)
|
||||
try:
|
||||
persona_file = Path("data/personality/personality_data.json")
|
||||
if persona_file.exists():
|
||||
data = orjson.loads(persona_file.read_bytes())
|
||||
personality_parts = [data.get("personality", ""), data.get("identity", "")]
|
||||
persona_info["personality"] = ",".join([p for p in personality_parts if p]).strip(",")
|
||||
if persona_info["personality"]:
|
||||
return persona_info
|
||||
except Exception as e:
|
||||
logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}")
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
# 退化为配置中的人设描述
|
||||
try:
|
||||
personality_parts = []
|
||||
personality_core = getattr(global_config.personality, "personality_core", "")
|
||||
personality_side = getattr(global_config.personality, "personality_side", "")
|
||||
identity = getattr(global_config.personality, "identity", "")
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
if personality_core:
|
||||
personality_parts.append(personality_core)
|
||||
if personality_side:
|
||||
personality_parts.append(personality_side)
|
||||
if identity:
|
||||
personality_parts.append(identity)
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
persona_info["personality"] = ",".join(personality_parts) or "默认人设"
|
||||
except Exception as e:
|
||||
logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}")
|
||||
persona_info["personality"] = "默认人设"
|
||||
|
||||
return persona_info
|
||||
|
||||
async def _calculate_semantic_score(self, content: str) -> float:
|
||||
"""计算语义兴趣度分数(优化版:FastScorer + 可选批处理 + 超时保护)
|
||||
|
||||
Args:
|
||||
content: 消息文本
|
||||
|
||||
Returns:
|
||||
语义兴趣度分数 [0.0, 1.0]
|
||||
"""
|
||||
# 检查是否启用
|
||||
if not self.use_semantic_scoring:
|
||||
return 0.0
|
||||
|
||||
# 检查评分器是否已加载
|
||||
if not self.semantic_scorer:
|
||||
return 0.0
|
||||
|
||||
# 检查内容是否为空
|
||||
if not content or not content.strip():
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
score = await self.semantic_scorer.score_async(content, timeout=2.0)
|
||||
|
||||
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
|
||||
return score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[语义评分] 计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def reload_semantic_model(self):
|
||||
"""重新加载语义兴趣度模型(支持热更新和人设检查)"""
|
||||
if not self.use_semantic_scoring:
|
||||
logger.info("[语义评分] 语义评分未启用,无需重载")
|
||||
return
|
||||
|
||||
logger.info("[语义评分] 开始重新加载模型...")
|
||||
|
||||
# 检查人设是否变化
|
||||
if hasattr(self, 'model_manager') and self.model_manager:
|
||||
persona_info = self._get_current_persona_info()
|
||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||
if reloaded:
|
||||
self.semantic_scorer = self.model_manager.get_scorer()
|
||||
|
||||
logger.info("[语义评分] 模型重载完成(人设已更新)")
|
||||
else:
|
||||
logger.info("[语义评分] 人设未变化,无需重载")
|
||||
else:
|
||||
# 降级:简单重新初始化
|
||||
self._semantic_initialized = False
|
||||
await self._initialize_semantic_scorer()
|
||||
logger.info("[语义评分] 模型重载完成")
|
||||
|
||||
def update_no_reply_count(self, replied: bool):
|
||||
"""更新连续不回复计数"""
|
||||
@@ -415,3 +523,5 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug(
|
||||
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
|
||||
)
|
||||
|
||||
afc_interest_calculator = AffinityInterestCalculator()
|
||||
@@ -7,8 +7,6 @@ import asyncio
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -52,6 +50,8 @@ class ChatterActionPlanner:
|
||||
self.action_manager = action_manager
|
||||
self.generator = ChatterPlanGenerator(chat_id, action_manager)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
self._interest_calculator = None
|
||||
self._interest_calculator_lock = asyncio.Lock()
|
||||
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
@@ -130,60 +130,32 @@ class ChatterActionPlanner:
|
||||
if not pending_messages:
|
||||
return
|
||||
|
||||
calculator = await self._get_interest_calculator()
|
||||
if not calculator:
|
||||
logger.debug("未获取到兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
|
||||
|
||||
if not bot_interest_manager.is_initialized:
|
||||
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
try:
|
||||
interest_manager = get_interest_manager()
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取兴趣管理器失败: {exc}")
|
||||
return
|
||||
|
||||
if not interest_manager or not interest_manager.has_calculator():
|
||||
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
|
||||
return
|
||||
|
||||
text_map: dict[str, str] = {}
|
||||
for message in pending_messages:
|
||||
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
|
||||
text_map[str(message.message_id)] = text
|
||||
|
||||
try:
|
||||
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
|
||||
except Exception as exc:
|
||||
logger.error(f"批量获取消息embedding失败: {exc}")
|
||||
embeddings = {}
|
||||
|
||||
interest_updates: dict[str, float] = {}
|
||||
reply_updates: dict[str, bool] = {}
|
||||
|
||||
for message in pending_messages:
|
||||
message_id = str(message.message_id)
|
||||
if message_id in embeddings:
|
||||
message.semantic_embedding = embeddings[message_id]
|
||||
|
||||
try:
|
||||
result = await interest_manager.calculate_interest(message)
|
||||
result = await calculator._safe_execute(message) # 使用带统计的安全执行
|
||||
except Exception as exc:
|
||||
logger.error(f"批量计算消息兴趣失败: {exc}")
|
||||
continue
|
||||
|
||||
if result.success:
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = True
|
||||
message.interest_value = result.interest_value
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
message.interest_calculated = result.success
|
||||
|
||||
message_id = str(getattr(message, "message_id", ""))
|
||||
if message_id:
|
||||
interest_updates[message_id] = result.interest_value
|
||||
reply_updates[message_id] = result.should_reply
|
||||
|
||||
# 批量处理后清理 embeddings 字典
|
||||
embeddings.clear()
|
||||
text_map.clear()
|
||||
else:
|
||||
message.interest_calculated = False
|
||||
|
||||
if interest_updates:
|
||||
try:
|
||||
@@ -191,6 +163,32 @@ class ChatterActionPlanner:
|
||||
except Exception as exc:
|
||||
logger.error(f"批量更新消息兴趣值失败: {exc}")
|
||||
|
||||
async def _get_interest_calculator(self):
|
||||
"""懒加载兴趣计算器,直接使用计算器实例进行兴趣计算"""
|
||||
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
|
||||
return self._interest_calculator
|
||||
|
||||
async with self._interest_calculator_lock:
|
||||
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
|
||||
return self._interest_calculator
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||
afc_interest_calculator,
|
||||
)
|
||||
|
||||
calculator = afc_interest_calculator
|
||||
if not await calculator.initialize():
|
||||
logger.warning("AffinityInterestCalculator 初始化失败")
|
||||
return None
|
||||
|
||||
self._interest_calculator = calculator
|
||||
logger.debug("AffinityInterestCalculator 已就绪")
|
||||
return self._interest_calculator
|
||||
except Exception as exc:
|
||||
logger.warning(f"创建 AffinityInterestCalculator 失败: {exc}")
|
||||
return None
|
||||
|
||||
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""Focus模式下的完整plan流程
|
||||
|
||||
@@ -589,13 +587,11 @@ class ChatterActionPlanner:
|
||||
replied: 是否回复了消息
|
||||
"""
|
||||
try:
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||
AffinityInterestCalculator,
|
||||
)
|
||||
|
||||
interest_manager = get_interest_manager()
|
||||
calculator = interest_manager.get_current_calculator()
|
||||
calculator = await self._get_interest_calculator()
|
||||
|
||||
if calculator and isinstance(calculator, AffinityInterestCalculator):
|
||||
calculator.on_message_processed(replied)
|
||||
|
||||
@@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 AffinityInterestCalculator(从 core 子模块)
|
||||
from .core.affinity_interest_calculator import AffinityInterestCalculator
|
||||
|
||||
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 UserProfileTool(从 tools 子模块)
|
||||
from .tools.user_profile_tool import UserProfileTool
|
||||
|
||||
@@ -158,6 +158,9 @@ class KokoroFlowChatterConfig:
|
||||
# LLM 配置
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
|
||||
# 自定义决策提示词
|
||||
custom_decision_prompt: str = ""
|
||||
|
||||
# 调试模式
|
||||
debug: bool = False
|
||||
|
||||
@@ -256,6 +259,10 @@ def load_config() -> KokoroFlowChatterConfig:
|
||||
timeout=getattr(llm_cfg, "timeout", 60.0),
|
||||
)
|
||||
|
||||
# 自定义决策提示词配置
|
||||
if hasattr(kfc_cfg, "custom_decision_prompt"):
|
||||
config.custom_decision_prompt = str(kfc_cfg.custom_decision_prompt)
|
||||
|
||||
except Exception as e:
|
||||
from src.common.logger import get_logger
|
||||
logger = get_logger("kfc_config")
|
||||
|
||||
@@ -235,7 +235,7 @@ class KFCContextBuilder:
|
||||
|
||||
search_result = await unified_manager.search_memories(
|
||||
query_text=query_text,
|
||||
use_judge=True,
|
||||
use_judge=config.memory.use_judge,
|
||||
recent_chat_history=chat_history,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,6 +72,9 @@ class PromptBuilder:
|
||||
# 1.5. 构建安全互动准则块
|
||||
safety_guidelines_block = self._build_safety_guidelines_block()
|
||||
|
||||
# 1.6. 构建自定义决策提示词块
|
||||
custom_decision_block = self._build_custom_decision_block()
|
||||
|
||||
# 2. 使用 context_builder 获取关系、记忆、工具、表达习惯等
|
||||
context_data = await self._build_context_data(user_name, chat_stream, user_id)
|
||||
relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。")
|
||||
@@ -102,6 +105,7 @@ class PromptBuilder:
|
||||
user_name=user_name,
|
||||
persona_block=persona_block,
|
||||
safety_guidelines_block=safety_guidelines_block,
|
||||
custom_decision_block=custom_decision_block,
|
||||
relation_block=relation_block,
|
||||
memory_block=memory_block or "(暂无相关记忆)",
|
||||
tool_info=tool_info or "(暂无工具信息)",
|
||||
@@ -232,6 +236,23 @@ class PromptBuilder:
|
||||
{guidelines_text}
|
||||
如果遇到违反上述原则的请求,请在保持你核心人设的同时,以合适的方式进行回应。"""
|
||||
|
||||
def _build_custom_decision_block(self) -> str:
|
||||
"""
|
||||
构建自定义决策提示词块
|
||||
|
||||
从配置中读取 custom_decision_prompt,用于指导KFC的决策行为
|
||||
类似于AFC的planner_custom_prompt_content
|
||||
"""
|
||||
from ..config import get_config
|
||||
|
||||
kfc_config = get_config()
|
||||
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
||||
|
||||
if not custom_prompt or not custom_prompt.strip():
|
||||
return ""
|
||||
|
||||
return custom_prompt.strip()
|
||||
|
||||
def _build_combined_expression_block(self, learned_habits: str) -> str:
|
||||
"""
|
||||
构建合并后的表达习惯块
|
||||
@@ -693,8 +714,22 @@ class PromptBuilder:
|
||||
|
||||
# 添加真正追问次数警告(只有真正发了消息才算追问)
|
||||
followup_count = extra_context.get("followup_count", 0)
|
||||
if followup_count > 0:
|
||||
timeout_context_parts.append(f"⚠️ 你已经连续追问了 {followup_count} 次,对方仍未回复。再追问可能会显得太急躁,请三思。")
|
||||
if followup_count >= 2:
|
||||
timeout_context_parts.append(
|
||||
f"⚠️ **强烈建议**: 你已经连续追问了 {followup_count} 次,对方仍未回复。"
|
||||
"**极度推荐选择 `do_nothing` 或主动结束话题**。"
|
||||
"对方可能在忙或需要空间,不是所有人都能一直在线。给彼此一些空间会更好。"
|
||||
)
|
||||
elif followup_count == 1:
|
||||
timeout_context_parts.append(
|
||||
"📝 温馨提醒:这是你第2次等待回复(已追问1次)。"
|
||||
"可以再试着追问一次,但如果对方还是没回复,**强烈建议**之后选择 `do_nothing` 或结束话题。"
|
||||
)
|
||||
elif followup_count == 0:
|
||||
timeout_context_parts.append(
|
||||
"💭 追问提示:如果对方一段时间未回复,可以适当追问一次。"
|
||||
"但要记住对方可能在忙,建议最多追问2次左右,之后给对方一些空间。"
|
||||
)
|
||||
|
||||
# 添加距离用户上次回复的时间
|
||||
time_since_user_reply_str = extra_context.get("time_since_user_reply_str")
|
||||
|
||||
@@ -24,6 +24,9 @@ kfc_MAIN_PROMPT = Prompt(
|
||||
# 安全互动准则
|
||||
{safety_guidelines_block}
|
||||
|
||||
# 决策指导
|
||||
{custom_decision_block}
|
||||
|
||||
# 相关信息
|
||||
{memory_block}
|
||||
|
||||
|
||||
@@ -100,6 +100,33 @@ def build_rules_module() -> str:
|
||||
{kfc_rules}"""
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模块2.5: 自定义决策提示词 - 用户自定义的决策行为指导
|
||||
# ============================================================
|
||||
|
||||
def build_custom_decision_module() -> str:
|
||||
"""
|
||||
构建自定义决策提示词模块
|
||||
|
||||
允许用户通过配置文件自定义KFC的决策行为指导
|
||||
类似于AFC的planner_custom_prompt_content
|
||||
"""
|
||||
from .config import get_config
|
||||
|
||||
kfc_config = get_config()
|
||||
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
||||
|
||||
# 调试输出
|
||||
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}")
|
||||
|
||||
if not custom_prompt or not custom_prompt.strip():
|
||||
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
|
||||
return ""
|
||||
|
||||
logger.info(f"[自定义决策提示词] 已加载: {custom_prompt.strip()}")
|
||||
return custom_prompt.strip()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模块3: 情境上下文 - 时间/场景/内在状态/关系/记忆
|
||||
# ============================================================
|
||||
@@ -354,6 +381,9 @@ def build_system_prompt(
|
||||
# 角色框架(不提AI,直接进入角色)
|
||||
role_frame = "你是一个真实的人,有自己的生活、想法和情绪。像平时和朋友聊天一样自然地交流。"
|
||||
|
||||
# 获取自定义决策提示词
|
||||
custom_decision_block = build_custom_decision_module()
|
||||
|
||||
# 组装各模块
|
||||
modules = [
|
||||
role_frame,
|
||||
@@ -363,6 +393,17 @@ def build_system_prompt(
|
||||
"",
|
||||
"## 2. 规则",
|
||||
build_rules_module(),
|
||||
]
|
||||
|
||||
# 如果有自定义决策提示词,插入到规则后面
|
||||
if custom_decision_block:
|
||||
modules.extend([
|
||||
"",
|
||||
"## 2.5. 决策指导",
|
||||
custom_decision_block,
|
||||
])
|
||||
|
||||
modules.extend([
|
||||
"",
|
||||
"## 3. 现在的情况",
|
||||
build_context_module(session, chat_stream, context_data),
|
||||
@@ -372,7 +413,7 @@ def build_system_prompt(
|
||||
"",
|
||||
"## 5. 怎么回复",
|
||||
build_output_module(context_data),
|
||||
]
|
||||
])
|
||||
|
||||
return "\n".join(modules)
|
||||
|
||||
|
||||
@@ -147,17 +147,35 @@ class UnifiedPromptGenerator:
|
||||
|
||||
# 生成连续追问警告(使用 followup_count 作为追问计数,只有真正发消息才算)
|
||||
followup_count = session.waiting_config.followup_count
|
||||
max_followups = 3 # 最多追问3次
|
||||
max_followups = 2 # 建议最多追问2次
|
||||
|
||||
if followup_count >= max_followups:
|
||||
followup_warning = f"""⚠️ **重要提醒**:
|
||||
followup_warning = f"""⚠️ **强烈建议**:
|
||||
你已经连续追问了 {followup_count} 次,对方都没有回复。
|
||||
**强烈建议不要再发消息了**——继续追问会显得很缠人、很不尊重对方的空间。
|
||||
对方可能真的在忙,或者暂时不想回复,这都是正常的。
|
||||
请选择 `do_nothing` 继续等待,或者直接结束对话(设置 `max_wait_seconds: 0`)。"""
|
||||
elif followup_count > 0:
|
||||
followup_warning = f"""📝 提示:这已经是你第 {followup_count + 1} 次等待对方回复了。
|
||||
如果对方持续没有回应,可能真的在忙或不方便,不需要急着追问。"""
|
||||
**极度推荐选择 `do_nothing` 或设置 `max_wait_seconds: 0` 结束这个话题**。
|
||||
|
||||
对方很可能:
|
||||
- 正在忙自己的事情,没有时间回复
|
||||
- 需要一些个人空间和独处时间
|
||||
- 暂时不方便或不想聊天
|
||||
|
||||
这些都是完全正常的。不是所有人都能一直在线,每个人都有自己的生活节奏。
|
||||
继续追问很可能会让对方感到压力和不适,不如给彼此一些空间。
|
||||
|
||||
**最好的选择**:
|
||||
1. 选择 `do_nothing` 安静等待对方主动联系
|
||||
2. 或者主动结束这个话题(`max_wait_seconds: 0`),让对方知道你理解他们可能在忙"""
|
||||
elif followup_count == 1:
|
||||
followup_warning = """📝 温馨提醒:
|
||||
这是你第2次等待对方回复(已追问1次)。
|
||||
可以再试着温柔地追问一次,但要做好对方可能真的在忙的心理准备。
|
||||
如果这次之后对方还是没回复,**强烈建议**不要再继续追问了——
|
||||
选择 `do_nothing` 给对方空间,或者主动结束话题,都是尊重对方的表现。"""
|
||||
elif followup_count == 0:
|
||||
followup_warning = """💭 追问提示:
|
||||
如果对方一段时间没回复,可以适当追问一次,用轻松的语气提醒一下。
|
||||
但要记住:不是所有人都能一直在线,对方可能在忙。
|
||||
建议最多追问2次左右,之后就给对方一些空间吧。"""
|
||||
else:
|
||||
followup_warning = ""
|
||||
|
||||
|
||||
@@ -619,7 +619,6 @@ class SystemCommand(PlusCommand):
|
||||
# 禁用保护
|
||||
if not enabled:
|
||||
protected_types = [
|
||||
ComponentType.INTEREST_CALCULATOR,
|
||||
ComponentType.PROMPT,
|
||||
ComponentType.ROUTER,
|
||||
]
|
||||
@@ -736,7 +735,6 @@ class SystemCommand(PlusCommand):
|
||||
if not enabled: # 如果是禁用操作
|
||||
# 定义不可禁用的核心组件类型
|
||||
protected_types = [
|
||||
ComponentType.INTEREST_CALCULATOR,
|
||||
ComponentType.PROMPT,
|
||||
ComponentType.ROUTER,
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.9.8"
|
||||
version = "8.0.0"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -103,7 +103,7 @@ command_prefixes = ['/']
|
||||
|
||||
[personality]
|
||||
# 建议50字以内,描述人格的核心特质
|
||||
personality_core = "是一个积极向上的女大学生"
|
||||
personality_core = "是一个积极向上的女大学生"
|
||||
# 人格的细节,描述人格的一些侧面
|
||||
personality_side = "用一句话或几句话描述人格的侧面特质"
|
||||
#アイデンティティがない 生まれないらららら
|
||||
@@ -134,6 +134,8 @@ compress_identity = false # 是否压缩身份,压缩后会精简身份信息
|
||||
# - "classic": 经典模式,随机抽样 + LLM选择
|
||||
# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达
|
||||
mode = "classic"
|
||||
# model_temperature: 机器预测模式下的“温度”,0 为贪婪,越大越爱探索(更容易选到低分表达)
|
||||
model_temperature = 1.0
|
||||
|
||||
# expiration_days: 表达方式过期天数,超过此天数未激活的表达方式将被清理
|
||||
expiration_days = 1
|
||||
@@ -311,6 +313,7 @@ short_term_search_top_k = 5 # 搜索时返回的最大数量
|
||||
short_term_decay_factor = 0.98 # 衰减因子
|
||||
|
||||
# 长期记忆层配置
|
||||
use_judge = true # 使用评判模型决定是否检索长期记忆
|
||||
long_term_batch_size = 10 # 批量转移大小
|
||||
long_term_decay_factor = 0.95 # 衰减因子
|
||||
long_term_auto_transfer_interval = 180 # 自动转移间隔(秒)
|
||||
@@ -425,7 +428,7 @@ auto_install = true #it can work now!
|
||||
auto_install_timeout = 300
|
||||
# 是否使用PyPI镜像源(推荐,可加速下载)
|
||||
use_mirror = true
|
||||
mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL,如: "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL,如: "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
# 依赖安装日志级别
|
||||
install_log_level = "INFO"
|
||||
|
||||
@@ -536,14 +539,6 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在
|
||||
# 兴趣评分系统参数
|
||||
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
|
||||
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
|
||||
high_match_interest_threshold = 0.6 # 高匹配兴趣阈值
|
||||
medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值
|
||||
low_match_interest_threshold = 0.2 # 低匹配兴趣阈值
|
||||
high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率
|
||||
medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率
|
||||
low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率
|
||||
match_count_bonus = 0.01 # 匹配数关键词加成值
|
||||
max_match_bonus = 0.1 # 最大匹配数加成值
|
||||
|
||||
# 回复决策系统参数
|
||||
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值
|
||||
@@ -636,6 +631,13 @@ mode = "split"
|
||||
max_wait_seconds_default = 300 # 默认的最大等待秒数(AI发送消息后愿意等待用户回复的时间)
|
||||
enable_continuous_thinking = true # 是否在等待期间启用心理活动更新
|
||||
|
||||
# --- 自定义决策提示词 ---
|
||||
# 类似于AFC的planner_custom_prompt_content,允许用户自定义KFC的决策行为指导
|
||||
# 在unified模式下:会插入到完整提示词中,影响整体思考和回复生成
|
||||
# 在split模式下:只会插入到planner提示词中,影响决策规划,不影响replyer的回复生成
|
||||
# 留空则不生效
|
||||
custom_decision_prompt = ""
|
||||
|
||||
# --- 等待策略 ---
|
||||
[kokoro_flow_chatter.waiting]
|
||||
default_max_wait_seconds = 300 # LLM 未给出等待时间时的默认值
|
||||
|
||||
Reference in New Issue
Block a user