Compare commits

...

31 Commits

Author SHA1 Message Date
1aa09ee340 feat: 添加 ffmpeg
All checks were successful
Build and Push Docker Image / build-and-push (push) Successful in 2m56s
2025-12-13 02:56:13 +08:00
25bd23ad3f chore: 添加本地构建配置 2025-12-13 02:56:08 +08:00
minecraft1024a
179b5b7222 feat(log): 添加日志广播系统以实时推送日志到多个订阅者 2025-12-12 21:56:25 +08:00
minecraft1024a
f39b0eaa44 Revert "重构异常处理并移除 orjson 依赖"
This reverts commit 70217d7df8.
2025-12-12 20:40:56 +08:00
LuiKlee
b55df150d4 Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev 2025-12-12 15:43:26 +08:00
LuiKlee
70217d7df8 重构异常处理并移除 orjson 依赖 2025-12-12 15:41:48 +08:00
tt-P607
f1bfcd1cff Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev 2025-12-12 15:13:18 +08:00
tt-P607
5a1d5052ca feat(kfc): 引入自定义决策提示并优化后续警告
增加了一个新的配置选项 `custom_decision_prompt`,允许用户提供具体指令来指导 Kokoro Flow Chatter 的决策过程。该提示会被整合到拆分和统一的提示生成中,以影响整体行为。

此外,系统的后续警告逻辑也得到了显著增强。警告现在更加细致和明确,提供了更清晰的建议,如在多次未响应的后续请求后选择 `do_nothing` 或结束对话,从而促进更健康的交互模式。配置模板中的内部版本已更新。
2025-12-12 15:12:28 +08:00
Windpicker-owo
35502914a7 Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev 2025-12-12 15:09:29 +08:00
Windpicker-owo
7d547b7b80 feat: 修复JSON解析问题并增加批量标注大小至50 2025-12-12 15:09:00 +08:00
LuiKlee
700cf477fb Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev 2025-12-12 15:06:36 +08:00
LuiKlee
1f0b8fa04d 修正参数名称和类型注释
www
2025-12-12 15:06:33 +08:00
Windpicker-owo
1087d46ce2 chore: 将MMC_VERSION更新至0.13.1-alpha.1 2025-12-12 15:02:16 +08:00
Windpicker-owo
da3752725e chore: 更新版本号至0.13.1-alpha.2和8.0.0,调整兴趣评分阈值 2025-12-12 14:59:44 +08:00
Windpicker-owo
e5e552df65 feat: 更新自动训练器和数据集生成器,增加初始关键词生成功能 2025-12-12 14:56:11 +08:00
Windpicker-owo
0193913841 refactor: 移除兴趣计算器相关代码和配置,优化系统管理插件 2025-12-12 14:38:15 +08:00
Windpicker-owo
e6a4f855a2 feat: 提升语义兴趣评分与拼写错误生成
- 为中文拼写生成器实现了背景预热功能,以提升首次使用时的性能。
- 更新了MessageStorageBatcher以支持可配置的提交批次大小和间隔,优化数据库写入性能。
- 增强版数据集生成器,对样本规模设置硬性限制并提升采样效率。
- 将AutoTrainer中的最大样本数增加至1000,以优化训练数据利用率。
- 对亲和兴趣计算器进行了重构,以避免并发初始化并优化模型加载逻辑。
- 引入批量处理机制用于语义兴趣评分,以应对高频聊天场景。
- 更新了配置模板以反映新的评分参数,并移除了已弃用的兴趣阈值。
2025-12-12 14:11:36 +08:00
Windpicker-owo
9d01b81cef feat: 通过FastScorer与批处理功能增强关联兴趣计算器
- 集成FastScorer用于优化评分,绕过sklearn以提升性能。
- 新增批量处理功能,以应对高频聊天场景。
- 实现了一个全局线程池以避免重复创建执行器。
- 将评分操作的超时时间缩短至2秒。
- 重构了ChatterActionPlanner以利用新的利息计算器。
- 引入了一个基准测试脚本,用于比较原始sklearn与FastScorer之间的性能差异。
开发了一款优化后的评分器,具备权重剪枝和异步评分等功能。
2025-12-12 12:14:21 +08:00
Windpicker-owo
ef0c569348 fix(query_builder): 优化分页查询逻辑,确保字段可用后再释放数据库连接 2025-12-11 21:50:28 +08:00
Windpicker-owo
e8bffe4a87 feat: 实现TF-IDF特征提取器和逻辑回归模型用于语义兴趣评分
- 新增了TfidfFeatureExtractor,用于字符级n-gram的TF-IDF向量化,适用于中文及多语言场景。
- 基于逻辑回归开发了语义兴趣模型,用于多类别兴趣标签(-1、0、1)的预测。
- 创建了在线推理的运行时评分器,实现消息兴趣评分的快速评估。
建立了模型训练、评估和数据集准备的全流程培训体系。
- 集成模型管理,支持热加载与个性化模型选择。
2025-12-11 21:28:27 +08:00
拾风
59e7a1a846 Merge pull request #28 from Gardelll/dev
修复一些LLM响应解析问题和添加memory.use_judge配置项
2025-12-11 15:46:22 +08:00
Windpicker-owo
633585e6af Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev 2025-12-11 13:57:34 +08:00
Windpicker-owo
c75cc88fb5 feat(expression_selector): 添加温度采样功能以优化表达选择
feat(official_configs): 新增模型温度配置项以支持表达模型采样
chore(bot_config_template): 更新版本号并添加模型温度说明
2025-12-11 13:57:17 +08:00
拾风
2d02bf4631 Merge pull request #27 from Gardelll/fix-memory-extract-prompt
修复记忆提取的问题
2025-12-10 22:07:10 +08:00
雅诺狐
4592e37c10 fix(config): 修复配置加载中的类型验证问题,避免Pydantic严格模式下的错误 2025-12-10 15:11:25 +08:00
雅诺狐
c870af768d fix(redis):更新Redis连接池初始化,以兼容redis-py 7.x版本
更新Redis连接池创建方式,使用connection_class参数替代已弃用的ssl参数,以适配redis-py 7.x及以上版本
2025-12-10 15:06:01 +08:00
7735b161c8 feat: 添加选项必须检索长期记忆 2025-12-10 12:52:41 +08:00
016c8647f7 fix: 修复回复分割问题 2025-12-10 12:52:41 +08:00
f269034b6a fix: 修复 VLM 解析 2025-12-10 12:52:35 +08:00
eac1ef2869 fix: 分析记忆时修复引号内容 2025-12-10 00:01:59 +08:00
8f3338f845 fix: 记忆提取添加末尾逗号 2025-12-10 00:00:53 +08:00
50 changed files with 4303 additions and 2126 deletions

View File

@@ -0,0 +1,32 @@
name: Build and Push Docker Image
on:
push:
branches:
- dev
- gitea
jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Registry
uses: docker/login-action@v3
with:
registry: docker.gardel.top
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: docker.gardel.top/gardel/mofox:dev
build-args: |
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
VCS_REF=${{ github.sha }}

View File

@@ -34,7 +34,6 @@ MoFox_Bot 是基于 MaiCore 的增强型 QQ 聊天机器人,集成了 LLM、
- `PLUS_COMMAND`: 增强命令(支持参数解析、权限检查)
- `TOOL`: LLM 工具调用(函数调用集成)
- `EVENT_HANDLER`: 事件订阅处理器
- `INTEREST_CALCULATOR`: 兴趣值计算器
- `PROMPT`: 自定义提示词注入
**插件开发流程**:

View File

@@ -1,149 +0,0 @@
name: Docker Build and Push
on:
push:
branches:
- master
- dev
tags:
- "v*.*.*"
- "v*"
- "*.*.*"
- "*.*.*-*"
workflow_dispatch: # 允许手动触发工作流
# Workflow's jobs
jobs:
build-amd64:
name: Build AMD64 Image
runs-on: ubuntu-24.04
outputs:
digest: ${{ steps.build.outputs.digest }}
steps:
- name: Check out git repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-flags: --debug
# Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Generate metadata for Docker images
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
# Build and push AMD64 image by digest
- name: Build and push AMD64
id: build
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
file: ./Dockerfile
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:amd64-buildcache,mode=max
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
build-args: |
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
VCS_REF=${{ github.sha }}
build-arm64:
name: Build ARM64 Image
runs-on: ubuntu-24.04-arm
outputs:
digest: ${{ steps.build.outputs.digest }}
steps:
- name: Check out git repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
buildkitd-flags: --debug
# Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Generate metadata for Docker images
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
# Build and push ARM64 image by digest
- name: Build and push ARM64
id: build
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/arm64/v8
labels: ${{ steps.meta.outputs.labels }}
file: ./Dockerfile
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/mofox:arm64-buildcache,mode=max
outputs: type=image,name=${{ secrets.DOCKERHUB_USERNAME }}/mofox,push-by-digest=true,name-canonical=true,push=true
build-args: |
BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')
VCS_REF=${{ github.sha }}
create-manifest:
name: Create Multi-Arch Manifest
runs-on: ubuntu-24.04
needs:
- build-amd64
- build-arm64
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# Log in docker hub
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# Generate metadata for Docker images
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/mofox
tags: |
type=ref,event=branch
type=ref,event=tag
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
- name: Create and Push Manifest
run: |
# 为每个标签创建多架构镜像
for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr '\n' ' '); do
echo "Creating manifest for $tag"
docker buildx imagetools create -t $tag \
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-amd64.outputs.digest }} \
${{ secrets.DOCKERHUB_USERNAME }}/mofox@${{ needs.build-arm64.outputs.digest }}
done

1
.gitignore vendored
View File

@@ -342,3 +342,4 @@ package.json
/backup
mofox_bot_statistics.html
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
depends-data/pinyin_dict.json

View File

@@ -9,6 +9,10 @@ RUN apt-get update && apt-get install -y build-essential
# 复制依赖列表和锁文件
COPY pyproject.toml uv.lock ./
COPY --from=mwader/static-ffmpeg:latest /ffmpeg /usr/local/bin/ffmpeg
COPY --from=mwader/static-ffmpeg:latest /ffprobe /usr/local/bin/ffprobe
RUN ldconfig && ffmpeg -version
# 安装依赖(使用 --frozen 确保使用锁文件中的版本)
RUN uv sync --frozen --no-dev

21
bot.py
View File

@@ -35,7 +35,6 @@ class StartupStageReporter:
else:
self._logger.info(title)
startup_stage = StartupStageReporter(logger)
# 常量定义
@@ -567,6 +566,7 @@ class MaiBotMain:
def __init__(self):
self.main_system = None
self._typo_prewarm_task = None
def setup_timezone(self):
"""设置时区"""
@@ -663,6 +663,25 @@ class MaiBotMain:
async def run_async_init(self, main_system):
"""执行异步初始化步骤"""
# 后台预热中文错别字生成器,避免首次使用阻塞主流程
try:
from src.chat.utils.typo_generator import get_typo_generator
typo_cfg = getattr(global_config, "chinese_typo", None)
self._typo_prewarm_task = asyncio.create_task(
asyncio.to_thread(
get_typo_generator,
error_rate=getattr(typo_cfg, "error_rate", 0.3),
min_freq=getattr(typo_cfg, "min_freq", 5),
tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2),
word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3),
max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200),
)
)
logger.debug("已启动 ChineseTypoGenerator 后台预热任务")
except Exception as e:
logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}")
# 初始化数据库表结构
await self.initialize_database_async()

View File

@@ -4,6 +4,7 @@ import binascii
import hashlib
import io
import json
import json_repair
import os
import random
import re
@@ -1022,6 +1023,15 @@ class EmojiManager:
- 必须是表情包,非普通截图。
- 图中文字不超过5个。
请确保你的最终输出是严格的JSON对象不要添加任何额外解释或文本。
输出格式:
```json
{{
"detailed_description": "",
"keywords": [],
"refined_sentence": "",
"is_compliant": true
}}
```
"""
image_data_for_vlm, image_format_for_vlm = image_base64, image_format
@@ -1041,16 +1051,14 @@ class EmojiManager:
if not vlm_response_str:
continue
match = re.search(r"\{.*\}", vlm_response_str, re.DOTALL)
if match:
vlm_response_json = json.loads(match.group(0))
description = vlm_response_json.get("detailed_description", "")
emotions = vlm_response_json.get("keywords", [])
refined_description = vlm_response_json.get("refined_sentence", "")
is_compliant = vlm_response_json.get("is_compliant", False)
if description and emotions and refined_description:
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
break
vlm_response_json = self._parse_json_response(vlm_response_str)
description = vlm_response_json.get("detailed_description", "")
emotions = vlm_response_json.get("keywords", [])
refined_description = vlm_response_json.get("refined_sentence", "")
is_compliant = vlm_response_json.get("is_compliant", False)
if description and emotions and refined_description:
logger.info("[VLM分析] 成功解析VLM返回的JSON数据。")
break
logger.warning("[VLM分析] VLM返回的JSON数据不完整或格式错误准备重试。")
except (json.JSONDecodeError, AttributeError) as e:
logger.error(f"VLM JSON解析失败 (第 {i+1}/3 次): {e}")
@@ -1195,6 +1203,29 @@ class EmojiManager:
logger.error(f"[错误] 删除异常处理文件时出错: {remove_error}")
return False
@classmethod
def _parse_json_response(cls, response: str) -> dict[str, Any] | None:
"""解析 LLM 的 JSON 响应"""
try:
# 尝试提取 JSON 代码块
json_match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 尝试直接解析
json_str = response.strip()
# 移除可能的注释
json_str = re.sub(r"//.*", "", json_str)
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
data = json_repair.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.warning(f"JSON 解析失败: {e}, 响应: {response[:200]}")
return None
emoji_manager = None

View File

@@ -1,5 +1,6 @@
import asyncio
import hashlib
import math
import random
import time
from typing import Any
@@ -76,6 +77,45 @@ def weighted_sample(population: list[dict], weights: list[float], k: int) -> lis
class ExpressionSelector:
@staticmethod
def _sample_with_temperature(
candidates: list[tuple[Any, float, float, str]],
max_num: int,
temperature: float,
) -> list[tuple[Any, float, float, str]]:
"""
对候选表达按温度采样,温度越高越均匀。
Args:
candidates: (expr, similarity, count, best_predicted) 列表
max_num: 需要返回的数量
temperature: 温度参数0 表示贪婪选择
"""
if max_num <= 0 or not candidates:
return []
if temperature <= 0:
return candidates[:max_num]
adjusted_temp = max(temperature, 1e-6)
# 使用与排序相同的打分,但通过 softmax/temperature 放大尾部概率
scores = [max(c[1] * (c[2] ** 0.5), 1e-8) for c in candidates]
max_score = max(scores)
weights = [math.exp((s - max_score) / adjusted_temp) for s in scores]
# 始终保留最高分一个,剩余的按温度采样,避免过度集中
best_idx = scores.index(max_score)
selected = [candidates[best_idx]]
remaining_indices = [i for i in range(len(candidates)) if i != best_idx]
while remaining_indices and len(selected) < max_num:
current_weights = [weights[i] for i in remaining_indices]
picked_pos = random.choices(range(len(remaining_indices)), weights=current_weights, k=1)[0]
picked_idx = remaining_indices.pop(picked_pos)
selected.append(candidates[picked_idx])
return selected
def __init__(self, chat_id: str = ""):
self.chat_id = chat_id
if model_config is None:
@@ -517,12 +557,21 @@ class ExpressionSelector:
)
return []
# 按照相似度*count排序选择最佳匹配
# 按照相似度*count排序并根据温度采样,避免过度集中
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
temperature = getattr(global_config.expression, "model_temperature", 0.0)
sampled_matches = self._sample_with_temperature(
candidates=matched_expressions,
max_num=max_num,
temperature=temperature,
)
expressions_objs = [e[0] for e in sampled_matches]
# 显示最佳匹配的详细信息
logger.debug(f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式")
logger.debug(
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式 "
f"(候选 {len(matched_expressions)}temperature={temperature})"
)
# 转换为字典格式
expressions = [

View File

@@ -1,21 +1,15 @@
"""
兴趣度系统模块
提供机器人兴趣标签和智能匹配功能,以及消息兴趣计算功能
目前仅保留兴趣计算器管理入口
"""
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
from src.common.data_models.bot_interest_data_model import InterestMatchResult
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from .interest_manager import InterestManager, get_interest_manager
__all__ = [
# 机器人兴趣标签管理
"BotInterestManager",
"BotInterestTag",
"BotPersonalityInterests",
# 消息兴趣值计算管理
"InterestManager",
"InterestMatchResult",
"bot_interest_manager",
"get_interest_manager",
]

File diff suppressed because it is too large Load Diff

View File

@@ -3,10 +3,10 @@ import re
import time
import traceback
from collections import deque
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, Any, cast
import orjson
from sqlalchemy import desc, select, update
from sqlalchemy import desc, insert, select, update
from sqlalchemy.engine import CursorResult
from src.common.data_models.database_data_model import DatabaseMessages
@@ -25,29 +25,55 @@ class MessageStorageBatcher:
消息存储批处理器
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。
"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
def __init__(
self,
batch_size: int = 50,
flush_interval: float = 5.0,
*,
commit_batch_size: int | None = None,
commit_interval: float | None = None,
db_chunk_size: int = 200,
):
"""
初始化批处理器
Args:
batch_size: 批量大小,达到此数量立即写入
flush_interval: 自动刷新间隔(秒)
batch_size: 写入队列中触发准备阶段的消息条数
flush_interval: 自动刷新/检查间隔(秒)
commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size至少100
commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s)
db_chunk_size: 单次SQL语句批量写入数量上限
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100)
self.commit_interval = commit_interval or max(flush_interval * 2, 10.0)
self.db_chunk_size = max(50, db_chunk_size)
self.pending_messages: deque = deque()
self._prepared_buffer: list[dict[str, Any]] = []
self._lock = asyncio.Lock()
self._flush_barrier = asyncio.Lock()
self._flush_task = None
self._running = False
self._last_commit_ts = time.monotonic()
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None and not self._running:
self._running = True
self._last_commit_ts = time.monotonic()
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
logger.info(
"消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)",
self.batch_size,
self.flush_interval,
self.commit_batch_size,
self.commit_interval,
)
async def stop(self):
"""停止批处理器"""
@@ -62,7 +88,7 @@ class MessageStorageBatcher:
self._flush_task = None
# 刷新剩余的消息
await self.flush()
await self.flush(force=True)
logger.info("消息存储批处理器已停止")
async def add_message(self, message_data: dict):
@@ -76,61 +102,82 @@ class MessageStorageBatcher:
'chat_stream': ChatStream
}
"""
should_force_flush = False
async with self._lock:
self.pending_messages.append(message_data)
# 如果达到批量大小,立即刷新
if len(self.pending_messages) >= self.batch_size:
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
await self.flush()
should_force_flush = True
async def flush(self):
"""执行批量写入"""
async with self._lock:
if not self.pending_messages:
return
if should_force_flush:
logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新")
await self.flush(force=True)
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
async def flush(self, force: bool = False):
"""执行批量写入, 支持强制落库和延迟提交策略。"""
async with self._flush_barrier:
async with self._lock:
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
if not messages_to_store:
if messages_to_store:
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
await self._maybe_commit_buffer(force=force)
async def _maybe_commit_buffer(self, *, force: bool = False) -> None:
"""根据阈值/时间窗口判断是否需要真正写库。"""
if not self._prepared_buffer:
return
now = time.monotonic()
enough_rows = len(self._prepared_buffer) >= self.commit_batch_size
waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval
if not (force or enough_rows or waited_long_enough):
return
await self._write_buffer_to_database()
async def _write_buffer_to_database(self) -> None:
payload = self._prepared_buffer
if not payload:
return
self._prepared_buffer = []
start_time = time.time()
success_count = 0
total = len(payload)
try:
# 🔧 优化准备字典数据而不是ORM对象使用批量INSERT
messages_dicts = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"]
)
if message_dict:
messages_dicts.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
continue
# 批量写入数据库 - 使用高效的批量INSERT
if messages_dicts:
from sqlalchemy import insert
async with get_db_session() as session:
stmt = insert(Messages).values(messages_dicts)
await session.execute(stmt)
await session.commit()
success_count = len(messages_dicts)
async with get_db_session() as session:
for start in range(0, total, self.db_chunk_size):
chunk = payload[start : start + self.db_chunk_size]
if chunk:
await session.execute(insert(Messages), chunk)
await session.commit()
elapsed = time.time() - start_time
self._last_commit_ts = time.monotonic()
per_item = (elapsed / total) * 1000 if total else 0
logger.info(
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})"
)
except Exception as e:
# 回滚到缓冲区, 等待下一次尝试
self._prepared_buffer = payload + self._prepared_buffer
logger.error(f"批量存储消息失败: {e}")
async def _prepare_message_dict(self, message, chat_stream):

View File

@@ -614,7 +614,7 @@ class DefaultReplyer:
# 使用统一管理器的智能检索Judge模型决策
search_result = await unified_manager.search_memories(
query_text=query_text,
use_judge=True,
use_judge=global_config.memory.use_judge,
recent_chat_history=chat_history, # 传递最近聊天历史
)
@@ -1799,8 +1799,9 @@ class DefaultReplyer:
)
if content:
# 移除 [SPLIT] 标记,防止消息被分割
content = content.replace("[SPLIT]", "")
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm':
# 移除 [SPLIT] 标记,防止消息被分割
content = content.replace("[SPLIT]", "")
# 应用统一的格式过滤器
from src.chat.utils.utils import filter_system_format_content

View File

@@ -0,0 +1,67 @@
"""语义兴趣度计算模块
基于 TF-IDF + Logistic Regression 的语义兴趣度计算系统
支持人设感知的自动训练和模型切换
2024.12 优化更新:
- 新增 FastScorer绕过 sklearn使用 token→weight 字典直接计算
- 全局线程池:避免重复创建 ThreadPoolExecutor
- 批处理队列:攒消息一起算,提高 CPU 利用率
- TF-IDF 降维max_features 10000, ngram_range (2,3)
- 权重剪枝:只保留高贡献 token
"""
from .auto_trainer import AutoTrainer, get_auto_trainer
from .dataset import DatasetGenerator, generate_training_dataset
from .features_tfidf import TfidfFeatureExtractor
from .model_lr import SemanticInterestModel, train_semantic_model
from .optimized_scorer import (
BatchScoringQueue,
FastScorer,
FastScorerConfig,
clear_fast_scorer_instances,
convert_sklearn_to_fast,
get_fast_scorer,
get_global_executor,
shutdown_global_executor,
)
from .runtime_scorer import (
ModelManager,
SemanticInterestScorer,
clear_scorer_instances,
get_all_scorer_instances,
get_semantic_scorer,
get_semantic_scorer_sync,
)
from .trainer import SemanticInterestTrainer
__all__ = [
# 运行时评分
"SemanticInterestScorer",
"ModelManager",
"get_semantic_scorer", # 单例获取(异步)
"get_semantic_scorer_sync", # 单例获取(同步)
"clear_scorer_instances", # 清空单例
"get_all_scorer_instances", # 查看所有实例
# 优化评分器(推荐用于高频场景)
"FastScorer",
"FastScorerConfig",
"BatchScoringQueue",
"get_fast_scorer",
"convert_sklearn_to_fast",
"clear_fast_scorer_instances",
"get_global_executor",
"shutdown_global_executor",
# 训练组件
"TfidfFeatureExtractor",
"SemanticInterestModel",
"train_semantic_model",
# 数据集生成
"DatasetGenerator",
"generate_training_dataset",
# 训练器
"SemanticInterestTrainer",
# 自动训练
"AutoTrainer",
"get_auto_trainer",
]

View File

@@ -0,0 +1,375 @@
"""自动训练调度器
监控人设变化,自动触发模型训练和切换
"""
import asyncio
import hashlib
import json
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
logger = get_logger("semantic_interest.auto_trainer")
class AutoTrainer:
"""自动训练调度器
功能:
1. 监控人设变化
2. 自动构建训练数据集
3. 定期重新训练模型
4. 管理多个人设的模型
"""
def __init__(
self,
data_dir: Path | None = None,
model_dir: Path | None = None,
min_train_interval_hours: int = 720, # 最小训练间隔小时30天
min_samples_for_training: int = 100, # 最小训练样本数
):
"""初始化自动训练器
Args:
data_dir: 数据集目录
model_dir: 模型目录
min_train_interval_hours: 最小训练间隔(小时)
min_samples_for_training: 触发训练的最小样本数
"""
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
self.model_dir = Path(model_dir or "data/semantic_interest/models")
self.min_train_interval = timedelta(hours=min_train_interval_hours)
self.min_samples = min_samples_for_training
# 人设状态缓存
self.persona_cache_file = self.data_dir / "persona_cache.json"
self.last_persona_hash: str | None = None
self.last_train_time: datetime | None = None
# 训练器实例
self.trainer = SemanticInterestTrainer(
data_dir=self.data_dir,
model_dir=self.model_dir,
)
# 确保目录存在
self.data_dir.mkdir(parents=True, exist_ok=True)
self.model_dir.mkdir(parents=True, exist_ok=True)
# 加载缓存的人设状态
self._load_persona_cache()
# 定时任务标志(防止重复启动)
self._scheduled_task_running = False
self._scheduled_task = None
logger.info("[自动训练器] 初始化完成")
logger.info(f" - 数据目录: {self.data_dir}")
logger.info(f" - 模型目录: {self.model_dir}")
logger.info(f" - 最小训练间隔: {min_train_interval_hours}小时")
def _load_persona_cache(self):
"""加载缓存的人设状态"""
if self.persona_cache_file.exists():
try:
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
cache = json.load(f)
self.last_persona_hash = cache.get("persona_hash")
last_train_str = cache.get("last_train_time")
if last_train_str:
self.last_train_time = datetime.fromisoformat(last_train_str)
logger.info(f"[自动训练器] 加载人设缓存: hash={self.last_persona_hash[:8] if self.last_persona_hash else 'None'}")
except Exception as e:
logger.warning(f"[自动训练器] 加载人设缓存失败: {e}")
def _save_persona_cache(self, persona_hash: str):
"""保存人设状态到缓存"""
cache = {
"persona_hash": persona_hash,
"last_train_time": datetime.now().isoformat(),
}
try:
with open(self.persona_cache_file, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
logger.debug(f"[自动训练器] 保存人设缓存: hash={persona_hash[:8]}")
except Exception as e:
logger.error(f"[自动训练器] 保存人设缓存失败: {e}")
def _calculate_persona_hash(self, persona_info: dict[str, Any]) -> str:
"""计算人设信息的哈希值
Args:
persona_info: 人设信息字典
Returns:
SHA256 哈希值
"""
# 只关注影响模型的关键字段
key_fields = {
"name": persona_info.get("name", ""),
"interests": sorted(persona_info.get("interests", [])),
"dislikes": sorted(persona_info.get("dislikes", [])),
"personality": persona_info.get("personality", ""),
# 可选的更完整人设字段(存在则纳入哈希)
"personality_core": persona_info.get("personality_core", ""),
"personality_side": persona_info.get("personality_side", ""),
"identity": persona_info.get("identity", ""),
}
# 转为JSON并计算哈希
json_str = json.dumps(key_fields, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(json_str.encode()).hexdigest()
def check_persona_changed(self, persona_info: dict[str, Any]) -> bool:
"""检查人设是否发生变化
Args:
persona_info: 当前人设信息
Returns:
True 如果人设发生变化
"""
current_hash = self._calculate_persona_hash(persona_info)
if self.last_persona_hash is None:
logger.info("[自动训练器] 首次检测人设")
return True
if current_hash != self.last_persona_hash:
logger.info(f"[自动训练器] 检测到人设变化")
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
logger.info(f" - 新哈希: {current_hash[:8]}")
return True
return False
def should_train(self, persona_info: dict[str, Any], force: bool = False) -> tuple[bool, str]:
"""判断是否应该训练模型
Args:
persona_info: 人设信息
force: 强制训练
Returns:
(是否应该训练, 原因说明)
"""
# 强制训练
if force:
return True, "强制训练"
# 检查人设是否变化
persona_changed = self.check_persona_changed(persona_info)
if persona_changed:
return True, "人设发生变化"
# 检查训练间隔
if self.last_train_time is None:
return True, "从未训练过"
time_since_last_train = datetime.now() - self.last_train_time
if time_since_last_train >= self.min_train_interval:
return True, f"距上次训练已{time_since_last_train.total_seconds() / 3600:.1f}小时"
return False, "无需训练"
async def auto_train_if_needed(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
force: bool = False,
) -> tuple[bool, Path | None]:
"""自动训练(如果需要)
Args:
persona_info: 人设信息
days: 采样天数
max_samples: 最大采样数默认1000条
force: 强制训练
Returns:
(是否训练了, 模型路径)
"""
# 检查是否需要训练
should_train, reason = self.should_train(persona_info, force)
if not should_train:
logger.debug(f"[自动训练器] {reason},跳过训练")
return False, None
logger.info(f"[自动训练器] 开始自动训练: {reason}")
try:
# 计算人设哈希作为版本标识
persona_hash = self._calculate_persona_hash(persona_info)
model_version = f"auto_{persona_hash[:8]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 执行训练
dataset_path, model_path, metrics = await self.trainer.full_training_pipeline(
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_version=model_version,
tfidf_config={
"analyzer": "char",
"ngram_range": (2, 4),
"max_features": 10000,
"min_df": 3,
},
model_config={
"class_weight": "balanced",
"max_iter": 1000,
},
)
# 更新缓存
self.last_persona_hash = persona_hash
self.last_train_time = datetime.now()
self._save_persona_cache(persona_hash)
# 创建"latest"符号链接
self._create_latest_link(model_path)
logger.info(f"[自动训练器] 训练完成!")
logger.info(f" - 模型: {model_path.name}")
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
return True, model_path
except Exception as e:
logger.error(f"[自动训练器] 训练失败: {e}")
import traceback
traceback.print_exc()
return False, None
def _create_latest_link(self, model_path: Path):
"""创建指向最新模型的符号链接
Args:
model_path: 模型文件路径
"""
latest_path = self.model_dir / "semantic_interest_latest.pkl"
try:
# 删除旧链接
if latest_path.exists() or latest_path.is_symlink():
latest_path.unlink()
# 创建新链接Windows 需要管理员权限,使用复制代替)
import shutil
shutil.copy2(model_path, latest_path)
logger.info(f"[自动训练器] 已更新 latest 模型")
except Exception as e:
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
async def scheduled_train(
self,
persona_info: dict[str, Any],
interval_hours: int = 24,
):
"""定时训练任务
Args:
persona_info: 人设信息
interval_hours: 检查间隔(小时)
"""
# 检查是否已经有任务在运行
if self._scheduled_task_running:
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
return
self._scheduled_task_running = True
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
while True:
try:
# 检查并训练
trained, model_path = await self.auto_train_if_needed(persona_info)
if trained:
logger.info(f"[自动训练器] 定时训练完成: {model_path}")
# 等待下次检查
await asyncio.sleep(interval_hours * 3600)
except Exception as e:
logger.error(f"[自动训练器] 定时训练出错: {e}")
# 出错后等待较短时间再试
await asyncio.sleep(300) # 5分钟
def get_model_for_persona(self, persona_info: dict[str, Any]) -> Path | None:
"""获取当前人设对应的模型
Args:
persona_info: 人设信息
Returns:
模型文件路径,如果不存在则返回 None
"""
persona_hash = self._calculate_persona_hash(persona_info)
# 查找匹配的模型
pattern = f"semantic_interest_auto_{persona_hash[:8]}_*.pkl"
matching_models = list(self.model_dir.glob(pattern))
if matching_models:
# 返回最新的
latest = max(matching_models, key=lambda p: p.stat().st_mtime)
logger.debug(f"[自动训练器] 找到人设模型: {latest.name}")
return latest
# 没有找到,返回 latest
latest_path = self.model_dir / "semantic_interest_latest.pkl"
if latest_path.exists():
logger.debug(f"[自动训练器] 使用 latest 模型")
return latest_path
logger.warning(f"[自动训练器] 未找到可用模型")
return None
def cleanup_old_models(self, keep_count: int = 5):
"""清理旧模型文件
Args:
keep_count: 保留最新的 N 个模型
"""
try:
# 获取所有自动训练的模型
all_models = list(self.model_dir.glob("semantic_interest_auto_*.pkl"))
if len(all_models) <= keep_count:
return
# 按修改时间排序
all_models.sort(key=lambda p: p.stat().st_mtime, reverse=True)
# 删除旧模型
for old_model in all_models[keep_count:]:
old_model.unlink()
logger.info(f"[自动训练器] 清理旧模型: {old_model.name}")
logger.info(f"[自动训练器] 模型清理完成,保留 {keep_count}")
except Exception as e:
logger.error(f"[自动训练器] 清理模型失败: {e}")
# 全局单例
_auto_trainer: AutoTrainer | None = None
def get_auto_trainer() -> AutoTrainer:
"""获取自动训练器单例"""
global _auto_trainer
if _auto_trainer is None:
_auto_trainer = AutoTrainer()
return _auto_trainer

View File

@@ -0,0 +1,818 @@
"""数据集生成与 LLM 标注
从数据库采样消息并使用 LLM 进行兴趣度标注
"""
import asyncio
import json
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("semantic_interest.dataset")
class DatasetGenerator:
"""训练数据集生成器
从历史消息中采样并使用 LLM 进行标注
"""
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
HARD_MAX_SAMPLES = 2000
# 标注提示词模板(单条)
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 消息内容
{message_text}
## 标注规则
请判断角色对这条消息的兴趣程度,返回以下之一:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
只需返回数字 -1、0 或 1不要其他内容。"""
# 批量标注提示词模板
BATCH_ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断每条消息是否会引起角色的兴趣。
## 人格信息
{persona_info}
## 标注规则
对每条消息判断角色的兴趣程度:
- **-1**: 完全不感兴趣或排斥(话题不相关、违背价值观、无聊重复等)
- **0**: 中立(可以回应但不特别感兴趣)
- **1**: 感兴趣(话题相关、符合兴趣点、能产生深度对话)
## 消息列表
{messages_list}
## 输出格式
请严格按照以下JSON格式返回每条消息一个标签
```json
{example_output}
```
只返回JSON不要其他内容。"""
# 关键词生成提示词模板
KEYWORD_GENERATION_PROMPT = """你是一个帮助生成训练数据的专家。请根据人格设定生成感兴趣和不感兴趣的关键词/短语列表。
## 人格信息
{persona_info}
## 任务说明
请分别生成该角色**感兴趣**和**不感兴趣**的关键词或短语:
1. **感兴趣的关键词**包括但不限于该角色喜欢的话题、活动、领域、价值观相关词汇等约30-50个
2. **不感兴趣的关键词**包括该角色不关心、反感、无聊的话题、价值观冲突的内容等约30-50个
## 输出格式
请严格按照以下JSON格式返回
```json
{{
"interested": ["关键词1", "关键词2", "关键词3", ...],
"not_interested": ["关键词1", "关键词2", "关键词3", ...]
}}
```
注意:
- 关键词可以是单个词语或短语2-10个字
- 尽量覆盖多样化的话题和场景
- 确保关键词与人格设定高度相关
只返回JSON不要其他内容。"""
def __init__(
self,
model_name: str | None = None,
max_samples_per_batch: int = 50,
):
"""初始化数据集生成器
Args:
model_name: LLM 模型名称None 则使用默认模型)
max_samples_per_batch: 每批次最大采样数
"""
self.model_name = model_name
self.max_samples_per_batch = max_samples_per_batch
self.model_client = None
async def initialize(self):
"""初始化 LLM 客户端"""
try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
# 使用 utilities 模型配置(标注更偏工具型)
if hasattr(model_config.model_task_config, 'utils'):
self.model_client = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="semantic_annotation"
)
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
else:
logger.error("未找到 utils 模型配置")
self.model_client = None
except ImportError as e:
logger.warning(f"无法导入 LLM 模块: {e},标注功能将不可用")
self.model_client = None
except Exception as e:
logger.error(f"LLM 客户端初始化失败: {e}")
self.model_client = None
async def sample_messages(
self,
days: int = 7,
min_length: int = 5,
max_samples: int = 1000,
priority_ranges: list[tuple[float, float]] | None = None,
) -> list[dict[str, Any]]:
"""从数据库采样消息(优化版:减少查询量和内存使用)
Args:
days: 采样最近 N 天的消息
min_length: 最小消息长度
max_samples: 最大采样数量
priority_ranges: 优先采样的兴趣分范围列表,如 [(0.4, 0.6)]
Returns:
消息样本列表
"""
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import Messages
from sqlalchemy import func, or_
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
# 限制采样数量硬上限
requested_max_samples = max_samples
if max_samples is None:
max_samples = self.HARD_MAX_SAMPLES
else:
max_samples = int(max_samples)
if max_samples <= 0:
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
return []
if max_samples > self.HARD_MAX_SAMPLES:
logger.warning(
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES}"
f"已截断为 {self.HARD_MAX_SAMPLES}"
)
max_samples = self.HARD_MAX_SAMPLES
# 查询条件
cutoff_time = datetime.now() - timedelta(days=days)
cutoff_ts = cutoff_time.timestamp()
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
# 这样可以在保证足够样本的同时减少查询量
prefetch_limit = int(max_samples * 1.5)
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
query_builder = QueryBuilder(Messages)
# 过滤条件:时间范围 + 消息文本不为空
messages = await query_builder.filter(
time__gte=cutoff_ts,
).order_by(
"-time" # 按时间倒序,优先采样最新消息
).limit(
prefetch_limit # 限制预取数量
).all(as_dict=True)
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit}")
# 过滤消息长度和提取文本
filtered = []
for msg in messages:
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
text = text.strip()
if text and len(text) >= min_length:
filtered.append({**msg, "message_text": text})
# 达到目标数量即可停止
if len(filtered) >= max_samples:
break
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples}")
# 如果过滤后数量不足,记录警告
if len(filtered) < max_samples:
logger.warning(
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples})"
f"可能需要扩大采样范围(增加 days 参数或降低 min_length"
)
# 随机打乱样本顺序(避免时间偏向)
if len(filtered) > 0:
random.shuffle(filtered)
# 转换为标准格式
result = []
for msg in filtered:
result.append({
"message_id": msg.get("message_id"),
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"message_text": msg.get("message_text", ""),
"timestamp": msg.get("time"),
"platform": msg.get("chat_info_platform"),
})
logger.info(f"采样完成,共 {len(result)} 条消息")
return result
async def generate_initial_keywords(
self,
persona_info: dict[str, Any],
temperature: float = 0.7,
num_iterations: int = 3,
) -> list[dict[str, Any]]:
"""使用 LLM 生成初始关键词数据集
根据人设信息生成感兴趣和不感兴趣的关键词,重复多次以增加多样性。
Args:
persona_info: 人格信息
temperature: 生成温度默认0.7,较高温度增加多样性)
num_iterations: 重复生成次数默认3次
Returns:
初始数据集列表,每个元素包含 {"message_text": str, "label": int}
"""
if not self.model_client:
await self.initialize()
logger.info(f"开始生成初始关键词数据集,温度={temperature},迭代{num_iterations}")
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.KEYWORD_GENERATION_PROMPT.format(
persona_info=persona_desc,
)
all_keywords_data = []
# 重复生成多次
for iteration in range(num_iterations):
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,跳过关键词生成")
break
logger.info(f"{iteration + 1}/{num_iterations} 次生成关键词...")
# 调用 LLM使用较高温度
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=1000, # 关键词列表需要较多token
temperature=temperature,
)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
keywords_data = self._parse_keywords_response(response_text)
if keywords_data:
interested = keywords_data.get("interested", [])
not_interested = keywords_data.get("not_interested", [])
logger.info(f" 生成 {len(interested)} 个感兴趣关键词,{len(not_interested)} 个不感兴趣关键词")
# 转换为训练格式(标签 1 表示感兴趣,-1 表示不感兴趣)
for keyword in interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": 1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
for keyword in not_interested:
if keyword and keyword.strip():
all_keywords_data.append({
"message_text": keyword.strip(),
"label": -1,
"source": "llm_generated_initial",
"iteration": iteration + 1,
})
else:
logger.warning(f"{iteration + 1} 次生成失败,未能解析关键词")
except Exception as e:
logger.error(f"{iteration + 1} 次关键词生成失败: {e}")
import traceback
traceback.print_exc()
logger.info(f"初始关键词数据集生成完成,共 {len(all_keywords_data)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in all_keywords_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标签分布: {label_counts}")
return all_keywords_data
def _parse_keywords_response(self, response: str) -> dict | None:
"""解析关键词生成的JSON响应
Args:
response: LLM 响应文本
Returns:
解析后的字典,包含 interested 和 not_interested 列表
"""
try:
# 提取JSON部分去除markdown代码块标记
response = response.strip()
if "```json" in response:
response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response:
response = response.split("```")[1].split("```")[0].strip()
# 解析JSON
import json_repair
response = json_repair.repair_json(response)
data = json.loads(response)
# 验证格式
if isinstance(data, dict) and "interested" in data and "not_interested" in data:
if isinstance(data["interested"], list) and isinstance(data["not_interested"], list):
return data
logger.warning(f"关键词响应格式不正确: {data}")
return None
except json.JSONDecodeError as e:
logger.error(f"解析关键词JSON失败: {e}")
logger.debug(f"响应内容: {response}")
return None
except Exception as e:
logger.error(f"解析关键词响应失败: {e}")
return None
async def annotate_message(
self,
message_text: str,
persona_info: dict[str, Any],
) -> int:
"""使用 LLM 标注单条消息
Args:
message_text: 消息文本
persona_info: 人格信息
Returns:
标签 (-1, 0, 1)
"""
if not self.model_client:
await self.initialize()
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造提示词
prompt = self.ANNOTATION_PROMPT.format(
persona_info=persona_desc,
message_text=message_text,
)
try:
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return 0
# 调用 LLM
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=10,
temperature=0.1, # 低温度保证一致性
)
# 解析响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
label = self._parse_label(response_text)
return label
except Exception as e:
logger.error(f"LLM 标注失败: {e}")
return 0 # 默认返回中立
async def annotate_batch(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
save_path: Path | None = None,
batch_size: int = 50,
) -> list[dict[str, Any]]:
"""批量标注消息(真正的批量模式)
Args:
messages: 消息列表
persona_info: 人格信息
save_path: 保存路径(可选)
batch_size: 每次LLM请求处理的消息数默认20
Returns:
标注后的数据集
"""
logger.info(f"开始批量标注,共 {len(messages)} 条消息,每批 {batch_size}")
annotated_data = []
for i in range(0, len(messages), batch_size):
batch = messages[i : i + batch_size]
# 批量标注一次LLM请求处理多条消息
labels = await self._annotate_batch_llm(batch, persona_info)
# 保存结果
for msg, label in zip(batch, labels):
annotated_data.append({
"message_id": msg["message_id"],
"message_text": msg["message_text"],
"label": label,
"user_id": msg.get("user_id"),
"chat_id": msg.get("chat_id"),
"timestamp": msg.get("timestamp"),
})
logger.info(f"已标注 {len(annotated_data)}/{len(messages)}")
# 统计标签分布
label_counts = {}
for item in annotated_data:
label = item["label"]
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f"标注完成,标签分布: {label_counts}")
# 保存到文件
if save_path:
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(annotated_data, f, ensure_ascii=False, indent=2)
logger.info(f"数据集已保存到: {save_path}")
return annotated_data
async def _annotate_batch_llm(
self,
messages: list[dict[str, Any]],
persona_info: dict[str, Any],
) -> list[int]:
"""使用一次LLM请求标注多条消息
Args:
messages: 消息列表通常20条
persona_info: 人格信息
Returns:
标签列表
"""
if not self.model_client:
logger.warning("LLM 客户端未初始化,返回默认标签")
return [0] * len(messages)
# 构造人格描述
persona_desc = self._format_persona_info(persona_info)
# 构造消息列表
messages_list = ""
for idx, msg in enumerate(messages, 1):
messages_list += f"{idx}. {msg['message_text']}\n"
# 构造示例输出
example_output = json.dumps(
{str(i): 0 for i in range(1, len(messages) + 1)},
ensure_ascii=False,
indent=2
)
# 构造提示词
prompt = self.BATCH_ANNOTATION_PROMPT.format(
persona_info=persona_desc,
messages_list=messages_list,
example_output=example_output,
)
try:
# 调用 LLM使用更大的token限制
response = await self.model_client.generate_response_async(
prompt=prompt,
max_tokens=500, # 批量标注需要更多token
temperature=0.1,
)
# 解析批量响应generate_response_async 返回元组)
response_text = response[0] if isinstance(response, tuple) else response
labels = self._parse_batch_labels(response_text, len(messages))
return labels
except Exception as e:
logger.error(f"批量LLM标注失败: {e},返回默认值")
return [0] * len(messages)
def _format_persona_info(self, persona_info: dict[str, Any]) -> str:
"""格式化人格信息
Args:
persona_info: 人格信息字典
Returns:
格式化后的人格描述
"""
def _stringify(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (list, tuple, set)):
return "".join([str(v) for v in value if v is not None and str(v).strip()])
if isinstance(value, dict):
try:
return json.dumps(value, ensure_ascii=False, sort_keys=True)
except Exception:
return str(value)
return str(value).strip()
parts: list[str] = []
name = _stringify(persona_info.get("name"))
if name:
parts.append(f"角色名称: {name}")
# 核心/侧面/身份等完整人设信息
personality_core = _stringify(persona_info.get("personality_core"))
if personality_core:
parts.append(f"核心人设: {personality_core}")
personality_side = _stringify(persona_info.get("personality_side"))
if personality_side:
parts.append(f"侧面特质: {personality_side}")
identity = _stringify(persona_info.get("identity"))
if identity:
parts.append(f"身份特征: {identity}")
# 追加其他未覆盖字段(保持信息完整)
known_keys = {
"name",
"personality_core",
"personality_side",
"identity",
}
for key, value in persona_info.items():
if key in known_keys:
continue
value_str = _stringify(value)
if value_str:
parts.append(f"{key}: {value_str}")
return "\n".join(parts) if parts else "无特定人格设定"
def _parse_label(self, response: str) -> int:
"""解析 LLM 响应为标签
Args:
response: LLM 响应文本
Returns:
标签 (-1, 0, 1)
"""
# 部分 LLM 客户端可能返回 (text, meta) 的 tuple这里取首元素并转为字符串
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response).strip()
# 尝试直接解析数字
if response in ["-1", "0", "1"]:
return int(response)
# 尝试提取数字
if "-1" in response:
return -1
elif "1" in response:
return 1
elif "0" in response:
return 0
# 默认返回中立
logger.warning(f"无法解析 LLM 响应: {response},返回默认值 0")
return 0
def _parse_batch_labels(self, response: str, expected_count: int) -> list[int]:
"""解析批量LLM响应为标签列表
Args:
response: LLM 响应文本JSON格式
expected_count: 期望的标签数量
Returns:
标签列表
"""
try:
# 兼容 tuple/list 返回格式
if isinstance(response, (tuple, list)):
response = response[0] if response else ""
response = str(response)
# 提取JSON内容
import re
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
# 尝试直接解析
json_str = response
import json_repair
# 解析JSON
labels_json = json_repair.repair_json(json_str)
labels_dict = json.loads(labels_json) # 验证是否为有效JSON
# 转换为列表
labels = []
for i in range(1, expected_count + 1):
key = str(i)
# 检查是否为字典且包含该键
if isinstance(labels_dict, dict) and key in labels_dict:
label = labels_dict[key]
# 确保标签值有效
if label in [-1, 0, 1]:
labels.append(label)
else:
logger.warning(f"无效标签值 {label},使用默认值 0")
labels.append(0)
else:
# 尝试从值列表或数组中顺序取值
if isinstance(labels_dict, list) and len(labels_dict) >= i:
label = labels_dict[i - 1]
labels.append(label if label in [-1, 0, 1] else 0)
else:
labels.append(0)
if len(labels) != expected_count:
logger.warning(
f"标签数量不匹配:期望 {expected_count},实际 {len(labels)}"
f"补齐为 {expected_count}"
)
# 补齐或截断
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
else:
labels = labels[:expected_count]
return labels
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e},响应内容: {response[:200]}")
return [0] * expected_count
except Exception as e:
# 兜底:尝试直接提取所有标签数字
try:
import re
numbers = re.findall(r"-?1|0", response)
labels = [int(n) for n in numbers[:expected_count]]
if len(labels) < expected_count:
labels.extend([0] * (expected_count - len(labels)))
return labels
except Exception:
logger.error(f"批量标签解析失败: {e}")
return [0] * expected_count
@staticmethod
def load_dataset(path: Path) -> tuple[list[str], list[int]]:
"""加载训练数据集
Args:
path: 数据集文件路径
Returns:
(文本列表, 标签列表)
"""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
texts = [item["message_text"] for item in data]
labels = [item["label"] for item in data]
logger.info(f"加载数据集: {len(texts)} 条样本")
return texts, labels
async def generate_training_dataset(
output_path: Path,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
generate_initial_keywords: bool = True,
keyword_temperature: float = 0.7,
keyword_iterations: int = 3,
) -> Path:
"""生成训练数据集(主函数)
Args:
output_path: 输出文件路径
persona_info: 人格信息
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
generate_initial_keywords: 是否生成初始关键词数据集默认True
keyword_temperature: 关键词生成温度默认0.7
keyword_iterations: 关键词生成迭代次数默认3
Returns:
保存的文件路径
"""
generator = DatasetGenerator(model_name=model_name)
await generator.initialize()
# 第一步:生成初始关键词数据集(如果启用)
initial_keywords_data = []
if generate_initial_keywords:
logger.info("=" * 60)
logger.info("步骤 1/3: 生成初始关键词数据集")
logger.info("=" * 60)
initial_keywords_data = await generator.generate_initial_keywords(
persona_info=persona_info,
temperature=keyword_temperature,
num_iterations=keyword_iterations,
)
logger.info(f"✓ 初始关键词数据集已生成: {len(initial_keywords_data)}")
else:
logger.info("跳过初始关键词生成")
# 第二步:采样真实消息
logger.info("=" * 60)
logger.info(f"步骤 2/3: 采样真实消息(最近 {days} 天,最多 {max_samples} 条)")
logger.info("=" * 60)
messages = await generator.sample_messages(
days=days,
max_samples=max_samples,
)
logger.info(f"✓ 消息采样完成: {len(messages)}")
# 第三步:批量标注真实消息
logger.info("=" * 60)
logger.info("步骤 3/3: LLM 标注真实消息")
logger.info("=" * 60)
# 注意:不保存到文件,返回标注后的数据
annotated_messages = await generator.annotate_batch(
messages=messages,
persona_info=persona_info,
save_path=None, # 暂不保存
)
logger.info(f"✓ 消息标注完成: {len(annotated_messages)}")
# 第四步:合并数据集
logger.info("=" * 60)
logger.info("步骤 4/4: 合并数据集")
logger.info("=" * 60)
# 合并初始关键词和标注后的消息(不去重,保持所有重复项)
combined_dataset = []
# 添加初始关键词数据
if initial_keywords_data:
combined_dataset.extend(initial_keywords_data)
logger.info(f" + 初始关键词: {len(initial_keywords_data)}")
# 添加标注后的消息
combined_dataset.extend(annotated_messages)
logger.info(f" + 标注消息: {len(annotated_messages)}")
logger.info(f"✓ 合并后总计: {len(combined_dataset)} 条(不去重)")
# 统计标签分布
label_counts = {}
for item in combined_dataset:
label = item.get("label", 0)
label_counts[label] = label_counts.get(label, 0) + 1
logger.info(f" 最终标签分布: {label_counts}")
# 保存合并后的数据集
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(combined_dataset, f, ensure_ascii=False, indent=2)
logger.info("=" * 60)
logger.info(f"✓ 训练数据集已保存: {output_path}")
logger.info("=" * 60)
return output_path

View File

@@ -0,0 +1,147 @@
"""TF-IDF 特征向量化器
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
"""
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from src.common.logger import get_logger
logger = get_logger("semantic_interest.features")
class TfidfFeatureExtractor:
"""TF-IDF 特征提取器
使用字符级 n-gram 策略,适合中文/多语言场景
优化说明2024.12
- max_features 从 20000 降到 10000减少计算量
- ngram_range 默认 (2, 3),对于兴趣任务足够
- min_df 提高到 3过滤低频噪声
"""
def __init__(
self,
analyzer: str = "char", # type: ignore
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
min_df: int = 3, # 优化:过滤低频 n-gram
max_df: float = 0.95,
):
"""初始化特征提取器
Args:
analyzer: 分析器类型 ('char''word')
ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram
max_features: 词表最大大小,防止特征爆炸
min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表
max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词)
"""
self.vectorizer = TfidfVectorizer(
analyzer=analyzer,
ngram_range=ngram_range,
max_features=max_features,
min_df=min_df,
max_df=max_df,
lowercase=True,
strip_accents=None, # 保留中文字符
sublinear_tf=True, # 使用对数 TF 缩放
norm="l2", # L2 归一化
)
self.is_fitted = False
logger.info(
f"TF-IDF 特征提取器初始化: analyzer={analyzer}, "
f"ngram_range={ngram_range}, max_features={max_features}"
)
def fit(self, texts: list[str]) -> "TfidfFeatureExtractor":
"""训练向量化器
Args:
texts: 训练文本列表
Returns:
self
"""
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
self.vectorizer.fit(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
return self
def transform(self, texts: list[str]):
"""将文本转换为 TF-IDF 向量
Args:
texts: 待转换文本列表
Returns:
稀疏矩阵
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
return self.vectorizer.transform(texts)
def fit_transform(self, texts: list[str]):
"""训练并转换文本
Args:
texts: 训练文本列表
Returns:
稀疏矩阵
"""
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
result = self.vectorizer.fit_transform(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
return result
def get_feature_names(self) -> list[str]:
"""获取特征名称列表
Returns:
特征名称列表
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练")
return self.vectorizer.get_feature_names_out().tolist()
def get_vocabulary_size(self) -> int:
"""获取词表大小
Returns:
词表大小
"""
if not self.is_fitted:
return 0
return len(self.vectorizer.vocabulary_)
def get_config(self) -> dict:
"""获取配置信息
Returns:
配置字典
"""
params = self.vectorizer.get_params()
return {
"analyzer": params["analyzer"],
"ngram_range": params["ngram_range"],
"max_features": params["max_features"],
"min_df": params["min_df"],
"max_df": params["max_df"],
"vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0,
"is_fitted": self.is_fitted,
}

View File

@@ -0,0 +1,263 @@
"""Logistic Regression 模型训练与推理
使用多分类 Logistic Regression 预测消息的兴趣度标签 (-1, 0, 1)
"""
import time
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
logger = get_logger("semantic_interest.model")
class SemanticInterestModel:
"""语义兴趣度模型
使用 Logistic Regression 进行多分类(-1: 不感兴趣, 0: 中立, 1: 感兴趣)
"""
def __init__(
self,
class_weight: str | dict | None = "balanced",
max_iter: int = 1000,
solver: str = "lbfgs", # type: ignore
n_jobs: int = -1,
):
"""初始化模型
Args:
class_weight: 类别权重配置
- "balanced": 自动平衡类别权重
- dict: 自定义权重,如 {-1: 0.8, 0: 0.6, 1: 1.6}
- None: 不使用权重
max_iter: 最大迭代次数
solver: 求解器 ('lbfgs', 'saga', 'liblinear' 等)
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
"""
self.clf = LogisticRegression(
solver=solver,
max_iter=max_iter,
class_weight=class_weight,
n_jobs=n_jobs,
random_state=42,
)
self.is_fitted = False
self.label_mapping = {-1: 0, 0: 1, 1: 2} # 内部类别映射
self.training_metrics = {}
logger.info(
f"Logistic Regression 模型初始化: class_weight={class_weight}, "
f"max_iter={max_iter}, solver={solver}"
)
def train(
self,
X_train,
y_train,
X_val=None,
y_val=None,
verbose: bool = True,
) -> dict[str, Any]:
"""训练模型
Args:
X_train: 训练集特征矩阵
y_train: 训练集标签(-1, 0, 1
X_val: 验证集特征矩阵(可选)
y_val: 验证集标签(可选)
verbose: 是否输出详细日志
Returns:
训练指标字典
"""
start_time = time.time()
logger.info(f"开始训练模型,训练样本数: {len(y_train)}")
# 训练模型
self.clf.fit(X_train, y_train)
self.is_fitted = True
training_time = time.time() - start_time
logger.info(f"模型训练完成,耗时: {training_time:.2f}")
# 计算训练集指标
y_train_pred = self.clf.predict(X_train)
train_accuracy = (y_train_pred == y_train).mean()
metrics = {
"training_time": training_time,
"train_accuracy": train_accuracy,
"train_samples": len(y_train),
}
if verbose:
logger.info(f"训练集准确率: {train_accuracy:.4f}")
logger.info(f"类别分布: {dict(zip(*np.unique(y_train, return_counts=True)))}")
# 如果提供了验证集,计算验证指标
if X_val is not None and y_val is not None:
val_metrics = self.evaluate(X_val, y_val, verbose=verbose)
metrics.update(val_metrics)
self.training_metrics = metrics
return metrics
def evaluate(
self,
X_test,
y_test,
verbose: bool = True,
) -> dict[str, Any]:
"""评估模型
Args:
X_test: 测试集特征矩阵
y_test: 测试集标签
verbose: 是否输出详细日志
Returns:
评估指标字典
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
y_pred = self.clf.predict(X_test)
accuracy = (y_pred == y_test).mean()
metrics = {
"test_accuracy": accuracy,
"test_samples": len(y_test),
}
if verbose:
logger.info(f"测试集准确率: {accuracy:.4f}")
logger.info("\n分类报告:")
report = classification_report(
y_test,
y_pred,
labels=[-1, 0, 1],
target_names=["不感兴趣(-1)", "中立(0)", "感兴趣(1)"],
zero_division=0,
)
logger.info(f"\n{report}")
logger.info("\n混淆矩阵:")
cm = confusion_matrix(y_test, y_pred, labels=[-1, 0, 1])
logger.info(f"\n{cm}")
return metrics
def predict_proba(self, X) -> np.ndarray:
"""预测概率分布
Args:
X: 特征矩阵
Returns:
概率矩阵,形状为 (n_samples, 3),对应 [-1, 0, 1] 的概率
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
proba = self.clf.predict_proba(X)
# 确保类别顺序为 [-1, 0, 1]
classes = self.clf.classes_
if not np.array_equal(classes, [-1, 0, 1]):
# 需要重新排序
sorted_proba = np.zeros_like(proba)
for i, cls in enumerate([-1, 0, 1]):
idx = np.where(classes == cls)[0]
if len(idx) > 0:
sorted_proba[:, i] = proba[:, idx[0]]
return sorted_proba
return proba
def predict(self, X) -> np.ndarray:
"""预测类别
Args:
X: 特征矩阵
Returns:
预测标签数组
"""
if not self.is_fitted:
raise ValueError("模型尚未训练")
return self.clf.predict(X)
def get_config(self) -> dict:
"""获取模型配置
Returns:
配置字典
"""
params = self.clf.get_params()
return {
"solver": params["solver"],
"max_iter": params["max_iter"],
"class_weight": params["class_weight"],
"is_fitted": self.is_fitted,
"classes": self.clf.classes_.tolist() if self.is_fitted else None,
}
def train_semantic_model(
texts: list[str],
labels: list[int],
test_size: float = 0.1,
random_state: int = 42,
tfidf_config: dict | None = None,
model_config: dict | None = None,
) -> tuple[TfidfFeatureExtractor, SemanticInterestModel, dict]:
"""训练完整的语义兴趣度模型
Args:
texts: 消息文本列表
labels: 对应的标签列表 (-1, 0, 1)
test_size: 验证集比例
random_state: 随机种子
tfidf_config: TF-IDF 配置
model_config: 模型配置
Returns:
(特征提取器, 模型, 训练指标)
"""
logger.info(f"开始训练语义兴趣度模型,总样本数: {len(texts)}")
# 划分训练集和验证集
X_train_texts, X_val_texts, y_train, y_val = train_test_split(
texts,
labels,
test_size=test_size,
stratify=labels,
random_state=random_state,
)
logger.info(f"训练集: {len(X_train_texts)}, 验证集: {len(X_val_texts)}")
# 初始化并训练 TF-IDF 向量化器
tfidf_config = tfidf_config or {}
feature_extractor = TfidfFeatureExtractor(**tfidf_config)
X_train = feature_extractor.fit_transform(X_train_texts)
X_val = feature_extractor.transform(X_val_texts)
# 初始化并训练模型
model_config = model_config or {}
model = SemanticInterestModel(**model_config)
metrics = model.train(X_train, y_train, X_val, y_val)
logger.info("语义兴趣度模型训练完成")
return feature_extractor, model, metrics

View File

@@ -0,0 +1,641 @@
"""优化的语义兴趣度评分器
实现关键优化:
1. TF-IDF + LR 权重融合为 token→weight 字典
2. 稀疏权重剪枝(只保留高贡献 token
3. 全局线程池 + 异步调度
4. 批处理队列系统
5. 绕过 sklearn 的纯 Python scorer
"""
import asyncio
import math
import re
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
import numpy as np
from src.common.logger import get_logger
logger = get_logger("semantic_interest.optimized")
# ============================================================================
# 全局线程池(避免每次创建新的 executor
# ============================================================================
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
_EXECUTOR_LOCK = asyncio.Lock()
def get_global_executor(max_workers: int = 4) -> ThreadPoolExecutor:
"""获取全局线程池(单例)"""
global _GLOBAL_EXECUTOR
if _GLOBAL_EXECUTOR is None:
_GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="semantic_scorer")
logger.info(f"[优化评分器] 创建全局线程池workers={max_workers}")
return _GLOBAL_EXECUTOR
def shutdown_global_executor():
"""关闭全局线程池"""
global _GLOBAL_EXECUTOR
if _GLOBAL_EXECUTOR is not None:
_GLOBAL_EXECUTOR.shutdown(wait=False)
_GLOBAL_EXECUTOR = None
logger.info("[优化评分器] 全局线程池已关闭")
# ============================================================================
# 快速评分器(绕过 sklearn
# ============================================================================
@dataclass
class FastScorerConfig:
"""快速评分器配置"""
# n-gram 参数
analyzer: str = "char"
ngram_range: tuple[int, int] = (2, 4)
lowercase: bool = True
# 权重剪枝阈值(绝对值小于此值的权重视为 0
weight_prune_threshold: float = 1e-4
# 只保留 top-k 权重0 表示不限制)
top_k_weights: int = 0
# sigmoid 缩放因子
sigmoid_alpha: float = 1.0
# 评分超时(秒)
score_timeout: float = 2.0
class FastScorer:
"""快速语义兴趣度评分器
将 TF-IDF + LR 融合成一个纯 Python 的 token→weight 字典 scorer。
核心公式:
- TF-IDF: x_i = tf_i * idf_i
- LR: z = Σ_i (w_i * x_i) + b = Σ_i (w_i * idf_i * tf_i) + b
- 定义 w'_i = w_i * idf_i则 z = Σ_i (w'_i * tf_i) + b
这样在线评分只需要:
1. 手动做 n-gram tokenize
2. 统计 tf
3. 查表 w'_i累加求和
4. sigmoid 转 [0, 1]
"""
def __init__(self, config: FastScorerConfig | None = None):
"""初始化快速评分器"""
self.config = config or FastScorerConfig()
# 融合后的权重字典: {token: combined_weight}
# 对于三分类,我们计算 z_interest = z_pos - z_neg
# 所以 combined_weight = (w_pos - w_neg) * idf
self.token_weights: dict[str, float] = {}
# 偏置项: bias_pos - bias_neg
self.bias: float = 0.0
# 元信息
self.meta: dict[str, Any] = {}
self.is_loaded = False
# 统计
self.total_scores = 0
self.total_time = 0.0
# n-gram 正则(预编译)
self._tokenize_pattern = re.compile(r'\s+')
@classmethod
def from_sklearn_model(
cls,
vectorizer, # TfidfVectorizer 或 TfidfFeatureExtractor
model, # SemanticInterestModel 或 LogisticRegression
config: FastScorerConfig | None = None,
) -> "FastScorer":
"""从 sklearn 模型创建快速评分器
Args:
vectorizer: TF-IDF 向量化器
model: Logistic Regression 模型
config: 配置
Returns:
FastScorer 实例
"""
scorer = cls(config)
scorer._extract_weights(vectorizer, model)
return scorer
def _extract_weights(self, vectorizer, model):
"""从 sklearn 模型提取并融合权重
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
"""
# 获取底层 sklearn 对象
if hasattr(vectorizer, 'vectorizer'):
# TfidfFeatureExtractor 包装类
tfidf = vectorizer.vectorizer
else:
tfidf = vectorizer
if hasattr(model, 'clf'):
# SemanticInterestModel 包装类
clf = model.clf
else:
clf = model
# 获取词表和 IDF
vocabulary = tfidf.vocabulary_ # {token: index}
idf = tfidf.idf_ # numpy array, shape (n_features,)
# 获取 LR 权重
# clf.coef_ shape: (n_classes, n_features) 对于多分类
# classes_ 顺序应该是 [-1, 0, 1]
coef = clf.coef_ # shape (3, n_features)
intercept = clf.intercept_ # shape (3,)
classes = clf.classes_
# 找到 -1 和 1 的索引
idx_neg = np.where(classes == -1)[0][0]
idx_pos = np.where(classes == 1)[0][0]
# 计算 z_interest = z_pos - z_neg 的权重
w_interest = coef[idx_pos] - coef[idx_neg] # shape (n_features,)
b_interest = intercept[idx_pos] - intercept[idx_neg]
# 融合: combined_weight = w_interest * idf
combined_weights = w_interest * idf
# 构建 token→weight 字典
token_weights = {}
for token, idx in vocabulary.items():
weight = combined_weights[idx]
# 权重剪枝
if abs(weight) >= self.config.weight_prune_threshold:
token_weights[token] = weight
# 如果设置了 top-k 限制
if self.config.top_k_weights > 0 and len(token_weights) > self.config.top_k_weights:
# 按绝对值排序,保留 top-k
sorted_items = sorted(token_weights.items(), key=lambda x: abs(x[1]), reverse=True)
token_weights = dict(sorted_items[:self.config.top_k_weights])
self.token_weights = token_weights
self.bias = float(b_interest)
self.is_loaded = True
# 更新元信息
self.meta = {
"original_vocab_size": len(vocabulary),
"pruned_vocab_size": len(token_weights),
"prune_ratio": 1 - len(token_weights) / len(vocabulary) if vocabulary else 0,
"weight_prune_threshold": self.config.weight_prune_threshold,
"top_k_weights": self.config.top_k_weights,
"bias": self.bias,
"ngram_range": self.config.ngram_range,
}
logger.info(
f"[FastScorer] 权重提取完成: "
f"原始词表={len(vocabulary)}, 剪枝后={len(token_weights)}, "
f"剪枝率={self.meta['prune_ratio']:.2%}"
)
def _tokenize(self, text: str) -> list[str]:
"""将文本转换为 n-gram tokens
与 sklearn 的 char n-gram 保持一致
"""
if self.config.lowercase:
text = text.lower()
# 字符级 n-gram
min_n, max_n = self.config.ngram_range
tokens = []
for n in range(min_n, max_n + 1):
for i in range(len(text) - n + 1):
tokens.append(text[i:i + n])
return tokens
def _compute_tf(self, tokens: list[str]) -> dict[str, float]:
"""计算词频TF
注意sklearn 使用 sublinear_tf=True 时是 1 + log(tf)
这里简化为原始计数,因为对于短消息差异不大
"""
return dict(Counter(tokens))
def score(self, text: str) -> float:
"""计算单条消息的语义兴趣度
Args:
text: 消息文本
Returns:
兴趣分 [0.0, 1.0]
"""
if not self.is_loaded:
raise ValueError("评分器尚未加载,请先调用 from_sklearn_model() 或 load()")
start_time = time.time()
try:
# 1. Tokenize
tokens = self._tokenize(text)
if not tokens:
return 0.5 # 空文本返回中立值
# 2. 计算 TF
tf = self._compute_tf(tokens)
# 3. 加权求和: z = Σ (w'_i * tf_i) + b
z = self.bias
for token, count in tf.items():
if token in self.token_weights:
z += self.token_weights[token] * count
# 4. Sigmoid 转换
# interest = 1 / (1 + exp(-α * z))
alpha = self.config.sigmoid_alpha
try:
interest = 1.0 / (1.0 + math.exp(-alpha * z))
except OverflowError:
interest = 0.0 if z < 0 else 1.0
# 统计
self.total_scores += 1
self.total_time += time.time() - start_time
return interest
except Exception as e:
logger.error(f"[FastScorer] 评分失败: {e}, 消息: {text[:50]}")
return 0.5
def score_batch(self, texts: list[str]) -> list[float]:
"""批量计算兴趣度"""
if not texts:
return []
return [self.score(text) for text in texts]
async def score_async(self, text: str, timeout: float | None = None) -> float:
"""异步计算兴趣度(使用全局线程池)"""
timeout = timeout or self.config.score_timeout
executor = get_global_executor()
loop = asyncio.get_running_loop()
try:
return await asyncio.wait_for(
loop.run_in_executor(executor, self.score, text),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"[FastScorer] 评分超时({timeout}s): {text[:30]}...")
return 0.5
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
"""异步批量计算兴趣度"""
if not texts:
return []
timeout = timeout or self.config.score_timeout * len(texts)
executor = get_global_executor()
loop = asyncio.get_running_loop()
try:
return await asyncio.wait_for(
loop.run_in_executor(executor, self.score_batch, texts),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"[FastScorer] 批量评分超时({timeout}s), 批次大小: {len(texts)}")
return [0.5] * len(texts)
def get_statistics(self) -> dict[str, Any]:
"""获取统计信息"""
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
return {
"is_loaded": self.is_loaded,
"total_scores": self.total_scores,
"total_time": self.total_time,
"avg_score_time_ms": avg_time * 1000,
"vocab_size": len(self.token_weights),
"meta": self.meta,
}
def save(self, path: Path | str):
"""保存快速评分器"""
import joblib
path = Path(path)
bundle = {
"token_weights": self.token_weights,
"bias": self.bias,
"config": {
"analyzer": self.config.analyzer,
"ngram_range": self.config.ngram_range,
"lowercase": self.config.lowercase,
"weight_prune_threshold": self.config.weight_prune_threshold,
"top_k_weights": self.config.top_k_weights,
"sigmoid_alpha": self.config.sigmoid_alpha,
"score_timeout": self.config.score_timeout,
},
"meta": self.meta,
}
joblib.dump(bundle, path)
logger.info(f"[FastScorer] 已保存到: {path}")
@classmethod
def load(cls, path: Path | str) -> "FastScorer":
"""加载快速评分器"""
import joblib
path = Path(path)
bundle = joblib.load(path)
config = FastScorerConfig(**bundle["config"])
scorer = cls(config)
scorer.token_weights = bundle["token_weights"]
scorer.bias = bundle["bias"]
scorer.meta = bundle.get("meta", {})
scorer.is_loaded = True
logger.info(f"[FastScorer] 已从 {path} 加载,词表大小: {len(scorer.token_weights)}")
return scorer
# ============================================================================
# 批处理评分队列
# ============================================================================
@dataclass
class ScoringRequest:
"""评分请求"""
text: str
future: asyncio.Future
timestamp: float = field(default_factory=time.time)
class BatchScoringQueue:
"""批处理评分队列
攒一小撮消息一起算,提高 CPU 利用率
"""
def __init__(
self,
scorer: FastScorer,
batch_size: int = 16,
flush_interval_ms: float = 50.0,
):
"""初始化批处理队列
Args:
scorer: 评分器实例
batch_size: 批次大小,达到后立即处理
flush_interval_ms: 刷新间隔(毫秒),超过后强制处理
"""
self.scorer = scorer
self.batch_size = batch_size
self.flush_interval = flush_interval_ms / 1000.0
self._pending: list[ScoringRequest] = []
self._lock = asyncio.Lock()
self._flush_task: asyncio.Task | None = None
self._running = False
# 统计
self.total_batches = 0
self.total_requests = 0
async def start(self):
"""启动批处理队列"""
if self._running:
return
self._running = True
self._flush_task = asyncio.create_task(self._flush_loop())
logger.info(f"[BatchQueue] 启动batch_size={self.batch_size}, interval={self.flush_interval*1000}ms")
async def stop(self):
"""停止批处理队列"""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# 处理剩余请求
await self._flush()
logger.info("[BatchQueue] 已停止")
async def score(self, text: str) -> float:
"""提交评分请求并等待结果
Args:
text: 消息文本
Returns:
兴趣分
"""
loop = asyncio.get_running_loop()
future = loop.create_future()
request = ScoringRequest(text=text, future=future)
async with self._lock:
self._pending.append(request)
self.total_requests += 1
# 达到批次大小,立即处理
if len(self._pending) >= self.batch_size:
asyncio.create_task(self._flush())
return await future
async def _flush_loop(self):
"""定时刷新循环"""
while self._running:
await asyncio.sleep(self.flush_interval)
await self._flush()
async def _flush(self):
"""处理当前待处理的请求"""
async with self._lock:
if not self._pending:
return
batch = self._pending.copy()
self._pending.clear()
if not batch:
return
self.total_batches += 1
try:
# 批量评分
texts = [req.text for req in batch]
scores = await self.scorer.score_batch_async(texts)
# 分发结果
for req, score in zip(batch, scores):
if not req.future.done():
req.future.set_result(score)
except Exception as e:
logger.error(f"[BatchQueue] 批量评分失败: {e}")
# 返回默认值
for req in batch:
if not req.future.done():
req.future.set_result(0.5)
def get_statistics(self) -> dict[str, Any]:
"""获取统计信息"""
avg_batch_size = self.total_requests / self.total_batches if self.total_batches > 0 else 0
return {
"total_batches": self.total_batches,
"total_requests": self.total_requests,
"avg_batch_size": avg_batch_size,
"pending_count": len(self._pending),
"batch_size": self.batch_size,
"flush_interval_ms": self.flush_interval * 1000,
}
# ============================================================================
# 优化评分器工厂
# ============================================================================
_fast_scorer_instances: dict[str, FastScorer] = {}
_batch_queue_instances: dict[str, BatchScoringQueue] = {}
async def get_fast_scorer(
model_path: str | Path,
use_batch_queue: bool = False,
batch_size: int = 16,
flush_interval_ms: float = 50.0,
force_reload: bool = False,
) -> FastScorer | BatchScoringQueue:
"""获取快速评分器实例(单例)
Args:
model_path: 模型文件路径(.pkl 格式,可以是 sklearn 模型或 FastScorer 保存的)
use_batch_queue: 是否使用批处理队列
batch_size: 批处理大小
flush_interval_ms: 批处理刷新间隔(毫秒)
force_reload: 是否强制重新加载
Returns:
FastScorer 或 BatchScoringQueue 实例
"""
import joblib
model_path = Path(model_path)
path_key = str(model_path.resolve())
# 检查是否已存在
if not force_reload:
if use_batch_queue and path_key in _batch_queue_instances:
return _batch_queue_instances[path_key]
elif not use_batch_queue and path_key in _fast_scorer_instances:
return _fast_scorer_instances[path_key]
# 加载模型
logger.info(f"[优化评分器] 加载模型: {model_path}")
bundle = joblib.load(model_path)
# 检查是 FastScorer 还是 sklearn 模型
if "token_weights" in bundle:
# FastScorer 格式
scorer = FastScorer.load(model_path)
else:
# sklearn 模型格式,需要转换
vectorizer = bundle["vectorizer"]
model = bundle["model"]
config = FastScorerConfig(
ngram_range=vectorizer.get_config().get("ngram_range", (2, 4)),
weight_prune_threshold=1e-4,
)
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
_fast_scorer_instances[path_key] = scorer
# 如果需要批处理队列
if use_batch_queue:
queue = BatchScoringQueue(scorer, batch_size, flush_interval_ms)
await queue.start()
_batch_queue_instances[path_key] = queue
return queue
return scorer
def convert_sklearn_to_fast(
sklearn_model_path: str | Path,
output_path: str | Path | None = None,
config: FastScorerConfig | None = None,
) -> FastScorer:
"""将 sklearn 模型转换为 FastScorer 格式
Args:
sklearn_model_path: sklearn 模型路径
output_path: 输出路径(可选)
config: FastScorer 配置
Returns:
FastScorer 实例
"""
import joblib
sklearn_model_path = Path(sklearn_model_path)
bundle = joblib.load(sklearn_model_path)
vectorizer = bundle["vectorizer"]
model = bundle["model"]
# 从 vectorizer 配置推断 n-gram range
if config is None:
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
config = FastScorerConfig(
ngram_range=vconfig.get("ngram_range", (2, 4)),
weight_prune_threshold=1e-4,
)
scorer = FastScorer.from_sklearn_model(vectorizer, model, config)
# 保存转换后的模型
if output_path:
output_path = Path(output_path)
scorer.save(output_path)
return scorer
def clear_fast_scorer_instances():
"""清空所有快速评分器实例"""
global _fast_scorer_instances, _batch_queue_instances
# 停止所有批处理队列
for queue in _batch_queue_instances.values():
asyncio.create_task(queue.stop())
_fast_scorer_instances.clear()
_batch_queue_instances.clear()
logger.info("[优化评分器] 已清空所有实例")

View File

@@ -0,0 +1,744 @@
"""运行时语义兴趣度评分器
在线推理时使用,提供快速的兴趣度评分
支持异步加载、超时保护、批量优化、模型预热
2024.12 优化更新:
- 新增 FastScorer 模式,绕过 sklearn 直接使用 token→weight 字典
- 全局线程池避免每次创建新的 executor
- 可选的批处理队列模式
"""
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
import joblib
import numpy as np
from src.common.logger import get_logger
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
from src.chat.semantic_interest.model_lr import SemanticInterestModel
logger = get_logger("semantic_interest.scorer")
# 全局配置
DEFAULT_SCORE_TIMEOUT = 2.0 # 评分超时(秒),从 5.0 降低到 2.0
# 全局线程池(避免每次创建新的 executor
_GLOBAL_EXECUTOR: ThreadPoolExecutor | None = None
_EXECUTOR_MAX_WORKERS = 4
def _get_global_executor() -> ThreadPoolExecutor:
"""获取全局线程池(单例)"""
global _GLOBAL_EXECUTOR
if _GLOBAL_EXECUTOR is None:
_GLOBAL_EXECUTOR = ThreadPoolExecutor(
max_workers=_EXECUTOR_MAX_WORKERS,
thread_name_prefix="semantic_scorer"
)
logger.info(f"[评分器] 创建全局线程池workers={_EXECUTOR_MAX_WORKERS}")
return _GLOBAL_EXECUTOR
# 单例管理
_scorer_instances: dict[str, "SemanticInterestScorer"] = {} # 模型路径 -> 评分器实例
_instance_lock = asyncio.Lock() # 创建实例的锁
class SemanticInterestScorer:
"""语义兴趣度评分器
加载训练好的模型,在运行时快速计算消息的语义兴趣度
优化特性:
- 异步加载支持(非阻塞)
- 批量评分优化
- 超时保护
- 模型预热
- 全局线程池(避免重复创建 executor
- 可选的 FastScorer 模式(绕过 sklearn
"""
def __init__(self, model_path: str | Path, use_fast_scorer: bool = True):
"""初始化评分器
Args:
model_path: 模型文件路径 (.pkl)
use_fast_scorer: 是否使用快速评分器模式(推荐)
"""
self.model_path = Path(model_path)
self.vectorizer: TfidfFeatureExtractor | None = None
self.model: SemanticInterestModel | None = None
self.meta: dict[str, Any] = {}
self.is_loaded = False
# 快速评分器模式
self._use_fast_scorer = use_fast_scorer
self._fast_scorer = None # FastScorer 实例
# 统计信息
self.total_scores = 0
self.total_time = 0.0
def load(self):
"""同步加载模型(阻塞)"""
if not self.model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
logger.info(f"开始加载模型: {self.model_path}")
start_time = time.time()
try:
bundle = joblib.load(self.model_path)
self.vectorizer = bundle["vectorizer"]
self.model = bundle["model"]
self.meta = bundle.get("meta", {})
# 如果启用快速评分器模式,创建 FastScorer
if self._use_fast_scorer:
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
config = FastScorerConfig(
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
weight_prune_threshold=1e-4,
)
self._fast_scorer = FastScorer.from_sklearn_model(
self.vectorizer, self.model, config
)
logger.info(
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
f"剪枝到 {len(self._fast_scorer.token_weights)}"
)
self.is_loaded = True
load_time = time.time() - start_time
logger.info(
f"模型加载成功,耗时: {load_time:.3f}秒, "
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
)
if self.meta:
logger.info(f"模型元信息: {self.meta}")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def load_async(self):
"""异步加载模型(非阻塞)"""
if not self.model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
logger.info(f"开始异步加载模型: {self.model_path}")
start_time = time.time()
try:
# 在全局线程池中执行 I/O 密集型操作
executor = _get_global_executor()
loop = asyncio.get_running_loop()
bundle = await loop.run_in_executor(executor, joblib.load, self.model_path)
self.vectorizer = bundle["vectorizer"]
self.model = bundle["model"]
self.meta = bundle.get("meta", {})
# 如果启用快速评分器模式,创建 FastScorer
if self._use_fast_scorer:
from src.chat.semantic_interest.optimized_scorer import FastScorer, FastScorerConfig
config = FastScorerConfig(
ngram_range=self.vectorizer.get_config().get("ngram_range", (2, 3)),
weight_prune_threshold=1e-4,
)
self._fast_scorer = FastScorer.from_sklearn_model(
self.vectorizer, self.model, config
)
logger.info(
f"[FastScorer] 已启用,词表从 {self.vectorizer.get_vocabulary_size()} "
f"剪枝到 {len(self._fast_scorer.token_weights)}"
)
self.is_loaded = True
load_time = time.time() - start_time
logger.info(
f"模型异步加载成功,耗时: {load_time:.3f}秒, "
f"词表大小: {self.vectorizer.get_vocabulary_size()}" # type: ignore
)
if self.meta:
logger.info(f"模型元信息: {self.meta}")
# 预热模型
await self._warmup_async()
except Exception as e:
logger.error(f"模型异步加载失败: {e}")
raise
def reload(self):
"""重新加载模型(热更新)"""
logger.info("重新加载模型...")
self.is_loaded = False
self.load()
async def reload_async(self):
"""异步重新加载模型"""
logger.info("异步重新加载模型...")
self.is_loaded = False
await self.load_async()
def score(self, text: str) -> float:
"""计算单条消息的语义兴趣度
Args:
text: 消息文本
Returns:
兴趣分 [0.0, 1.0],越高表示越感兴趣
"""
if not self.is_loaded:
raise ValueError("模型尚未加载,请先调用 load() 或 load_async() 方法")
start_time = time.time()
try:
# 优先使用 FastScorer绕过 sklearn更快
if self._fast_scorer is not None:
interest = self._fast_scorer.score(text)
else:
# 回退到原始 sklearn 路径
# 向量化
X = self.vectorizer.transform([text])
# 预测概率
proba = self.model.predict_proba(X)[0]
# proba 顺序为 [-1, 0, 1]
p_neg, p_neu, p_pos = proba
# 兴趣分计算策略:
# interest = P(1) + 0.5 * P(0)
# 这样:纯正向(1)=1.0, 纯中立(0)=0.5, 纯负向(-1)=0.0
interest = float(p_pos + 0.5 * p_neu)
# 确保在 [0, 1] 范围内
interest = max(0.0, min(1.0, interest))
# 统计
self.total_scores += 1
self.total_time += time.time() - start_time
return interest
except Exception as e:
logger.error(f"兴趣度计算失败: {e}, 消息: {text[:50]}")
return 0.5 # 默认返回中立值
async def score_async(self, text: str, timeout: float = DEFAULT_SCORE_TIMEOUT) -> float:
"""异步计算兴趣度(带超时保护)
Args:
text: 消息文本
timeout: 超时时间(秒),超时返回中立值 0.5
Returns:
兴趣分 [0.0, 1.0]
"""
# 使用全局线程池,避免每次创建新的 executor
executor = _get_global_executor()
loop = asyncio.get_running_loop()
try:
return await asyncio.wait_for(
loop.run_in_executor(executor, self.score, text),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"兴趣度计算超时({timeout}秒),消息: {text[:50]}")
return 0.5 # 默认中立值
def score_batch(self, texts: list[str]) -> list[float]:
"""批量计算兴趣度
Args:
texts: 消息文本列表
Returns:
兴趣分列表
"""
if not self.is_loaded:
raise ValueError("模型尚未加载")
if not texts:
return []
start_time = time.time()
try:
# 优先使用 FastScorer
if self._fast_scorer is not None:
interests = self._fast_scorer.score_batch(texts)
# 统计
self.total_scores += len(texts)
self.total_time += time.time() - start_time
return interests
else:
# 回退到原始 sklearn 路径
# 批量向量化
X = self.vectorizer.transform(texts)
# 批量预测
proba = self.model.predict_proba(X)
# 计算兴趣分
interests = []
for p_neg, p_neu, p_pos in proba:
interest = float(p_pos + 0.5 * p_neu)
interest = max(0.0, min(1.0, interest))
interests.append(interest)
# 统计
self.total_scores += len(texts)
self.total_time += time.time() - start_time
return interests
except Exception as e:
logger.error(f"批量兴趣度计算失败: {e}")
return [0.5] * len(texts)
async def score_batch_async(self, texts: list[str], timeout: float | None = None) -> list[float]:
"""异步批量计算兴趣度
Args:
texts: 消息文本列表
timeout: 超时时间None 则使用单条超时*文本数
Returns:
兴趣分列表
"""
if not texts:
return []
# 计算动态超时
if timeout is None:
timeout = DEFAULT_SCORE_TIMEOUT * len(texts)
# 使用全局线程池
executor = _get_global_executor()
loop = asyncio.get_running_loop()
try:
return await asyncio.wait_for(
loop.run_in_executor(executor, self.score_batch, texts),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"批量兴趣度计算超时({timeout}秒),批次大小: {len(texts)}")
return [0.5] * len(texts)
def _warmup(self, sample_texts: list[str] | None = None):
"""预热模型(执行几次推理以优化性能)
Args:
sample_texts: 预热用的样本文本None 则使用默认样本
"""
if not self.is_loaded:
return
if sample_texts is None:
sample_texts = [
"你好",
"今天天气怎么样?",
"我对这个话题很感兴趣"
]
logger.debug(f"开始预热模型,样本数: {len(sample_texts)}")
start_time = time.time()
for text in sample_texts:
try:
self.score(text)
except Exception:
pass # 忽略预热错误
warmup_time = time.time() - start_time
logger.debug(f"模型预热完成,耗时: {warmup_time:.3f}")
async def _warmup_async(self, sample_texts: list[str] | None = None):
"""异步预热模型"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._warmup, sample_texts)
def get_detailed_score(self, text: str) -> dict[str, Any]:
"""获取详细的兴趣度评分信息
Args:
text: 消息文本
Returns:
包含概率分布和最终分数的详细信息
"""
if not self.is_loaded:
raise ValueError("模型尚未加载")
X = self.vectorizer.transform([text])
proba = self.model.predict_proba(X)[0]
pred_label = self.model.predict(X)[0]
p_neg, p_neu, p_pos = proba
interest = float(p_pos + 0.5 * p_neu)
return {
"interest_score": max(0.0, min(1.0, interest)),
"proba_distribution": {
"dislike": float(p_neg),
"neutral": float(p_neu),
"like": float(p_pos),
},
"predicted_label": int(pred_label),
"text_preview": text[:100],
}
def get_statistics(self) -> dict[str, Any]:
"""获取评分器统计信息
Returns:
统计信息字典
"""
avg_time = self.total_time / self.total_scores if self.total_scores > 0 else 0
stats = {
"is_loaded": self.is_loaded,
"model_path": str(self.model_path),
"total_scores": self.total_scores,
"total_time": self.total_time,
"avg_score_time": avg_time,
"avg_score_time_ms": avg_time * 1000, # 毫秒单位更直观
"vocabulary_size": (
self.vectorizer.get_vocabulary_size()
if self.vectorizer and self.is_loaded
else 0
),
"use_fast_scorer": self._use_fast_scorer,
"fast_scorer_enabled": self._fast_scorer is not None,
"meta": self.meta,
}
# 如果启用了 FastScorer添加其统计
if self._fast_scorer is not None:
stats["fast_scorer_stats"] = self._fast_scorer.get_statistics()
return stats
def __repr__(self) -> str:
mode = "fast" if self._fast_scorer else "sklearn"
return (
f"SemanticInterestScorer("
f"loaded={self.is_loaded}, "
f"mode={mode}, "
f"model={self.model_path.name})"
)
class ModelManager:
"""模型管理器
支持模型热更新、版本管理和人设感知的模型切换
"""
def __init__(self, model_dir: Path):
"""初始化管理器
Args:
model_dir: 模型目录
"""
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
self.current_scorer: SemanticInterestScorer | None = None
self.current_version: str | None = None
self.current_persona_info: dict[str, Any] | None = None
self._lock = asyncio.Lock()
# 自动训练器集成
self._auto_trainer = None
self._auto_training_started = False # 防止重复启动自动训练
async def load_model(self, version: str = "latest", persona_info: dict[str, Any] | None = None, use_async: bool = True) -> SemanticInterestScorer:
"""加载指定版本的模型,支持人设感知(使用单例)
Args:
version: 模型版本号或 "latest""auto"
persona_info: 人设信息,用于自动选择匹配的模型
use_async: 是否使用异步加载(推荐)
Returns:
评分器实例(单例)
"""
async with self._lock:
# 如果指定了人设信息,尝试使用自动训练器
if persona_info is not None and version == "auto":
model_path = await self._get_persona_model(persona_info)
elif version == "latest":
model_path = self._get_latest_model()
else:
model_path = self.model_dir / f"semantic_interest_{version}.pkl"
if not model_path or not model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 使用单例获取评分器
scorer = await get_semantic_scorer(model_path, force_reload=False, use_async=use_async)
self.current_scorer = scorer
self.current_version = version
self.current_persona_info = persona_info
logger.info(f"模型管理器已加载版本: {version}, 文件: {model_path.name}")
return scorer
async def reload_current_model(self):
"""重新加载当前模型"""
if not self.current_scorer:
raise ValueError("尚未加载任何模型")
async with self._lock:
await self.current_scorer.reload_async()
logger.info("模型已重新加载")
def _get_latest_model(self) -> Path:
"""获取最新的模型文件
Returns:
最新模型文件路径
"""
model_files = list(self.model_dir.glob("semantic_interest_*.pkl"))
if not model_files:
raise FileNotFoundError(f"{self.model_dir} 中未找到模型文件")
# 按修改时间排序
latest = max(model_files, key=lambda p: p.stat().st_mtime)
return latest
def get_scorer(self) -> SemanticInterestScorer:
"""获取当前评分器
Returns:
当前评分器实例
"""
if not self.current_scorer:
raise ValueError("尚未加载任何模型")
return self.current_scorer
async def _get_persona_model(self, persona_info: dict[str, Any]) -> Path | None:
"""根据人设信息获取或训练模型
Args:
persona_info: 人设信息
Returns:
模型文件路径
"""
try:
# 延迟导入避免循环依赖
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
# 检查是否需要训练
trained, model_path = await self._auto_trainer.auto_train_if_needed(
persona_info=persona_info,
days=7,
max_samples=1000, # 初始训练使用1000条消息
)
if trained and model_path:
logger.info(f"[模型管理器] 使用新训练的模型: {model_path.name}")
return model_path
# 获取现有的人设模型
model_path = self._auto_trainer.get_model_for_persona(persona_info)
if model_path:
return model_path
# 降级到 latest
logger.warning("[模型管理器] 未找到人设模型,使用 latest")
return self._get_latest_model()
except Exception as e:
logger.error(f"[模型管理器] 获取人设模型失败: {e}")
return self._get_latest_model()
async def check_and_reload_for_persona(self, persona_info: dict[str, Any]) -> bool:
"""检查人设变化并重新加载模型
Args:
persona_info: 当前人设信息
Returns:
True 如果重新加载了模型
"""
# 检查人设是否变化
if self.current_persona_info == persona_info:
return False
logger.info("[模型管理器] 检测到人设变化,重新加载模型...")
try:
await self.load_model(version="auto", persona_info=persona_info)
return True
except Exception as e:
logger.error(f"[模型管理器] 重新加载模型失败: {e}")
return False
async def start_auto_training(self, persona_info: dict[str, Any], interval_hours: int = 24):
"""启动自动训练任务
Args:
persona_info: 人设信息
interval_hours: 检查间隔(小时)
"""
# 使用锁防止并发启动
async with self._lock:
# 检查是否已经启动
if self._auto_training_started:
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
return
try:
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
# 标记为已启动
self._auto_training_started = True
# 在后台任务中运行
asyncio.create_task(
self._auto_trainer.scheduled_train(persona_info, interval_hours)
)
except Exception as e:
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
self._auto_training_started = False # 失败时重置标志
# 单例获取函数
async def get_semantic_scorer(
model_path: str | Path,
force_reload: bool = False,
use_async: bool = True
) -> SemanticInterestScorer:
"""获取语义兴趣度评分器实例(单例模式)
同一个模型路径只会创建一个评分器实例,避免重复加载模型。
Args:
model_path: 模型文件路径
force_reload: 是否强制重新加载模型
use_async: 是否使用异步加载(推荐)
Returns:
评分器实例(单例)
Example:
>>> scorer = await get_semantic_scorer("data/semantic_interest/models/model.pkl")
>>> score = await scorer.score_async("今天天气真好")
"""
model_path = Path(model_path)
path_key = str(model_path.resolve()) # 使用绝对路径作为键
async with _instance_lock:
# 检查是否已存在实例
if not force_reload and path_key in _scorer_instances:
scorer = _scorer_instances[path_key]
if scorer.is_loaded:
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
return scorer
else:
logger.info(f"[单例] 评分器未加载,重新加载: {model_path.name}")
# 创建或重新加载实例
if path_key not in _scorer_instances:
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
scorer = SemanticInterestScorer(model_path)
_scorer_instances[path_key] = scorer
else:
scorer = _scorer_instances[path_key]
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
# 加载模型
if use_async:
await scorer.load_async()
else:
scorer.load()
return scorer
def get_semantic_scorer_sync(
model_path: str | Path,
force_reload: bool = False
) -> SemanticInterestScorer:
"""获取语义兴趣度评分器实例(同步版本,单例模式)
注意:这是同步版本,推荐使用异步版本 get_semantic_scorer()
Args:
model_path: 模型文件路径
force_reload: 是否强制重新加载模型
Returns:
评分器实例(单例)
"""
model_path = Path(model_path)
path_key = str(model_path.resolve())
# 检查是否已存在实例
if not force_reload and path_key in _scorer_instances:
scorer = _scorer_instances[path_key]
if scorer.is_loaded:
logger.debug(f"[单例] 复用已加载的评分器: {model_path.name}")
return scorer
# 创建或重新加载实例
if path_key not in _scorer_instances:
logger.info(f"[单例] 创建新的评分器实例: {model_path.name}")
scorer = SemanticInterestScorer(model_path)
_scorer_instances[path_key] = scorer
else:
scorer = _scorer_instances[path_key]
logger.info(f"[单例] 强制重新加载评分器: {model_path.name}")
# 加载模型
scorer.load()
return scorer
def clear_scorer_instances():
"""清空所有评分器实例(释放内存)"""
global _scorer_instances
count = len(_scorer_instances)
_scorer_instances.clear()
logger.info(f"[单例] 已清空 {count} 个评分器实例")
def get_all_scorer_instances() -> dict[str, SemanticInterestScorer]:
"""获取所有已创建的评分器实例
Returns:
{模型路径: 评分器实例} 的字典
"""
return _scorer_instances.copy()

View File

@@ -0,0 +1,202 @@
"""训练器入口脚本
统一的训练流程入口,包含数据采样、标注、训练、评估
"""
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Any
import joblib
from src.common.logger import get_logger
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
from src.chat.semantic_interest.model_lr import train_semantic_model
logger = get_logger("semantic_interest.trainer")
class SemanticInterestTrainer:
"""语义兴趣度训练器
统一管理训练流程
"""
def __init__(
self,
data_dir: Path | None = None,
model_dir: Path | None = None,
):
"""初始化训练器
Args:
data_dir: 数据集目录
model_dir: 模型保存目录
"""
self.data_dir = Path(data_dir or "data/semantic_interest/datasets")
self.model_dir = Path(model_dir or "data/semantic_interest/models")
self.data_dir.mkdir(parents=True, exist_ok=True)
self.model_dir.mkdir(parents=True, exist_ok=True)
async def prepare_dataset(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
model_name: str | None = None,
dataset_name: str | None = None,
generate_initial_keywords: bool = True,
keyword_temperature: float = 0.7,
keyword_iterations: int = 3,
) -> Path:
"""准备训练数据集
Args:
persona_info: 人格信息
days: 采样最近 N 天的消息
max_samples: 最大采样数
model_name: LLM 模型名称
dataset_name: 数据集名称(默认使用时间戳)
generate_initial_keywords: 是否生成初始关键词数据集
keyword_temperature: 关键词生成温度
keyword_iterations: 关键词生成迭代次数
Returns:
数据集文件路径
"""
if dataset_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dataset_name = f"dataset_{timestamp}"
output_path = self.data_dir / f"{dataset_name}.json"
logger.info(f"开始准备数据集: {dataset_name}")
await generate_training_dataset(
output_path=output_path,
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_name=model_name,
generate_initial_keywords=generate_initial_keywords,
keyword_temperature=keyword_temperature,
keyword_iterations=keyword_iterations,
)
return output_path
def train_model(
self,
dataset_path: Path,
model_version: str | None = None,
tfidf_config: dict | None = None,
model_config: dict | None = None,
test_size: float = 0.1,
) -> tuple[Path, dict]:
"""训练模型
Args:
dataset_path: 数据集文件路径
model_version: 模型版本号(默认使用时间戳)
tfidf_config: TF-IDF 配置
model_config: 模型配置
test_size: 验证集比例
Returns:
(模型文件路径, 训练指标)
"""
logger.info(f"开始训练模型,数据集: {dataset_path}")
# 加载数据集
from src.chat.semantic_interest.dataset import DatasetGenerator
texts, labels = DatasetGenerator.load_dataset(dataset_path)
# 训练模型
vectorizer, model, metrics = train_semantic_model(
texts=texts,
labels=labels,
test_size=test_size,
tfidf_config=tfidf_config,
model_config=model_config,
)
# 保存模型
if model_version is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_version = timestamp
model_path = self.model_dir / f"semantic_interest_{model_version}.pkl"
bundle = {
"vectorizer": vectorizer,
"model": model,
"meta": {
"version": model_version,
"trained_at": datetime.now().isoformat(),
"dataset": str(dataset_path),
"train_samples": len(texts),
"metrics": metrics,
"tfidf_config": vectorizer.get_config(),
"model_config": model.get_config(),
},
}
joblib.dump(bundle, model_path)
logger.info(f"模型已保存到: {model_path}")
return model_path, metrics
async def full_training_pipeline(
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 1000,
llm_model_name: str | None = None,
tfidf_config: dict | None = None,
model_config: dict | None = None,
dataset_name: str | None = None,
model_version: str | None = None,
) -> tuple[Path, Path, dict]:
"""完整训练流程
Args:
persona_info: 人格信息
days: 采样天数
max_samples: 最大采样数
llm_model_name: LLM 模型名称
tfidf_config: TF-IDF 配置
model_config: 模型配置
dataset_name: 数据集名称
model_version: 模型版本
Returns:
(数据集路径, 模型路径, 训练指标)
"""
logger.info("开始完整训练流程")
# 1. 准备数据集
dataset_path = await self.prepare_dataset(
persona_info=persona_info,
days=days,
max_samples=max_samples,
model_name=llm_model_name,
dataset_name=dataset_name,
)
# 2. 训练模型
model_path, metrics = self.train_model(
dataset_path=dataset_path,
model_version=model_version,
tfidf_config=tfidf_config,
model_config=model_config,
)
logger.info("完整训练流程完成")
logger.info(f"数据集: {dataset_path}")
logger.info(f"模型: {model_path}")
logger.info(f"指标: {metrics}")
return dataset_path, model_path, metrics

View File

@@ -96,7 +96,7 @@ class ChineseTypoGenerator:
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
if _shared_pinyin_dict is None:
_shared_pinyin_dict = self._create_pinyin_dict()
_shared_pinyin_dict = self._load_or_create_pinyin_dict()
logger.debug("拼音字典已创建并缓存")
self.pinyin_dict = _shared_pinyin_dict
@@ -141,6 +141,35 @@ class ChineseTypoGenerator:
return normalized_freq
def _load_or_create_pinyin_dict(self):
"""
加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动)
"""
cache_file = Path("depends-data/pinyin_dict.json")
if cache_file.exists():
try:
with open(cache_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
# 恢复为 defaultdict(list) 以兼容旧逻辑
restored = defaultdict(list)
for py, chars in data.items():
restored[py] = list(chars)
return restored
except Exception as e:
logger.warning(f"读取拼音缓存失败,将重新生成: {e}")
pinyin_dict = self._create_pinyin_dict()
try:
cache_file.parent.mkdir(parents=True, exist_ok=True)
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8"))
except Exception as e:
logger.warning(f"写入拼音缓存失败(不影响使用): {e}")
return pinyin_dict
@staticmethod
def _create_pinyin_dict():
"""

View File

@@ -10,11 +10,6 @@ CoreSink 统一管理器
3. 使用 MessageRuntime 进行消息路由和处理
4. 提供统一的消息发送接口
架构说明2025-11 重构):
- 集成 mofox_wire.MessageRuntime 作为消息路由中心
- 使用 @runtime.on_message() 装饰器注册消息处理器
- 利用 before_hook/after_hook/error_hook 处理前置/后置/错误逻辑
- 简化消息处理链条,提高可扩展性
"""
from __future__ import annotations

View File

@@ -215,26 +215,25 @@ class QueryBuilder(Generic[T]):
async with get_db_session() as session:
result = await session.execute(paginated_stmt)
# .all() 已经返回 list无需再包装
instances = result.scalars().all()
if not instances:
# 没有更多数据
break
# 在 session 内部转换为字典列表
# 在 session 内部转换为字典列表,保证字段可用再释放连接
instances_dicts = [_model_to_dict(inst) for inst in instances]
if as_dict:
yield instances_dicts
else:
yield [_dict_to_model(self.model, row) for row in instances_dicts]
if as_dict:
yield instances_dicts
else:
yield [_dict_to_model(self.model, row) for row in instances_dicts]
# 如果返回的记录数小于 batch_size说明已经是最后一批
if len(instances) < batch_size:
break
# 如果返回的记录数小于 batch_size说明已经是最后一批
if len(instances) < batch_size:
break
offset += batch_size
offset += batch_size
async def iter_all(
self,

View File

@@ -20,6 +20,7 @@ from src.common.logger import get_logger
logger = get_logger("redis_cache")
import redis.asyncio as aioredis
from redis.asyncio.connection import Connection, SSLConnection
class RedisCache(CacheBackend):
@@ -98,7 +99,11 @@ class RedisCache(CacheBackend):
return self._client
try:
# 创建连接池 (使用 aioredis 模块确保类型安全)
# redis-py 7.x+ 使用 connection_class 来指定 SSL 连接
# 不再支持直接传递 ssl=True/False 给 ConnectionPool
connection_class = SSLConnection if self.ssl else Connection
# 创建连接池
self._pool = aioredis.ConnectionPool(
host=self.host,
port=self.port,
@@ -108,7 +113,7 @@ class RedisCache(CacheBackend):
socket_timeout=self.socket_timeout,
socket_connect_timeout=self.socket_timeout,
decode_responses=False, # 我们自己处理序列化
ssl=self.ssl,
connection_class=connection_class,
)
# 创建客户端

View File

@@ -0,0 +1,259 @@
"""
日志广播系统
用于实时推送日志到多个订阅者(如WebSocket客户端)
"""
import asyncio
import logging
from collections import deque
from collections.abc import Callable
from typing import Any
import orjson
class LogBroadcaster:
"""日志广播器,用于实时推送日志到订阅者"""
def __init__(self, max_buffer_size: int = 1000):
"""
初始化日志广播器
Args:
max_buffer_size: 缓冲区最大大小,超过后会丢弃旧日志
"""
self.subscribers: set[Callable[[dict[str, Any]], None]] = set()
self.buffer: deque[dict[str, Any]] = deque(maxlen=max_buffer_size)
self._lock = asyncio.Lock()
async def subscribe(self, callback: Callable[[dict[str, Any]], None]) -> None:
"""
订阅日志推送
Args:
callback: 接收日志的回调函数,参数为日志字典
"""
async with self._lock:
self.subscribers.add(callback)
async def unsubscribe(self, callback: Callable[[dict[str, Any]], None]) -> None:
"""
取消订阅
Args:
callback: 要移除的回调函数
"""
async with self._lock:
self.subscribers.discard(callback)
async def broadcast(self, log_record: dict[str, Any]) -> None:
"""
广播日志到所有订阅者
Args:
log_record: 日志记录字典
"""
# 添加到缓冲区
async with self._lock:
self.buffer.append(log_record)
# 创建订阅者列表的副本,避免在迭代时修改
subscribers = list(self.subscribers)
# 异步发送到所有订阅者
tasks = []
for callback in subscribers:
try:
if asyncio.iscoroutinefunction(callback):
tasks.append(asyncio.create_task(callback(log_record)))
else:
# 同步回调在线程池中执行
tasks.append(asyncio.to_thread(callback, log_record))
except Exception:
# 忽略单个订阅者的错误
pass
# 等待所有发送完成(但不阻塞太久)
if tasks:
await asyncio.wait(tasks, timeout=1.0)
def get_recent_logs(self, limit: int = 100) -> list[dict[str, Any]]:
"""
获取最近的日志记录
Args:
limit: 返回的最大日志数量
Returns:
日志记录列表
"""
return list(self.buffer)[-limit:]
def clear_buffer(self) -> None:
"""清空日志缓冲区"""
self.buffer.clear()
class BroadcastLogHandler(logging.Handler):
"""
日志处理器,将日志推送到广播器
"""
def __init__(self, broadcaster: LogBroadcaster):
"""
初始化处理器
Args:
broadcaster: 日志广播器实例
"""
super().__init__()
self.broadcaster = broadcaster
self.loop: asyncio.AbstractEventLoop | None = None
def _get_logger_metadata(self, logger_name: str) -> dict[str, str | None]:
"""
获取logger的元数据别名和颜色
Args:
logger_name: logger名称
Returns:
包含alias和color的字典
"""
try:
# 导入logger元数据获取函数
from src.common.logger import get_logger_meta
return get_logger_meta(logger_name)
except Exception:
# 如果获取失败,返回空元数据
return {"alias": None, "color": None}
def emit(self, record: logging.LogRecord) -> None:
"""
处理日志记录
Args:
record: 日志记录
"""
try:
# 获取logger元数据别名和颜色
logger_meta = self._get_logger_metadata(record.name)
# 转换日志记录为字典
log_dict = {
"timestamp": self.format_time(record),
"level": record.levelname, # 保持大写,与前端筛选器一致
"logger_name": record.name, # 原始logger名称
"event": record.getMessage(),
}
# 添加别名和颜色(如果存在)
if logger_meta["alias"]:
log_dict["alias"] = logger_meta["alias"]
if logger_meta["color"]:
log_dict["color"] = logger_meta["color"]
# 添加额外字段
if hasattr(record, "__dict__"):
for key, value in record.__dict__.items():
if key not in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"thread",
"threadName",
"exc_info",
"exc_text",
"stack_info",
):
try:
# 尝试序列化以确保可以转为JSON
orjson.dumps(value)
log_dict[key] = value
except (TypeError, ValueError):
log_dict[key] = str(value)
# 获取或创建事件循环
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# 没有运行的事件循环,创建新任务
if self.loop is None:
try:
self.loop = asyncio.new_event_loop()
except Exception:
return
loop = self.loop
# 在事件循环中异步广播
asyncio.run_coroutine_threadsafe(
self.broadcaster.broadcast(log_dict), loop
)
except Exception:
# 忽略广播错误,避免影响日志系统
pass
def format_time(self, record: logging.LogRecord) -> str:
"""
格式化时间戳
Args:
record: 日志记录
Returns:
格式化的时间字符串
"""
from datetime import datetime
dt = datetime.fromtimestamp(record.created)
return dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
# 全局广播器实例
_global_broadcaster: LogBroadcaster | None = None
def get_log_broadcaster() -> LogBroadcaster:
"""
获取全局日志广播器实例
Returns:
日志广播器实例
"""
global _global_broadcaster
if _global_broadcaster is None:
_global_broadcaster = LogBroadcaster()
return _global_broadcaster
def setup_log_broadcasting() -> LogBroadcaster:
"""
设置日志广播系统,将日志处理器添加到根日志记录器
Returns:
日志广播器实例
"""
broadcaster = get_log_broadcaster()
# 创建并添加广播处理器到根日志记录器
handler = BroadcastLogHandler(broadcaster)
handler.setLevel(logging.DEBUG)
# 添加到根日志记录器
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return broadcaster

View File

@@ -1,6 +1,7 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
import os
import tarfile
import threading
import time
@@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler):
self.backup_count = backup_count
self.encoding = encoding
self._lock = threading.Lock()
self._current_size = 0
self._bytes_since_check = 0
self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8"))
self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024)
# 当前活跃的日志文件
self.current_file = None
@@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler):
# 极低概率碰撞,稍作等待
time.sleep(0.001)
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0
self._bytes_since_check = 0
def _should_rollover(self):
"""检查是否需要轮转"""
if self.current_file and self.current_file.exists():
return self.current_file.stat().st_size >= self.max_bytes
def _should_rollover(self, incoming_size: int = 0) -> bool:
"""检查是否需要轮转使用内存缓存的大小信息减少磁盘stat次数。"""
if not self.current_file:
return False
projected = self._current_size + incoming_size
if projected >= self.max_bytes:
return True
self._bytes_since_check += incoming_size
if self._bytes_since_check >= self._stat_refresh_threshold:
try:
if self.current_file.exists():
self._current_size = self.current_file.stat().st_size
else:
self._current_size = 0
except OSError:
self._current_size = 0
finally:
self._bytes_since_check = 0
return False
def _do_rollover(self):
@@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler):
def emit(self, record):
"""发出日志记录"""
try:
message = self.format(record)
encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes
with self._lock:
# 检查是否需要轮转
if self._should_rollover():
if self._should_rollover(encoded_len):
self._do_rollover()
# 写入日志
if self.current_stream:
msg = self.format(record)
self.current_stream.write(msg + "\n")
self.current_stream.write(message + "\n")
self.current_stream.flush()
self._current_size += encoded_len
except Exception:
self.handleError(record)
@@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = {
}
# 创建全局 Rich Console 实例用于颜色渲染
_rich_console = Console(force_terminal=True, color_system="truecolor")
class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
@@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer:
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors
self._config = LOG_CONFIG
self._render_console = Console(force_terminal=True, color_system="truecolor", width=999)
# 日志级别颜色 (#RRGGBB 格式)
self._level_colors_hex = {
@@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer:
self._enable_level_colors = False
self._enable_full_content_colors = False
@staticmethod
def _looks_like_markup(content: str) -> bool:
"""快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。"""
if not content:
return False
return "[" in content and "]" in content
def _render_content_text(self, content: str, *, style: str | None = None) -> Text:
"""只在必要时解析 Rich 标记降低CPU占用。"""
if self._looks_like_markup(content):
try:
return Text.from_markup(content, style=style)
except Exception:
return Text(content, style=style)
return Text(content, style=style)
def __call__(self, logger, method_name, event_dict):
# sourcery skip: merge-duplicate-blocks
"""渲染日志消息"""
@@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer:
if prefix:
# 解析 prefix 中的 Rich 标记
if module_hex_color:
content_text.append(Text.from_markup(prefix, style=module_hex_color))
content_text.append(self._render_content_text(prefix, style=module_hex_color))
else:
content_text.append(Text.from_markup(prefix))
content_text.append(self._render_content_text(prefix))
# 与"内心思考"段落之间插入空行
if prefix:
@@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer:
else:
# 使用 Text.from_markup 解析 Rich 标记语言
if module_hex_color:
try:
parts.append(Text.from_markup(event_content, style=module_hex_color))
except Exception:
# 如果标记解析失败,回退到普通文本
parts.append(Text(event_content, style=module_hex_color))
parts.append(self._render_content_text(event_content, style=module_hex_color))
else:
try:
parts.append(Text.from_markup(event_content))
except Exception:
# 如果标记解析失败,回退到普通文本
parts.append(Text(event_content))
parts.append(self._render_content_text(event_content))
else:
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
try:
parts.append(Text.from_markup(event_content))
except Exception:
# 如果标记解析失败,使用普通文本
parts.append(Text(event_content))
parts.append(self._render_content_text(event_content))
# 处理其他字段
extras = []
@@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer:
# 使用 Rich 拼接并返回字符串
result = Text(" ").join(parts)
# 将 Rich Text 对象转换为带 ANSI 颜色码的字符串
from io import StringIO
string_io = StringIO()
temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999)
temp_console.print(result, end="")
return string_io.getvalue()
# 使用持久化 Console + capture 避免每条日志重复实例化
with self._render_console.capture() as capture:
self._render_console.print(result, end="")
return capture.get()
# 配置标准logging以支持文件输出和压缩

View File

@@ -506,24 +506,16 @@ def load_config(config_path: str) -> Config:
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
# 但在 Pydantic 严格模式下可能导致类型验证失败
config_dict = config_data.unwrap()
# 创建Config对象各个配置类会自动进行 Pydantic 验证)
try:
logger.debug("正在解析和验证配置文件...")
config = Config.from_dict(config_data)
config = Config.from_dict(config_dict)
logger.debug("配置文件解析和验证完成")
# 【临时修复】在验证后,手动从原始数据重新加载 master_users
try:
# 先将 tomlkit 对象转换为纯 Python 字典
config_dict = config_data.unwrap()
if "permission" in config_dict and "master_users" in config_dict["permission"]:
raw_master_users = config_dict["permission"]["master_users"]
# 现在 raw_master_users 就是一个标准的 Python 列表了
config.permission.master_users = raw_master_users
logger.debug(f"【临时修复】已手动将 master_users 设置为: {config.permission.master_users}")
except Exception as patch_exc:
logger.error(f"【临时修复】手动设置 master_users 失败: {patch_exc}")
return config
except Exception as e:
logger.critical(f"配置文件解析失败: {e}")
@@ -581,4 +573,4 @@ def initialize_configs_once() -> tuple[Config, APIAdapterConfig]:
# 同一进程只执行一次初始化,避免重复生成或覆盖配置
global_config, model_config = initialize_configs_once()
logger.debug("非常的新鲜,非常的美味!")
logger.debug("非常的新鲜,非常的美味!")

View File

@@ -213,6 +213,12 @@ class ExpressionConfig(ValidatedConfigBase):
default="classic",
description="表达方式选择模式: classic=经典LLM评估, exp_model=机器学习模型预测"
)
model_temperature: float = Field(
default=1.0,
ge=0.0,
le=5.0,
description="表达模型采样温度0为贪婪值越大越容易采样到低分表达"
)
expiration_days: int = Field(
default=90,
description="表达方式过期天数,超过此天数未激活的表达方式将被清理"
@@ -508,6 +514,7 @@ class MemoryConfig(ValidatedConfigBase):
short_term_decay_factor: float = Field(default=0.98, description="衰减因子")
# 长期记忆层配置
use_judge: bool = Field(default=True, description="使用评判模型决定是否检索长期记忆")
long_term_batch_size: int = Field(default=10, description="批量转移大小")
long_term_decay_factor: float = Field(default=0.95, description="衰减因子")
long_term_auto_transfer_interval: int = Field(default=60, description="自动转移间隔(秒)")
@@ -796,14 +803,6 @@ class AffinityFlowConfig(ValidatedConfigBase):
# 兴趣评分系统参数
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
# 回复决策系统参数
no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值")
@@ -1009,4 +1008,3 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
default_factory=KokoroFlowChatterProactiveConfig,
description="私聊专属主动思考配置"
)

View File

@@ -79,9 +79,6 @@ class Individuality:
else:
logger.error("人设构建失败")
# 初始化智能兴趣系统
await self._initialize_smart_interest_system(personality_result, identity_result)
# 如果任何一个发生变化都需要清空数据库中的info_list因为这影响整体人设
if personality_changed or identity_changed:
logger.info("将清空数据库中原有的关键词缓存")
@@ -93,20 +90,6 @@ class Individuality:
}
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str):
"""初始化智能兴趣系统"""
# 组合完整的人设描述
full_personality = f"{personality_result}{identity_result}"
# 使用统一的评分API初始化智能兴趣系统
from src.plugin_system.apis import person_api
await person_api.initialize_smart_interests(
personality_description=full_personality, personality_id=self.bot_person_id
)
logger.info("智能兴趣系统初始化完成")
async def get_personality_block(self) -> str:
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:

View File

@@ -33,7 +33,6 @@ from src.config.config import global_config
from src.individuality.individuality import Individuality, get_individuality
from src.manager.async_task_manager import async_task_manager
from src.mood.mood_manager import mood_manager
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.plugin_manager import plugin_manager
@@ -120,93 +119,6 @@ class MainSystem:
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def _initialize_interest_calculator(self) -> None:
"""初始化兴趣值计算组件 - 通过插件系统自动发现和加载"""
try:
logger.debug("开始自动发现兴趣值计算组件...")
# 使用组件注册表自动发现兴趣计算器组件
interest_calculators = {}
try:
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
from src.plugin_system.base.component_types import ComponentType
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
logger.debug(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
except Exception as e:
logger.error(f"从组件注册表获取兴趣计算器失败: {e}")
if not interest_calculators:
logger.warning("未发现任何兴趣计算器组件")
return
# 初始化兴趣度管理器
from src.chat.interest_system.interest_manager import get_interest_manager
interest_manager = get_interest_manager()
await interest_manager.initialize()
# 尝试注册所有可用的计算器
registered_calculators = []
for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, "enabled", True)
default_enabled = getattr(calc_info, "enabled_by_default", True)
if not enabled or not default_enabled:
logger.debug(f"兴趣计算器 {calc_name} 未启用,跳过")
continue
try:
from src.plugin_system.base.component_types import ComponentType as CT
from src.plugin_system.core.component_registry import component_registry
component_class = component_registry.get_component_class(
calc_name, CT.INTEREST_CALCULATOR
)
if not component_class:
logger.warning(f"无法找到 {calc_name} 的组件类")
continue
logger.debug(f"成功获取 {calc_name} 的组件类: {component_class.__name__}")
# 确保组件是 BaseInterestCalculator 的子类
if not issubclass(component_class, BaseInterestCalculator):
logger.warning(f"{calc_name} 不是 BaseInterestCalculator 的有效子类")
continue
# 显式转换类型以修复 Pyright 错误
component_class = cast(type[BaseInterestCalculator], component_class)
# 创建组件实例
calculator_instance = component_class()
# 初始化组件
if not await calculator_instance.initialize():
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
continue
# 注册到兴趣管理器
if await interest_manager.register_calculator(calculator_instance):
registered_calculators.append(calculator_instance)
logger.debug(f"成功注册兴趣计算器: {calc_name}")
else:
logger.error(f"兴趣计算器 {calc_name} 注册失败")
except Exception as e:
logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}")
if registered_calculators:
logger.debug(f"成功注册了 {len(registered_calculators)} 个兴趣计算器")
for calc in registered_calculators:
logger.debug(f" - {calc.component_name} v{calc.component_version}")
else:
logger.error("未能成功注册任何兴趣计算器")
except Exception as e:
logger.error(f"初始化兴趣度计算器失败: {e}")
async def _async_cleanup(self) -> None:
"""异步清理资源"""
if self._cleanup_started:
@@ -474,6 +386,14 @@ class MainSystem:
await mood_manager.start()
logger.debug("情绪管理器初始化成功")
# 初始化日志广播系统
try:
from src.common.log_broadcaster import setup_log_broadcasting
setup_log_broadcasting()
logger.debug("日志广播系统初始化成功")
except Exception as e:
logger.error(f"日志广播系统初始化失败: {e}")
# 启动聊天管理器的自动保存任务
from src.chat.message_receive.chat_stream import get_chat_manager
task = asyncio.create_task(get_chat_manager()._auto_save_task())
@@ -499,9 +419,6 @@ class MainSystem:
except Exception as e:
logger.error(f"三层记忆系统初始化失败: {e}")
# 初始化消息兴趣值计算组件
await self._initialize_interest_calculator()
# 初始化LPMM知识库
try:
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge

View File

@@ -1,3 +1,6 @@
# ruff: noqa: G004, BLE001
# pylint: disable=logging-fstring-interpolation,broad-except,unused-argument
# pyright: reportOptionalMemberAccess=false
"""
记忆管理器 - Phase 3
@@ -218,7 +221,7 @@ class MemoryManager:
subject: str,
memory_type: str,
topic: str,
object: str | None = None,
obj: str | None = None,
attributes: dict[str, str] | None = None,
importance: float = 0.5,
**kwargs,
@@ -230,7 +233,7 @@ class MemoryManager:
subject: 主体(谁)
memory_type: 记忆类型(事件/观点/事实/关系)
topic: 主题(做什么/想什么)
object: 客体(对谁/对什么)
obj: 客体(对谁/对什么)
attributes: 属性字典(时间、地点、原因等)
importance: 重要性 (0.0-1.0)
**kwargs: 其他参数
@@ -246,7 +249,7 @@ class MemoryManager:
subject=subject,
memory_type=memory_type,
topic=topic,
object=object,
object=obj,
attributes=attributes,
importance=importance,
**kwargs,
@@ -775,6 +778,8 @@ class MemoryManager:
logger.debug(f"传播激活到相关记忆 {related_id[:8]} 失败: {e}")
# 再次保存传播后的更新
assert self.persistence is not None
assert self.graph_store is not None
await self.persistence.save_graph_store(self.graph_store)
logger.debug(f"后台保存激活更新完成,处理了 {len(memories)} 条记忆")
@@ -811,7 +816,6 @@ class MemoryManager:
# 批量执行传播任务
if propagation_tasks:
import asyncio
try:
await asyncio.wait_for(
asyncio.gather(*propagation_tasks, return_exceptions=True),
@@ -837,6 +841,8 @@ class MemoryManager:
Returns:
相关记忆 ID 列表
"""
_ = max_depth # 保留参数以兼容旧调用
memory = self.graph_store.get_memory_by_id(memory_id)
if not memory:
return []
@@ -997,7 +1003,7 @@ class MemoryManager:
if memories_to_forget:
logger.info(f"开始批量遗忘 {len(memories_to_forget)} 条记忆...")
for memory_id, activation in memories_to_forget:
for memory_id, _ in memories_to_forget:
# cleanup_orphans=False暂不清理孤立节点
success = await self.forget_memory(memory_id, cleanup_orphans=False)
if success:
@@ -1008,6 +1014,8 @@ class MemoryManager:
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
# 保存最终更新
assert self.persistence is not None
assert self.graph_store is not None
await self.persistence.save_graph_store(self.graph_store)
logger.info(
@@ -1059,7 +1067,7 @@ class MemoryManager:
# 2. 清理孤立边(指向已删除节点的边)
edges_to_remove = []
for source, target, edge_id in self.graph_store.graph.edges(data="edge_id"):
for source, target, _ in self.graph_store.graph.edges(data="edge_id"):
# 检查边的源节点和目标节点是否还存在于node_to_memories中
if source not in self.graph_store.node_to_memories or \
target not in self.graph_store.node_to_memories:
@@ -1096,7 +1104,7 @@ class MemoryManager:
if not self._initialized or not self.graph_store:
return {}
stats = self.graph_store.get_statistics()
stats: dict[str, Any] = self.graph_store.get_statistics()
# 添加激活度统计
all_memories = self.graph_store.get_all_memories()
@@ -1152,7 +1160,7 @@ class MemoryManager:
logger.info("开始记忆整理:检查遗忘 + 清理孤立节点...")
# 步骤1: 自动遗忘低激活度的记忆
forgotten_count = await self.auto_forget()
forgotten_count = await self.auto_forget_memories()
# 步骤2: 清理孤立节点和边auto_forget内部已执行这里再次确保
orphan_nodes, orphan_edges = await self._cleanup_orphan_nodes_and_edges()
@@ -1292,6 +1300,8 @@ class MemoryManager:
result["orphan_edges_cleaned"] = consolidate_result.get("orphan_edges_cleaned", 0)
# 2. 保存数据
assert self.persistence is not None
assert self.graph_store is not None
await self.persistence.save_graph_store(self.graph_store)
result["saved"] = True

View File

@@ -11,6 +11,7 @@ import asyncio
import json
import re
import uuid
import json_repair
from pathlib import Path
from typing import Any
@@ -186,8 +187,8 @@ class ShortTermMemoryManager:
"importance": 0.7,
"attributes": {{
"time": "时间信息",
"attribute1": "其他属性1"
"attribute2": "其他属性2"
"attribute1": "其他属性1",
"attribute2": "其他属性2",
...
}}
}}
@@ -530,7 +531,7 @@ class ShortTermMemoryManager:
json_str = re.sub(r"//.*", "", json_str)
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
data = json.loads(json_str)
data = json_repair.loads(json_str)
return data
except json.JSONDecodeError as e:

View File

@@ -12,7 +12,6 @@ from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.plugin_system.services.interest_service import interest_service
from src.plugin_system.services.relationship_service import relationship_service
logger = get_logger("person_api")
@@ -169,37 +168,6 @@ async def update_user_relationship(user_id: str, relationship_score: float, rela
await relationship_service.update_user_relationship(user_id, relationship_score, relationship_text, user_name)
# =============================================================================
# 兴趣系统API
# =============================================================================
async def initialize_smart_interests(personality_description: str, personality_id: str = "default"):
"""
初始化智能兴趣系统
Args:
personality_description: 机器人性格描述
personality_id: 性格ID
"""
await interest_service.initialize_smart_interests(personality_description, personality_id)
async def calculate_interest_match(
content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
):
"""计算消息兴趣匹配,返回匹配结果"""
if not content:
logger.warning("[PersonAPI] 请求兴趣匹配时 content 为空")
return None
try:
return await interest_service.calculate_interest_match(content, keywords, message_embedding)
except Exception as e:
logger.error(f"[PersonAPI] 计算消息兴趣匹配失败: {e}")
return None
# =============================================================================
# 系统状态与缓存API
# =============================================================================
@@ -214,7 +182,6 @@ def get_system_stats() -> dict[str, Any]:
"""
return {
"relationship_service": relationship_service.get_cache_stats(),
"interest_service": interest_service.get_interest_stats(),
}

View File

@@ -11,7 +11,6 @@ if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, InterestCalculatorInfo
logger = get_logger("base_interest_calculator")
@@ -210,26 +209,6 @@ class BaseInterestCalculator(ABC):
return default
return current
@classmethod
def get_interest_calculator_info(cls) -> "InterestCalculatorInfo":
"""从类属性生成InterestCalculatorInfo
遵循BaseCommand和BaseAction的设计模式从类属性自动生成组件信息
Returns:
InterestCalculatorInfo: 生成的兴趣计算器信息对象
"""
name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", ""))
if "." in name:
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
return InterestCalculatorInfo(
name=name,
component_type=ComponentType.INTEREST_CALCULATOR,
description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"),
enabled_by_default=getattr(cls, "enabled_by_default", True),
)
def __repr__(self) -> str:
return (

View File

@@ -7,7 +7,6 @@ from src.plugin_system.base.component_types import (
CommandInfo,
ComponentType,
EventHandlerInfo,
InterestCalculatorInfo,
PlusCommandInfo,
PromptInfo,
ToolInfo,
@@ -17,7 +16,6 @@ from .base_action import BaseAction
from .base_adapter import BaseAdapter
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_interest_calculator import BaseInterestCalculator
from .base_prompt import BasePrompt
from .base_tool import BaseTool
from .plugin_base import PluginBase
@@ -59,15 +57,6 @@ class BasePlugin(PluginBase):
logger.warning(f"Action组件 {component_class.__name__} 缺少 get_action_info 方法")
return None
elif component_type == ComponentType.INTEREST_CALCULATOR:
if hasattr(component_class, "get_interest_calculator_info"):
return component_class.get_interest_calculator_info()
else:
logger.warning(
f"InterestCalculator组件 {component_class.__name__} 缺少 get_interest_calculator_info 方法"
)
return None
elif component_type == ComponentType.PLUS_COMMAND:
# PlusCommand组件的get_info方法尚未实现
logger.warning("PlusCommand组件的get_info方法尚未实现")
@@ -123,7 +112,6 @@ class BasePlugin(PluginBase):
| tuple[PlusCommandInfo, type[PlusCommand]]
| tuple[EventHandlerInfo, type[BaseEventHandler]]
| tuple[ToolInfo, type[BaseTool]]
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
| tuple[PromptInfo, type[BasePrompt]]
]:
"""获取插件包含的组件列表

View File

@@ -48,7 +48,6 @@ class ComponentType(Enum):
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件
CHATTER = "chatter" # 聊天处理器组件
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
PROMPT = "prompt" # Prompt组件
ROUTER = "router" # 路由组件
ADAPTER = "adapter" # 适配器组件
@@ -298,17 +297,6 @@ class ChatterInfo(ComponentInfo):
self.component_type = ComponentType.CHATTER
@dataclass
class InterestCalculatorInfo(ComponentInfo):
"""兴趣度计算组件信息(单例模式)"""
enabled_by_default: bool = True # 是否默认启用
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.INTEREST_CALCULATOR
@dataclass
class EventInfo(ComponentInfo):
"""事件组件信息"""

View File

@@ -17,7 +17,6 @@ from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_http_component import BaseRouterComponent
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import (
@@ -28,7 +27,6 @@ from src.plugin_system.base.component_types import (
ComponentInfo,
ComponentType,
EventHandlerInfo,
InterestCalculatorInfo,
PluginInfo,
PlusCommandInfo,
PromptInfo,
@@ -48,7 +46,6 @@ ComponentClassType = (
| type[BaseEventHandler]
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| type[BasePrompt]
| type[BaseRouterComponent]
| type[BaseAdapter]
@@ -144,10 +141,6 @@ class ComponentRegistry:
self._chatter_registry: dict[str, type[BaseChatter]] = {}
self._enabled_chatter_registry: dict[str, type[BaseChatter]] = {}
# InterestCalculator 相关
self._interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {}
self._enabled_interest_calculator_registry: dict[str, type[BaseInterestCalculator]] = {}
# Prompt 相关
self._prompt_registry: dict[str, type[BasePrompt]] = {}
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
@@ -283,7 +276,6 @@ class ComponentRegistry:
ComponentType.TOOL: self._register_tool,
ComponentType.EVENT_HANDLER: self._register_event_handler,
ComponentType.CHATTER: self._register_chatter,
ComponentType.INTEREST_CALCULATOR: self._register_interest_calculator,
ComponentType.PROMPT: self._register_prompt,
ComponentType.ROUTER: self._register_router,
ComponentType.ADAPTER: self._register_adapter,
@@ -344,9 +336,6 @@ class ComponentRegistry:
case ComponentType.CHATTER:
self._chatter_registry.pop(component_name, None)
self._enabled_chatter_registry.pop(component_name, None)
case ComponentType.INTEREST_CALCULATOR:
self._interest_calculator_registry.pop(component_name, None)
self._enabled_interest_calculator_registry.pop(component_name, None)
case ComponentType.PROMPT:
self._prompt_registry.pop(component_name, None)
self._enabled_prompt_registry.pop(component_name, None)
@@ -497,25 +486,6 @@ class ComponentRegistry:
self._enabled_chatter_registry[info.name] = chatter_class
return True
def _register_interest_calculator(self, info: ComponentInfo, cls: ComponentClassType) -> bool:
"""
注册 InterestCalculator 组件到特定注册表。
Args:
info: InterestCalculator 组件的元数据信息
cls: InterestCalculator 组件的类定义
Returns:
注册成功返回 True
"""
calc_info = cast(InterestCalculatorInfo, info)
calc_class = cast(type[BaseInterestCalculator], cls)
_assign_plugin_attrs(calc_class, info.plugin_name, self.get_plugin_config(info.plugin_name) or {})
self._interest_calculator_registry[info.name] = calc_class
if calc_info.enabled:
self._enabled_interest_calculator_registry[info.name] = calc_class
return True
def _register_prompt(self, info: ComponentInfo, cls: ComponentClassType) -> bool:
"""
注册 Prompt 组件到 Prompt 特定注册表。
@@ -950,26 +920,6 @@ class ComponentRegistry:
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
return info if isinstance(info, ChatterInfo) else None
# --- InterestCalculator ---
def get_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]:
"""获取所有已注册的 InterestCalculator 类。"""
return self._interest_calculator_registry.copy()
def get_enabled_interest_calculator_registry(self) -> dict[str, type[BaseInterestCalculator]]:
"""
获取所有已启用的 InterestCalculator 类。
会检查组件的全局启用状态。
Returns:
可用的 InterestCalculator 名称到类的字典
"""
return {
name: cls
for name, cls in self._interest_calculator_registry.items()
if self.is_component_available(name, ComponentType.INTEREST_CALCULATOR)
}
# --- Prompt ---
def get_prompt_registry(self) -> dict[str, type[BasePrompt]]:
"""获取所有已注册的 Prompt 类。"""

View File

@@ -110,8 +110,6 @@ class ComponentStateManager:
)
case ComponentType.CHATTER:
self._registry._enabled_chatter_registry[component_name] = target_class # type: ignore
case ComponentType.INTEREST_CALCULATOR:
self._registry._enabled_interest_calculator_registry[component_name] = target_class # type: ignore
case ComponentType.PROMPT:
self._registry._enabled_prompt_registry[component_name] = target_class # type: ignore
case ComponentType.ADAPTER:
@@ -161,8 +159,6 @@ class ComponentStateManager:
event_manager.remove_event_handler(component_name)
case ComponentType.CHATTER:
self._registry._enabled_chatter_registry.pop(component_name, None)
case ComponentType.INTEREST_CALCULATOR:
self._registry._enabled_interest_calculator_registry.pop(component_name, None)
case ComponentType.PROMPT:
self._registry._enabled_prompt_registry.pop(component_name, None)
case ComponentType.ADAPTER:

View File

@@ -1,108 +0,0 @@
"""
兴趣系统服务
提供独立的兴趣管理功能,不依赖任何插件
"""
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
logger = get_logger("interest_service")
class InterestService:
"""兴趣系统服务 - 独立于插件的兴趣管理"""
def __init__(self):
self.is_initialized = bot_interest_manager.is_initialized
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
"""
初始化智能兴趣系统
Args:
personality_description: 机器人性格描述
personality_id: 性格ID
"""
try:
logger.info("开始初始化智能兴趣系统...")
await bot_interest_manager.initialize(personality_description, personality_id)
self.is_initialized = True
logger.info("智能兴趣系统初始化完成。")
# 显示初始化后的统计信息
stats = bot_interest_manager.get_interest_stats()
logger.debug(f"兴趣系统统计: {stats}")
except Exception as e:
logger.error(f"初始化智能兴趣系统失败: {e}")
self.is_initialized = False
async def calculate_interest_match(
self, content: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
):
"""
计算消息与兴趣的匹配度
Args:
content: 消息内容
keywords: 关键字列表
message_embedding: 已经生成的消息embedding可选
Returns:
匹配结果
"""
if not self.is_initialized:
logger.warning("兴趣系统未初始化,无法计算匹配度")
return None
try:
if not keywords:
# 如果没有关键字,则从内容中提取
keywords = self._extract_keywords_from_content(content)
return await bot_interest_manager.calculate_interest_match(content, keywords, message_embedding)
except Exception as e:
logger.error(f"计算兴趣匹配失败: {e}")
return None
def _extract_keywords_from_content(self, content: str) -> list[str]:
"""从内容中提取关键词"""
import re
# 清理文本
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
words = content.split()
# 过滤和关键词提取
keywords = []
for word in words:
word = word.strip()
if (
len(word) >= 2 # 至少2个字符
and word.isalnum() # 字母数字
and not word.isdigit()
): # 不是纯数字
keywords.append(word.lower())
# 去重并限制数量
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
def get_interest_stats(self) -> dict:
"""获取兴趣系统统计信息"""
if not self.is_initialized:
return {"initialized": False}
try:
return {
"initialized": True,
**bot_interest_manager.get_interest_stats()
}
except Exception as e:
logger.error(f"获取兴趣系统统计失败: {e}")
return {"initialized": True, "error": str(e)}
# 创建全局实例
interest_service = InterestService()

View File

@@ -1,15 +1,22 @@
"""AffinityFlow 风格兴趣值计算组件
基于原有的 AffinityFlow 兴趣度评分系统,提供标准化的兴趣值计算功能
集成了语义兴趣度计算TF-IDF + Logistic Regression
2024.12 优化更新:
- 使用 FastScorer 优化评分(绕过 sklearn纯 Python 字典计算)
- 支持批处理队列模式(高频群聊场景)
- 全局线程池避免重复创建 executor
- 更短的超时时间2秒
"""
import asyncio
import time
from typing import TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Any
import orjson
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator, InterestCalculationResult
@@ -36,18 +43,21 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 从配置加载评分权重
affinity_config = global_config.affinity_flow
self.score_weights = {
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
"semantic": 0.5, # 语义兴趣度权重(核心维度)
"relationship": affinity_config.relationship_weight, # 关系分权重
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
}
# 语义兴趣度评分器(替代原有的 embedding 兴趣匹配)
self.semantic_scorer = None
self.use_semantic_scoring = True # 必须启用
self._semantic_initialized = False # 防止重复初始化
self.model_manager = None
# 评分阈值
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
# 兴趣匹配系统配置
self.use_smart_matching = True
# 连续不回复概率提升
self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count
@@ -69,14 +79,17 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.post_reply_boost_max_count = affinity_config.post_reply_boost_max_count
self.post_reply_boost_decay_rate = affinity_config.post_reply_boost_decay_rate
logger.info("[Affinity兴趣计算器] 初始化完成:")
logger.info("[Affinity兴趣计算器] 初始化完成(基于语义兴趣度 TF-IDF+LR:")
logger.info(f" - 权重配置: {self.score_weights}")
logger.info(f" - 回复阈值: {self.reply_threshold}")
logger.info(f" - 智能匹配: {self.use_smart_matching}")
logger.info(f" - 语义评分: {self.use_semantic_scoring} (TF-IDF + Logistic Regression + FastScorer优化)")
logger.info(f" - 回复后连续对话: {self.enable_post_reply_boost}")
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
# 异步初始化语义评分器
asyncio.create_task(self._initialize_semantic_scorer())
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
"""执行AffinityFlow风格的兴趣值计算"""
try:
@@ -93,10 +106,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
logger.debug(f"[Affinity兴趣计算] 用户ID: {user_id}")
# 1. 计算兴趣匹配
keywords = self._extract_keywords_from_database(message)
interest_match_score = await self._calculate_interest_match_score(message, content, keywords)
logger.debug(f"[Affinity兴趣计算] 兴趣匹配分: {interest_match_score}")
# 1. 计算语义兴趣度(核心维度,替代原 embedding 兴趣匹配
semantic_score = await self._calculate_semantic_score(content)
logger.debug(f"[Affinity兴趣计算] 语义兴趣度TF-IDF+LR: {semantic_score}")
# 2. 计算关系分
relationship_score = await self._calculate_relationship_score(user_id)
@@ -108,12 +120,12 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 4. 综合评分
# 确保所有分数都是有效的 float 值
interest_match_score = float(interest_match_score) if interest_match_score is not None else 0.0
semantic_score = float(semantic_score) if semantic_score is not None else 0.0
relationship_score = float(relationship_score) if relationship_score is not None else 0.0
mentioned_score = float(mentioned_score) if mentioned_score is not None else 0.0
raw_total_score = (
interest_match_score * self.score_weights["interest_match"]
semantic_score * self.score_weights["semantic"]
+ relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"]
)
@@ -122,7 +134,8 @@ class AffinityInterestCalculator(BaseInterestCalculator):
total_score = min(raw_total_score, 1.0)
logger.debug(
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
f"[Affinity兴趣计算] 综合得分计算: "
f"{semantic_score:.3f}*{self.score_weights['semantic']} + "
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
)
@@ -153,7 +166,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
f"(语义:{semantic_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
)
return InterestCalculationResult(
@@ -172,55 +185,6 @@ class AffinityInterestCalculator(BaseInterestCalculator):
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
)
async def _calculate_interest_match_score(
self, message: "DatabaseMessages", content: str, keywords: list[str] | None = None
) -> float:
"""计算兴趣匹配度(使用智能兴趣匹配系统,带超时保护)"""
# 调试日志:检查各个条件
if not content:
logger.debug("兴趣匹配返回0.0: 内容为空")
return 0.0
if not self.use_smart_matching:
logger.debug("兴趣匹配返回0.0: 智能匹配未启用")
return 0.0
if not bot_interest_manager.is_initialized:
logger.debug("兴趣匹配返回0.0: bot_interest_manager未初始化")
return 0.0
logger.debug(f"开始兴趣匹配计算,内容: {content[:50]}...")
try:
# 使用机器人的兴趣标签系统进行智能匹配5秒超时保护
match_result = await asyncio.wait_for(
bot_interest_manager.calculate_interest_match(
content, keywords or [], getattr(message, "semantic_embedding", None)
),
timeout=5.0
)
logger.debug(f"兴趣匹配结果: {match_result}")
if match_result:
# 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow
match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
# 移除兴趣匹配分数上限允许超过1.0,最终分数会被整体限制
logger.debug(f"兴趣匹配最终得分: {final_score:.3f} (原始: {match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus:.3f})")
return final_score
else:
logger.debug("兴趣匹配返回0.0: match_result为None")
return 0.0
except asyncio.TimeoutError:
logger.warning("[超时] 兴趣匹配计算超时(>5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}")
return 0.0
async def _calculate_relationship_score(self, user_id: str) -> float:
"""计算用户关系分"""
if not user_id:
@@ -316,60 +280,204 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return adjusted_reply_threshold, adjusted_action_threshold
def _extract_keywords_from_database(self, message: "DatabaseMessages") -> list[str]:
"""从数据库消息中提取关键词"""
keywords = []
async def _initialize_semantic_scorer(self):
"""异步初始化语义兴趣度评分器(使用单例 + FastScorer优化"""
# 检查是否已初始化
if self._semantic_initialized:
logger.debug("[语义评分] 评分器已初始化,跳过")
return
if not self.use_semantic_scoring:
logger.debug("[语义评分] 未启用语义兴趣度评分")
return
# 防止并发初始化(使用锁)
if not hasattr(self, '_init_lock'):
self._init_lock = asyncio.Lock()
async with self._init_lock:
# 双重检查
if self._semantic_initialized:
logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过")
return
# 尝试从 key_words 字段提取存储的是JSON字符串
key_words = getattr(message, "key_words", "")
if key_words:
try:
extracted = orjson.loads(key_words)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
from src.chat.semantic_interest import get_semantic_scorer
from src.chat.semantic_interest.runtime_scorer import ModelManager
# 如果没有 keywords尝试从 key_words_lite 提取
if not keywords:
key_words_lite = getattr(message, "key_words_lite", "")
if key_words_lite:
# 查找最新的模型文件
model_dir = Path("data/semantic_interest/models")
if not model_dir.exists():
logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
model_dir.mkdir(parents=True, exist_ok=True)
# 使用模型管理器(支持人设感知)
if self.model_manager is None:
self.model_manager = ModelManager(model_dir)
logger.debug("[语义评分] 模型管理器已创建")
# 获取人设信息
persona_info = self._get_current_persona_info()
# 先检查是否已有可用模型
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
auto_trainer = get_auto_trainer()
existing_model = auto_trainer.get_model_for_persona(persona_info)
# 加载模型(自动选择合适的版本,使用单例 + FastScorer
try:
extracted = orjson.loads(key_words_lite)
if isinstance(extracted, list):
keywords = extracted
except (orjson.JSONDecodeError, TypeError):
keywords = []
if existing_model and existing_model.exists():
# 直接加载已有模型
logger.info(f"[语义评分] 使用已有模型: {existing_model.name}")
scorer = await get_semantic_scorer(existing_model, use_async=True)
else:
# 使用 ModelManager 自动选择或训练
scorer = await self.model_manager.load_model(
version="auto", # 自动选择或训练
persona_info=persona_info
)
self.semantic_scorer = scorer
logger.info("[语义评分] 语义兴趣度评分器初始化成功FastScorer优化 + 单例)")
# 设置初始化标志
self._semantic_initialized = True
# 启动自动训练任务每24小时检查一次- 只在没有模型时或明确需要时启动
if not existing_model or not existing_model.exists():
await self.model_manager.start_auto_training(
persona_info=persona_info,
interval_hours=24
)
else:
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
except FileNotFoundError:
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
# 触发首次训练
trained, model_path = await auto_trainer.auto_train_if_needed(
persona_info=persona_info,
force=True # 强制训练
)
if trained and model_path:
# 使用单例获取评分器(默认启用 FastScorer
self.semantic_scorer = await get_semantic_scorer(model_path)
logger.info("[语义评分] 首次训练完成模型已加载FastScorer优化 + 单例)")
# 设置初始化标志
self._semantic_initialized = True
else:
logger.error("[语义评分] 首次训练失败")
self.use_semantic_scoring = False
# 如果还是没有,从消息内容中提取(降级方案)
if not keywords:
content = getattr(message, "processed_plain_text", "") or ""
keywords = self._extract_keywords_from_content(content)
except ImportError:
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
self.use_semantic_scoring = False
except Exception as e:
logger.error(f"[语义评分] 初始化失败: {e}")
self.use_semantic_scoring = False
return keywords[:15] # 返回前15个关键词
def _get_current_persona_info(self) -> dict[str, Any]:
"""获取当前人设信息
Returns:
人设信息字典
"""
# 默认信息(至少包含名字)
persona_info = {
"name": global_config.bot.nickname,
"interests": [],
"dislikes": [],
"personality": "",
}
def _extract_keywords_from_content(self, content: str) -> list[str]:
"""从内容中提取关键词(降级方案)"""
import re
# 优先从已生成的人设文件获取Individuality 初始化时会生成)
try:
persona_file = Path("data/personality/personality_data.json")
if persona_file.exists():
data = orjson.loads(persona_file.read_bytes())
personality_parts = [data.get("personality", ""), data.get("identity", "")]
persona_info["personality"] = "".join([p for p in personality_parts if p]).strip("")
if persona_info["personality"]:
return persona_info
except Exception as e:
logger.debug(f"[语义评分] 从文件获取人设信息失败: {e}")
# 清理文本
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
words = content.split()
# 退化为配置中的人设描述
try:
personality_parts = []
personality_core = getattr(global_config.personality, "personality_core", "")
personality_side = getattr(global_config.personality, "personality_side", "")
identity = getattr(global_config.personality, "identity", "")
# 过滤和关键词提取
keywords = []
for word in words:
word = word.strip()
if (
len(word) >= 2 # 至少2个字符
and word.isalnum() # 字母数字
and not word.isdigit()
): # 不是纯数字
keywords.append(word.lower())
if personality_core:
personality_parts.append(personality_core)
if personality_side:
personality_parts.append(personality_side)
if identity:
personality_parts.append(identity)
# 去重并限制数量
unique_keywords = list(set(keywords))
return unique_keywords[:10] # 返回前10个唯一关键词
persona_info["personality"] = "".join(personality_parts) or "默认人设"
except Exception as e:
logger.debug(f"[语义评分] 使用配置获取人设信息失败: {e}")
persona_info["personality"] = "默认人设"
return persona_info
async def _calculate_semantic_score(self, content: str) -> float:
"""计算语义兴趣度分数优化版FastScorer + 可选批处理 + 超时保护)
Args:
content: 消息文本
Returns:
语义兴趣度分数 [0.0, 1.0]
"""
# 检查是否启用
if not self.use_semantic_scoring:
return 0.0
# 检查评分器是否已加载
if not self.semantic_scorer:
return 0.0
# 检查内容是否为空
if not content or not content.strip():
return 0.0
try:
score = await self.semantic_scorer.score_async(content, timeout=2.0)
logger.debug(f"[语义评分] 内容: '{content[:50]}...' -> 分数: {score:.3f}")
return score
except Exception as e:
logger.warning(f"[语义评分] 计算失败: {e}")
return 0.0
async def reload_semantic_model(self):
"""重新加载语义兴趣度模型(支持热更新和人设检查)"""
if not self.use_semantic_scoring:
logger.info("[语义评分] 语义评分未启用,无需重载")
return
logger.info("[语义评分] 开始重新加载模型...")
# 检查人设是否变化
if hasattr(self, 'model_manager') and self.model_manager:
persona_info = self._get_current_persona_info()
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
if reloaded:
self.semantic_scorer = self.model_manager.get_scorer()
logger.info("[语义评分] 模型重载完成(人设已更新)")
else:
logger.info("[语义评分] 人设未变化,无需重载")
else:
# 降级:简单重新初始化
self._semantic_initialized = False
await self._initialize_semantic_scorer()
logger.info("[语义评分] 模型重载完成")
def update_no_reply_count(self, replied: bool):
"""更新连续不回复计数"""
@@ -415,3 +523,5 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
)
afc_interest_calculator = AffinityInterestCalculator()

View File

@@ -7,8 +7,6 @@ import asyncio
from dataclasses import asdict
from typing import TYPE_CHECKING, Any
from src.chat.interest_system import bot_interest_manager
from src.chat.interest_system.interest_manager import get_interest_manager
from src.chat.message_receive.storage import MessageStorage
from src.common.logger import get_logger
from src.config.config import global_config
@@ -52,6 +50,8 @@ class ChatterActionPlanner:
self.action_manager = action_manager
self.generator = ChatterPlanGenerator(chat_id, action_manager)
self.executor = ChatterPlanExecutor(action_manager)
self._interest_calculator = None
self._interest_calculator_lock = asyncio.Lock()
# 使用新的统一兴趣度管理系统
@@ -130,60 +130,32 @@ class ChatterActionPlanner:
if not pending_messages:
return
calculator = await self._get_interest_calculator()
if not calculator:
logger.debug("未获取到兴趣计算器,跳过批量兴趣计算")
return
logger.debug(f"批量兴趣值计算:待处理 {len(pending_messages)} 条消息")
if not bot_interest_manager.is_initialized:
logger.debug("bot_interest_manager 未初始化,跳过批量兴趣计算")
return
try:
interest_manager = get_interest_manager()
except Exception as exc:
logger.warning(f"获取兴趣管理器失败: {exc}")
return
if not interest_manager or not interest_manager.has_calculator():
logger.debug("当前无可用兴趣计算器,跳过批量兴趣计算")
return
text_map: dict[str, str] = {}
for message in pending_messages:
text = getattr(message, "processed_plain_text", None) or getattr(message, "display_message", "") or ""
text_map[str(message.message_id)] = text
try:
embeddings = await bot_interest_manager.generate_embeddings_for_texts(text_map)
except Exception as exc:
logger.error(f"批量获取消息embedding失败: {exc}")
embeddings = {}
interest_updates: dict[str, float] = {}
reply_updates: dict[str, bool] = {}
for message in pending_messages:
message_id = str(message.message_id)
if message_id in embeddings:
message.semantic_embedding = embeddings[message_id]
try:
result = await interest_manager.calculate_interest(message)
result = await calculator._safe_execute(message) # 使用带统计的安全执行
except Exception as exc:
logger.error(f"批量计算消息兴趣失败: {exc}")
continue
if result.success:
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = True
message.interest_value = result.interest_value
message.should_reply = result.should_reply
message.should_act = result.should_act
message.interest_calculated = result.success
message_id = str(getattr(message, "message_id", ""))
if message_id:
interest_updates[message_id] = result.interest_value
reply_updates[message_id] = result.should_reply
# 批量处理后清理 embeddings 字典
embeddings.clear()
text_map.clear()
else:
message.interest_calculated = False
if interest_updates:
try:
@@ -191,6 +163,32 @@ class ChatterActionPlanner:
except Exception as exc:
logger.error(f"批量更新消息兴趣值失败: {exc}")
async def _get_interest_calculator(self):
"""懒加载兴趣计算器,直接使用计算器实例进行兴趣计算"""
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
return self._interest_calculator
async with self._interest_calculator_lock:
if self._interest_calculator and getattr(self._interest_calculator, "is_enabled", False):
return self._interest_calculator
try:
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
afc_interest_calculator,
)
calculator = afc_interest_calculator
if not await calculator.initialize():
logger.warning("AffinityInterestCalculator 初始化失败")
return None
self._interest_calculator = calculator
logger.debug("AffinityInterestCalculator 已就绪")
return self._interest_calculator
except Exception as exc:
logger.warning(f"创建 AffinityInterestCalculator 失败: {exc}")
return None
async def _focus_mode_flow(self, context: "StreamContext | None") -> tuple[list[dict[str, Any]], Any | None]:
"""Focus模式下的完整plan流程
@@ -589,13 +587,11 @@ class ChatterActionPlanner:
replied: 是否回复了消息
"""
try:
from src.chat.interest_system.interest_manager import get_interest_manager
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
AffinityInterestCalculator,
)
interest_manager = get_interest_manager()
calculator = interest_manager.get_current_calculator()
calculator = await self._get_interest_calculator()
if calculator and isinstance(calculator, AffinityInterestCalculator):
calculator.on_message_processed(replied)

View File

@@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin):
except Exception as e:
logger.error(f"加载 AffinityChatter 时出错: {e}")
try:
# 延迟导入 AffinityInterestCalculator从 core 子模块)
from .core.affinity_interest_calculator import AffinityInterestCalculator
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
except Exception as e:
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
try:
# 延迟导入 UserProfileTool从 tools 子模块)
from .tools.user_profile_tool import UserProfileTool

View File

@@ -158,6 +158,9 @@ class KokoroFlowChatterConfig:
# LLM 配置
llm: LLMConfig = field(default_factory=LLMConfig)
# 自定义决策提示词
custom_decision_prompt: str = ""
# 调试模式
debug: bool = False
@@ -256,6 +259,10 @@ def load_config() -> KokoroFlowChatterConfig:
timeout=getattr(llm_cfg, "timeout", 60.0),
)
# 自定义决策提示词配置
if hasattr(kfc_cfg, "custom_decision_prompt"):
config.custom_decision_prompt = str(kfc_cfg.custom_decision_prompt)
except Exception as e:
from src.common.logger import get_logger
logger = get_logger("kfc_config")

View File

@@ -235,7 +235,7 @@ class KFCContextBuilder:
search_result = await unified_manager.search_memories(
query_text=query_text,
use_judge=True,
use_judge=config.memory.use_judge,
recent_chat_history=chat_history,
)

View File

@@ -72,6 +72,9 @@ class PromptBuilder:
# 1.5. 构建安全互动准则块
safety_guidelines_block = self._build_safety_guidelines_block()
# 1.6. 构建自定义决策提示词块
custom_decision_block = self._build_custom_decision_block()
# 2. 使用 context_builder 获取关系、记忆、工具、表达习惯等
context_data = await self._build_context_data(user_name, chat_stream, user_id)
relation_block = context_data.get("relation_info", f"你与 {user_name} 还不太熟悉,这是早期的交流阶段。")
@@ -102,6 +105,7 @@ class PromptBuilder:
user_name=user_name,
persona_block=persona_block,
safety_guidelines_block=safety_guidelines_block,
custom_decision_block=custom_decision_block,
relation_block=relation_block,
memory_block=memory_block or "(暂无相关记忆)",
tool_info=tool_info or "(暂无工具信息)",
@@ -232,6 +236,23 @@ class PromptBuilder:
{guidelines_text}
如果遇到违反上述原则的请求,请在保持你核心人设的同时,以合适的方式进行回应。"""
def _build_custom_decision_block(self) -> str:
"""
构建自定义决策提示词块
从配置中读取 custom_decision_prompt用于指导KFC的决策行为
类似于AFC的planner_custom_prompt_content
"""
from ..config import get_config
kfc_config = get_config()
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
if not custom_prompt or not custom_prompt.strip():
return ""
return custom_prompt.strip()
def _build_combined_expression_block(self, learned_habits: str) -> str:
"""
构建合并后的表达习惯块
@@ -693,8 +714,22 @@ class PromptBuilder:
# 添加真正追问次数警告(只有真正发了消息才算追问)
followup_count = extra_context.get("followup_count", 0)
if followup_count > 0:
timeout_context_parts.append(f"⚠️ 你已经连续追问了 {followup_count} 次,对方仍未回复。再追问可能会显得太急躁,请三思。")
if followup_count >= 2:
timeout_context_parts.append(
f"⚠️ **强烈建议**: 你已经连续追问了 {followup_count} 次,对方仍未回复。"
"**极度推荐选择 `do_nothing` 或主动结束话题**。"
"对方可能在忙或需要空间,不是所有人都能一直在线。给彼此一些空间会更好。"
)
elif followup_count == 1:
timeout_context_parts.append(
"📝 温馨提醒这是你第2次等待回复已追问1次"
"可以再试着追问一次,但如果对方还是没回复,**强烈建议**之后选择 `do_nothing` 或结束话题。"
)
elif followup_count == 0:
timeout_context_parts.append(
"💭 追问提示:如果对方一段时间未回复,可以适当追问一次。"
"但要记住对方可能在忙建议最多追问2次左右之后给对方一些空间。"
)
# 添加距离用户上次回复的时间
time_since_user_reply_str = extra_context.get("time_since_user_reply_str")

View File

@@ -24,6 +24,9 @@ kfc_MAIN_PROMPT = Prompt(
# 安全互动准则
{safety_guidelines_block}
# 决策指导
{custom_decision_block}
# 相关信息
{memory_block}

View File

@@ -100,6 +100,33 @@ def build_rules_module() -> str:
{kfc_rules}"""
# ============================================================
# 模块2.5: 自定义决策提示词 - 用户自定义的决策行为指导
# ============================================================
def build_custom_decision_module() -> str:
"""
构建自定义决策提示词模块
允许用户通过配置文件自定义KFC的决策行为指导
类似于AFC的planner_custom_prompt_content
"""
from .config import get_config
kfc_config = get_config()
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
# 调试输出
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}")
if not custom_prompt or not custom_prompt.strip():
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
return ""
logger.info(f"[自定义决策提示词] 已加载: {custom_prompt.strip()}")
return custom_prompt.strip()
# ============================================================
# 模块3: 情境上下文 - 时间/场景/内在状态/关系/记忆
# ============================================================
@@ -354,6 +381,9 @@ def build_system_prompt(
# 角色框架不提AI直接进入角色
role_frame = "你是一个真实的人,有自己的生活、想法和情绪。像平时和朋友聊天一样自然地交流。"
# 获取自定义决策提示词
custom_decision_block = build_custom_decision_module()
# 组装各模块
modules = [
role_frame,
@@ -363,6 +393,17 @@ def build_system_prompt(
"",
"## 2. 规则",
build_rules_module(),
]
# 如果有自定义决策提示词,插入到规则后面
if custom_decision_block:
modules.extend([
"",
"## 2.5. 决策指导",
custom_decision_block,
])
modules.extend([
"",
"## 3. 现在的情况",
build_context_module(session, chat_stream, context_data),
@@ -372,7 +413,7 @@ def build_system_prompt(
"",
"## 5. 怎么回复",
build_output_module(context_data),
]
])
return "\n".join(modules)

View File

@@ -147,17 +147,35 @@ class UnifiedPromptGenerator:
# 生成连续追问警告(使用 followup_count 作为追问计数,只有真正发消息才算)
followup_count = session.waiting_config.followup_count
max_followups = 3 # 最多追问3
max_followups = 2 # 建议最多追问2
if followup_count >= max_followups:
followup_warning = f"""⚠️ **重要提醒**
followup_warning = f"""⚠️ **强烈建议**
你已经连续追问了 {followup_count} 次,对方都没有回复。
**强烈建议不要再发消息了**——继续追问会显得很缠人、很不尊重对方的空间
对方可能真的在忙,或者暂时不想回复,这都是正常的。
请选择 `do_nothing` 继续等待,或者直接结束对话(设置 `max_wait_seconds: 0`)。"""
elif followup_count > 0:
followup_warning = f"""📝 提示:这已经是你第 {followup_count + 1} 次等待对方回复了。
如果对方持续没有回应,可能真的在忙或不方便,不需要急着追问。"""
**极度推荐选择 `do_nothing` 或设置 `max_wait_seconds: 0` 结束这个话题**
对方很可能:
- 正在忙自己的事情,没有时间回复
- 需要一些个人空间和独处时间
- 暂时不方便或不想聊天
这些都是完全正常的。不是所有人都能一直在线,每个人都有自己的生活节奏。
继续追问很可能会让对方感到压力和不适,不如给彼此一些空间。
**最好的选择**
1. 选择 `do_nothing` 安静等待对方主动联系
2. 或者主动结束这个话题(`max_wait_seconds: 0`),让对方知道你理解他们可能在忙"""
elif followup_count == 1:
followup_warning = """📝 温馨提醒:
这是你第2次等待对方回复已追问1次
可以再试着温柔地追问一次,但要做好对方可能真的在忙的心理准备。
如果这次之后对方还是没回复,**强烈建议**不要再继续追问了——
选择 `do_nothing` 给对方空间,或者主动结束话题,都是尊重对方的表现。"""
elif followup_count == 0:
followup_warning = """💭 追问提示:
如果对方一段时间没回复,可以适当追问一次,用轻松的语气提醒一下。
但要记住:不是所有人都能一直在线,对方可能在忙。
建议最多追问2次左右之后就给对方一些空间吧。"""
else:
followup_warning = ""

View File

@@ -619,7 +619,6 @@ class SystemCommand(PlusCommand):
# 禁用保护
if not enabled:
protected_types = [
ComponentType.INTEREST_CALCULATOR,
ComponentType.PROMPT,
ComponentType.ROUTER,
]
@@ -736,7 +735,6 @@ class SystemCommand(PlusCommand):
if not enabled: # 如果是禁用操作
# 定义不可禁用的核心组件类型
protected_types = [
ComponentType.INTEREST_CALCULATOR,
ComponentType.PROMPT,
ComponentType.ROUTER,
]

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.9.8"
version = "8.0.0"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -103,7 +103,7 @@ command_prefixes = ['/']
[personality]
# 建议50字以内描述人格的核心特质
personality_core = "是一个积极向上的女大学生"
personality_core = "是一个积极向上的女大学生"
# 人格的细节,描述人格的一些侧面
personality_side = "用一句话或几句话描述人格的侧面特质"
#アイデンティティがない 生まれないらららら
@@ -134,6 +134,8 @@ compress_identity = false # 是否压缩身份,压缩后会精简身份信息
# - "classic": 经典模式,随机抽样 + LLM选择
# - "exp_model": 表达模型模式,使用机器学习模型预测最合适的表达
mode = "classic"
# model_temperature: 机器预测模式下的“温度”0 为贪婪,越大越爱探索(更容易选到低分表达)
model_temperature = 1.0
# expiration_days: 表达方式过期天数,超过此天数未激活的表达方式将被清理
expiration_days = 1
@@ -311,6 +313,7 @@ short_term_search_top_k = 5 # 搜索时返回的最大数量
short_term_decay_factor = 0.98 # 衰减因子
# 长期记忆层配置
use_judge = true # 使用评判模型决定是否检索长期记忆
long_term_batch_size = 10 # 批量转移大小
long_term_decay_factor = 0.95 # 衰减因子
long_term_auto_transfer_interval = 180 # 自动转移间隔(秒)
@@ -425,7 +428,7 @@ auto_install = true #it can work now!
auto_install_timeout = 300
# 是否使用PyPI镜像源推荐可加速下载
use_mirror = true
mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL如: "https://pypi.tuna.tsinghua.edu.cn/simple"
mirror_url = "https://pypi.tuna.tsinghua.edu.cn/simple" # PyPI镜像源URL如: "https://pypi.tuna.tsinghua.edu.cn/simple"
# 依赖安装日志级别
install_log_level = "INFO"
@@ -536,14 +539,6 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在
# 兴趣评分系统参数
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
high_match_interest_threshold = 0.6 # 高匹配兴趣阈值
medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值
low_match_interest_threshold = 0.2 # 低匹配兴趣阈值
high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率
medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率
low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率
match_count_bonus = 0.01 # 匹配数关键词加成值
max_match_bonus = 0.1 # 最大匹配数加成值
# 回复决策系统参数
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值
@@ -636,6 +631,13 @@ mode = "split"
max_wait_seconds_default = 300 # 默认的最大等待秒数AI发送消息后愿意等待用户回复的时间
enable_continuous_thinking = true # 是否在等待期间启用心理活动更新
# --- 自定义决策提示词 ---
# 类似于AFC的planner_custom_prompt_content允许用户自定义KFC的决策行为指导
# 在unified模式下会插入到完整提示词中影响整体思考和回复生成
# 在split模式下只会插入到planner提示词中影响决策规划不影响replyer的回复生成
# 留空则不生效
custom_decision_prompt = ""
# --- 等待策略 ---
[kokoro_flow_chatter.waiting]
default_max_wait_seconds = 300 # LLM 未给出等待时间时的默认值