feat: 提升语义兴趣评分与拼写错误生成
- 为中文拼写生成器实现了背景预热功能,以提升首次使用时的性能。 - 更新了MessageStorageBatcher以支持可配置的提交批次大小和间隔,优化数据库写入性能。 - 增强版数据集生成器,对样本规模设置硬性限制并提升采样效率。 - 将AutoTrainer中的最大样本数增加至1000,以优化训练数据利用率。 - 对亲和兴趣计算器进行了重构,以避免并发初始化并优化模型加载逻辑。 - 引入批量处理机制用于语义兴趣评分,以应对高频聊天场景。 - 更新了配置模板以反映新的评分参数,并移除了已弃用的兴趣阈值。
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -342,3 +342,4 @@ package.json
|
||||
/backup
|
||||
mofox_bot_statistics.html
|
||||
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
|
||||
depends-data/pinyin_dict.json
|
||||
|
||||
@@ -1,282 +0,0 @@
|
||||
"""语义兴趣度评分器性能测试
|
||||
|
||||
对比测试:
|
||||
1. 原始 sklearn 路径 vs FastScorer
|
||||
2. 单条评分 vs 批处理
|
||||
3. 同步 vs 异步
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# 测试样本
|
||||
SAMPLE_TEXTS = [
|
||||
"今天天气真好",
|
||||
"这个游戏太好玩了!",
|
||||
"无聊死了",
|
||||
"我对这个话题很感兴趣",
|
||||
"能不能聊点别的",
|
||||
"哇这个真的很厉害",
|
||||
"你好",
|
||||
"有人在吗",
|
||||
"这个问题很有深度",
|
||||
"随便说说",
|
||||
"真是太棒了,我非常喜欢",
|
||||
"算了算了不想说了",
|
||||
"来聊聊最近的新闻吧",
|
||||
"emmmm",
|
||||
"哈哈哈哈",
|
||||
"666",
|
||||
]
|
||||
|
||||
|
||||
def benchmark_sklearn_scorer(model_path: str, iterations: int = 100):
|
||||
"""测试原始 sklearn 评分器"""
|
||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
||||
|
||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=False)
|
||||
scorer.load()
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
scorer.score(text)
|
||||
|
||||
# 单条评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
scorer.score(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 批量评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
scorer.score_batch(SAMPLE_TEXTS)
|
||||
batch_time = time.perf_counter() - start
|
||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "sklearn",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"batch_total_time": batch_time,
|
||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
||||
"batch_qps": total_batch / batch_time,
|
||||
}
|
||||
|
||||
|
||||
def benchmark_fast_scorer(model_path: str, iterations: int = 100):
|
||||
"""测试 FastScorer"""
|
||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
||||
|
||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=True)
|
||||
scorer.load()
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
scorer.score(text)
|
||||
|
||||
# 单条评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
scorer.score(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 批量评分测试
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
scorer.score_batch(SAMPLE_TEXTS)
|
||||
batch_time = time.perf_counter() - start
|
||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "fast_scorer",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"batch_total_time": batch_time,
|
||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
||||
"batch_qps": total_batch / batch_time,
|
||||
}
|
||||
|
||||
|
||||
async def benchmark_async_scoring(model_path: str, iterations: int = 100):
|
||||
"""测试异步评分"""
|
||||
from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer
|
||||
|
||||
scorer = await get_semantic_scorer(model_path, use_async=True)
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
await scorer.score_async(text)
|
||||
|
||||
# 单条异步评分
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
for text in SAMPLE_TEXTS:
|
||||
await scorer.score_async(text)
|
||||
single_time = time.perf_counter() - start
|
||||
total_single = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
# 并发评分(模拟高并发场景)
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter() - start
|
||||
total_concurrent = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
return {
|
||||
"mode": "async",
|
||||
"single_total_time": single_time,
|
||||
"single_avg_ms": single_time / total_single * 1000,
|
||||
"single_qps": total_single / single_time,
|
||||
"concurrent_total_time": concurrent_time,
|
||||
"concurrent_avg_ms": concurrent_time / total_concurrent * 1000,
|
||||
"concurrent_qps": total_concurrent / concurrent_time,
|
||||
}
|
||||
|
||||
|
||||
async def benchmark_batch_queue(model_path: str, iterations: int = 100):
|
||||
"""测试批处理队列"""
|
||||
from src.chat.semantic_interest.optimized_scorer import get_fast_scorer
|
||||
|
||||
queue = await get_fast_scorer(
|
||||
model_path,
|
||||
use_batch_queue=True,
|
||||
batch_size=8,
|
||||
flush_interval_ms=20.0
|
||||
)
|
||||
|
||||
# 预热
|
||||
for text in SAMPLE_TEXTS[:3]:
|
||||
await queue.score(text)
|
||||
|
||||
# 并发提交评分请求
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
tasks = [queue.score(text) for text in SAMPLE_TEXTS]
|
||||
await asyncio.gather(*tasks)
|
||||
total_time = time.perf_counter() - start
|
||||
total_requests = iterations * len(SAMPLE_TEXTS)
|
||||
|
||||
stats = queue.get_statistics()
|
||||
|
||||
await queue.stop()
|
||||
|
||||
return {
|
||||
"mode": "batch_queue",
|
||||
"total_time": total_time,
|
||||
"avg_ms": total_time / total_requests * 1000,
|
||||
"qps": total_requests / total_time,
|
||||
"total_batches": stats["total_batches"],
|
||||
"avg_batch_size": stats["avg_batch_size"],
|
||||
}
|
||||
|
||||
|
||||
def print_results(results: dict):
|
||||
"""打印测试结果"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"模式: {results['mode']}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if "single_avg_ms" in results:
|
||||
print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}")
|
||||
|
||||
if "batch_avg_ms" in results:
|
||||
print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}")
|
||||
|
||||
if "concurrent_avg_ms" in results:
|
||||
print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}")
|
||||
|
||||
if "total_batches" in results:
|
||||
print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}")
|
||||
print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行性能测试"""
|
||||
import sys
|
||||
|
||||
# 检查模型路径
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
model_files = list(model_dir.glob("semantic_interest_*.pkl"))
|
||||
|
||||
if not model_files:
|
||||
print("错误: 未找到模型文件,请先训练模型")
|
||||
print(f"模型目录: {model_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
# 使用最新的模型
|
||||
model_path = str(max(model_files, key=lambda p: p.stat().st_mtime))
|
||||
print(f"使用模型: {model_path}")
|
||||
|
||||
iterations = 50 # 测试迭代次数
|
||||
|
||||
print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次")
|
||||
print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)} 条")
|
||||
|
||||
# 1. sklearn 原始路径
|
||||
print("\n[1/4] 测试 sklearn 原始路径...")
|
||||
try:
|
||||
sklearn_results = benchmark_sklearn_scorer(model_path, iterations)
|
||||
print_results(sklearn_results)
|
||||
except Exception as e:
|
||||
print(f"sklearn 测试失败: {e}")
|
||||
|
||||
# 2. FastScorer
|
||||
print("\n[2/4] 测试 FastScorer...")
|
||||
try:
|
||||
fast_results = benchmark_fast_scorer(model_path, iterations)
|
||||
print_results(fast_results)
|
||||
except Exception as e:
|
||||
print(f"FastScorer 测试失败: {e}")
|
||||
|
||||
# 3. 异步评分
|
||||
print("\n[3/4] 测试异步评分...")
|
||||
try:
|
||||
async_results = await benchmark_async_scoring(model_path, iterations)
|
||||
print_results(async_results)
|
||||
except Exception as e:
|
||||
print(f"异步测试失败: {e}")
|
||||
|
||||
# 4. 批处理队列
|
||||
print("\n[4/4] 测试批处理队列...")
|
||||
try:
|
||||
queue_results = await benchmark_batch_queue(model_path, iterations)
|
||||
print_results(queue_results)
|
||||
except Exception as e:
|
||||
print(f"批处理队列测试失败: {e}")
|
||||
|
||||
# 性能对比总结
|
||||
print(f"\n{'='*60}")
|
||||
print("性能对比总结")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"]
|
||||
print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速")
|
||||
|
||||
speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"]
|
||||
print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速")
|
||||
except:
|
||||
pass
|
||||
|
||||
print("\n清理资源...")
|
||||
from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances
|
||||
from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances
|
||||
|
||||
shutdown_global_executor()
|
||||
clear_fast_scorer_instances()
|
||||
clear_scorer_instances()
|
||||
|
||||
print("测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
20
bot.py
20
bot.py
@@ -567,6 +567,7 @@ class MaiBotMain:
|
||||
|
||||
def __init__(self):
|
||||
self.main_system = None
|
||||
self._typo_prewarm_task = None
|
||||
|
||||
def setup_timezone(self):
|
||||
"""设置时区"""
|
||||
@@ -663,6 +664,25 @@ class MaiBotMain:
|
||||
async def run_async_init(self, main_system):
|
||||
"""执行异步初始化步骤"""
|
||||
|
||||
# 后台预热中文错别字生成器,避免首次使用阻塞主流程
|
||||
try:
|
||||
from src.chat.utils.typo_generator import get_typo_generator
|
||||
|
||||
typo_cfg = getattr(global_config, "chinese_typo", None)
|
||||
self._typo_prewarm_task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
get_typo_generator,
|
||||
error_rate=getattr(typo_cfg, "error_rate", 0.3),
|
||||
min_freq=getattr(typo_cfg, "min_freq", 5),
|
||||
tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2),
|
||||
word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3),
|
||||
max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200),
|
||||
)
|
||||
)
|
||||
logger.debug("已启动 ChineseTypoGenerator 后台预热任务")
|
||||
except Exception as e:
|
||||
logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}")
|
||||
|
||||
# 初始化数据库表结构
|
||||
await self.initialize_database_async()
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, Any, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
from sqlalchemy import desc, insert, select, update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -25,29 +25,55 @@ class MessageStorageBatcher:
|
||||
消息存储批处理器
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 50,
|
||||
flush_interval: float = 5.0,
|
||||
*,
|
||||
commit_batch_size: int | None = None,
|
||||
commit_interval: float | None = None,
|
||||
db_chunk_size: int = 200,
|
||||
):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
batch_size: 写入队列中触发准备阶段的消息条数
|
||||
flush_interval: 自动刷新/检查间隔(秒)
|
||||
commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size,至少100)
|
||||
commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s))
|
||||
db_chunk_size: 单次SQL语句批量写入数量上限
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100)
|
||||
self.commit_interval = commit_interval or max(flush_interval * 2, 10.0)
|
||||
self.db_chunk_size = max(50, db_chunk_size)
|
||||
|
||||
self.pending_messages: deque = deque()
|
||||
self._prepared_buffer: list[dict[str, Any]] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_barrier = asyncio.Lock()
|
||||
self._flush_task = None
|
||||
self._running = False
|
||||
self._last_commit_ts = time.monotonic()
|
||||
|
||||
async def start(self):
|
||||
"""启动自动刷新任务"""
|
||||
if self._flush_task is None and not self._running:
|
||||
self._running = True
|
||||
self._last_commit_ts = time.monotonic()
|
||||
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
||||
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
|
||||
logger.info(
|
||||
"消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)",
|
||||
self.batch_size,
|
||||
self.flush_interval,
|
||||
self.commit_batch_size,
|
||||
self.commit_interval,
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
@@ -62,7 +88,7 @@ class MessageStorageBatcher:
|
||||
self._flush_task = None
|
||||
|
||||
# 刷新剩余的消息
|
||||
await self.flush()
|
||||
await self.flush(force=True)
|
||||
logger.info("消息存储批处理器已停止")
|
||||
|
||||
async def add_message(self, message_data: dict):
|
||||
@@ -76,61 +102,82 @@ class MessageStorageBatcher:
|
||||
'chat_stream': ChatStream
|
||||
}
|
||||
"""
|
||||
should_force_flush = False
|
||||
async with self._lock:
|
||||
self.pending_messages.append(message_data)
|
||||
|
||||
# 如果达到批量大小,立即刷新
|
||||
if len(self.pending_messages) >= self.batch_size:
|
||||
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
|
||||
await self.flush()
|
||||
should_force_flush = True
|
||||
|
||||
async def flush(self):
|
||||
"""执行批量写入"""
|
||||
async with self._lock:
|
||||
if not self.pending_messages:
|
||||
return
|
||||
if should_force_flush:
|
||||
logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新")
|
||||
await self.flush(force=True)
|
||||
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
async def flush(self, force: bool = False):
|
||||
"""执行批量写入, 支持强制落库和延迟提交策略。"""
|
||||
async with self._flush_barrier:
|
||||
async with self._lock:
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
|
||||
if not messages_to_store:
|
||||
if messages_to_store:
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
|
||||
await self._maybe_commit_buffer(force=force)
|
||||
|
||||
async def _maybe_commit_buffer(self, *, force: bool = False) -> None:
|
||||
"""根据阈值/时间窗口判断是否需要真正写库。"""
|
||||
if not self._prepared_buffer:
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
enough_rows = len(self._prepared_buffer) >= self.commit_batch_size
|
||||
waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval
|
||||
|
||||
if not (force or enough_rows or waited_long_enough):
|
||||
return
|
||||
|
||||
await self._write_buffer_to_database()
|
||||
|
||||
async def _write_buffer_to_database(self) -> None:
|
||||
payload = self._prepared_buffer
|
||||
if not payload:
|
||||
return
|
||||
|
||||
self._prepared_buffer = []
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
total = len(payload)
|
||||
|
||||
try:
|
||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||
messages_dicts = []
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"]
|
||||
)
|
||||
if message_dict:
|
||||
messages_dicts.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
continue
|
||||
|
||||
# 批量写入数据库 - 使用高效的批量INSERT
|
||||
if messages_dicts:
|
||||
from sqlalchemy import insert
|
||||
async with get_db_session() as session:
|
||||
stmt = insert(Messages).values(messages_dicts)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
success_count = len(messages_dicts)
|
||||
async with get_db_session() as session:
|
||||
for start in range(0, total, self.db_chunk_size):
|
||||
chunk = payload[start : start + self.db_chunk_size]
|
||||
if chunk:
|
||||
await session.execute(insert(Messages), chunk)
|
||||
await session.commit()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
self._last_commit_ts = time.monotonic()
|
||||
per_item = (elapsed / total) * 1000 if total else 0
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
|
||||
f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 回滚到缓冲区, 等待下一次尝试
|
||||
self._prepared_buffer = payload + self._prepared_buffer
|
||||
logger.error(f"批量存储消息失败: {e}")
|
||||
|
||||
async def _prepare_message_dict(self, message, chat_stream):
|
||||
|
||||
@@ -116,6 +116,10 @@ class AutoTrainer:
|
||||
"interests": sorted(persona_info.get("interests", [])),
|
||||
"dislikes": sorted(persona_info.get("dislikes", [])),
|
||||
"personality": persona_info.get("personality", ""),
|
||||
# 可选的更完整人设字段(存在则纳入哈希)
|
||||
"personality_core": persona_info.get("personality_core", ""),
|
||||
"personality_side": persona_info.get("personality_side", ""),
|
||||
"identity": persona_info.get("identity", ""),
|
||||
}
|
||||
|
||||
# 转为JSON并计算哈希
|
||||
@@ -178,7 +182,7 @@ class AutoTrainer:
|
||||
self,
|
||||
persona_info: dict[str, Any],
|
||||
days: int = 7,
|
||||
max_samples: int = 500,
|
||||
max_samples: int = 1000,
|
||||
force: bool = False,
|
||||
) -> tuple[bool, Path | None]:
|
||||
"""自动训练(如果需要)
|
||||
@@ -186,7 +190,7 @@ class AutoTrainer:
|
||||
Args:
|
||||
persona_info: 人设信息
|
||||
days: 采样天数
|
||||
max_samples: 最大采样数
|
||||
max_samples: 最大采样数(默认1000条)
|
||||
force: 强制训练
|
||||
|
||||
Returns:
|
||||
@@ -279,11 +283,12 @@ class AutoTrainer:
|
||||
"""
|
||||
# 检查是否已经有任务在运行
|
||||
if self._scheduled_task_running:
|
||||
logger.debug(f"[自动训练器] 定时任务已在运行,跳过")
|
||||
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||
return
|
||||
|
||||
self._scheduled_task_running = True
|
||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -22,6 +22,9 @@ class DatasetGenerator:
|
||||
从历史消息中采样并使用 LLM 进行标注
|
||||
"""
|
||||
|
||||
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
|
||||
HARD_MAX_SAMPLES = 2000
|
||||
|
||||
# 标注提示词模板(单条)
|
||||
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
||||
|
||||
@@ -107,7 +110,7 @@ class DatasetGenerator:
|
||||
max_samples: int = 1000,
|
||||
priority_ranges: list[tuple[float, float]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""从数据库采样消息
|
||||
"""从数据库采样消息(优化版:减少查询量和内存使用)
|
||||
|
||||
Args:
|
||||
days: 采样最近 N 天的消息
|
||||
@@ -120,40 +123,75 @@ class DatasetGenerator:
|
||||
"""
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
from src.common.database.core.models import Messages
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天")
|
||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||
|
||||
# 限制采样数量硬上限
|
||||
requested_max_samples = max_samples
|
||||
if max_samples is None:
|
||||
max_samples = self.HARD_MAX_SAMPLES
|
||||
else:
|
||||
max_samples = int(max_samples)
|
||||
if max_samples <= 0:
|
||||
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
|
||||
return []
|
||||
if max_samples > self.HARD_MAX_SAMPLES:
|
||||
logger.warning(
|
||||
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES},"
|
||||
f"已截断为 {self.HARD_MAX_SAMPLES}"
|
||||
)
|
||||
max_samples = self.HARD_MAX_SAMPLES
|
||||
|
||||
# 查询条件
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
cutoff_ts = cutoff_time.timestamp()
|
||||
|
||||
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
|
||||
# 这样可以在保证足够样本的同时减少查询量
|
||||
prefetch_limit = int(max_samples * 1.5)
|
||||
|
||||
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
|
||||
query_builder = QueryBuilder(Messages)
|
||||
|
||||
# 获取所有符合条件的消息(使用 as_dict 方便访问字段)
|
||||
|
||||
# 过滤条件:时间范围 + 消息文本不为空
|
||||
messages = await query_builder.filter(
|
||||
time__gte=cutoff_ts,
|
||||
).order_by(
|
||||
"-time" # 按时间倒序,优先采样最新消息
|
||||
).limit(
|
||||
prefetch_limit # 限制预取数量
|
||||
).all(as_dict=True)
|
||||
|
||||
logger.info(f"查询到 {len(messages)} 条消息")
|
||||
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit})")
|
||||
|
||||
# 过滤消息长度
|
||||
# 过滤消息长度和提取文本
|
||||
filtered = []
|
||||
for msg in messages:
|
||||
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
||||
if text and len(text.strip()) >= min_length:
|
||||
text = text.strip()
|
||||
if text and len(text) >= min_length:
|
||||
filtered.append({**msg, "message_text": text})
|
||||
# 达到目标数量即可停止
|
||||
if len(filtered) >= max_samples:
|
||||
break
|
||||
|
||||
logger.info(f"过滤后剩余 {len(filtered)} 条消息")
|
||||
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples})")
|
||||
|
||||
# 优先采样策略
|
||||
if priority_ranges and len(filtered) > max_samples:
|
||||
# 随机采样
|
||||
samples = random.sample(filtered, max_samples)
|
||||
else:
|
||||
samples = filtered[:max_samples]
|
||||
# 如果过滤后数量不足,记录警告
|
||||
if len(filtered) < max_samples:
|
||||
logger.warning(
|
||||
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples}),"
|
||||
f"可能需要扩大采样范围(增加 days 参数或降低 min_length)"
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
# 随机打乱样本顺序(避免时间偏向)
|
||||
if len(filtered) > 0:
|
||||
random.shuffle(filtered)
|
||||
|
||||
# 转换为标准格式
|
||||
result = []
|
||||
for msg in samples:
|
||||
for msg in filtered:
|
||||
result.append({
|
||||
"message_id": msg.get("message_id"),
|
||||
"user_id": msg.get("user_id"),
|
||||
@@ -335,19 +373,50 @@ class DatasetGenerator:
|
||||
Returns:
|
||||
格式化后的人格描述
|
||||
"""
|
||||
parts = []
|
||||
def _stringify(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return "、".join([str(v) for v in value if v is not None and str(v).strip()])
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True)
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value).strip()
|
||||
|
||||
if "name" in persona_info:
|
||||
parts.append(f"角色名称: {persona_info['name']}")
|
||||
parts: list[str] = []
|
||||
|
||||
if "interests" in persona_info:
|
||||
parts.append(f"兴趣点: {', '.join(persona_info['interests'])}")
|
||||
name = _stringify(persona_info.get("name"))
|
||||
if name:
|
||||
parts.append(f"角色名称: {name}")
|
||||
|
||||
if "dislikes" in persona_info:
|
||||
parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}")
|
||||
# 核心/侧面/身份等完整人设信息
|
||||
personality_core = _stringify(persona_info.get("personality_core"))
|
||||
if personality_core:
|
||||
parts.append(f"核心人设: {personality_core}")
|
||||
|
||||
if "personality" in persona_info:
|
||||
parts.append(f"性格特点: {persona_info['personality']}")
|
||||
personality_side = _stringify(persona_info.get("personality_side"))
|
||||
if personality_side:
|
||||
parts.append(f"侧面特质: {personality_side}")
|
||||
|
||||
identity = _stringify(persona_info.get("identity"))
|
||||
if identity:
|
||||
parts.append(f"身份特征: {identity}")
|
||||
|
||||
# 追加其他未覆盖字段(保持信息完整)
|
||||
known_keys = {
|
||||
"name",
|
||||
"personality_core",
|
||||
"personality_side",
|
||||
"identity",
|
||||
}
|
||||
for key, value in persona_info.items():
|
||||
if key in known_keys:
|
||||
continue
|
||||
value_str = _stringify(value)
|
||||
if value_str:
|
||||
parts.append(f"{key}: {value_str}")
|
||||
|
||||
return "\n".join(parts) if parts else "无特定人格设定"
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class TfidfFeatureExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
analyzer: str = "char", # type: ignore
|
||||
ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围
|
||||
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
|
||||
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
|
||||
min_df: int = 3, # 优化:过滤低频 n-gram
|
||||
max_df: float = 0.95,
|
||||
|
||||
@@ -44,7 +44,6 @@ class SemanticInterestModel:
|
||||
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
|
||||
"""
|
||||
self.clf = LogisticRegression(
|
||||
multi_class="multinomial",
|
||||
solver=solver,
|
||||
max_iter=max_iter,
|
||||
class_weight=class_weight,
|
||||
@@ -206,7 +205,6 @@ class SemanticInterestModel:
|
||||
"""
|
||||
params = self.clf.get_params()
|
||||
return {
|
||||
"multi_class": params["multi_class"],
|
||||
"solver": params["solver"],
|
||||
"max_iter": params["max_iter"],
|
||||
"class_weight": params["class_weight"],
|
||||
|
||||
@@ -558,7 +558,7 @@ class ModelManager:
|
||||
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
days=7,
|
||||
max_samples=500,
|
||||
max_samples=1000, # 初始训练使用1000条消息
|
||||
)
|
||||
|
||||
if trained and model_path:
|
||||
@@ -607,30 +607,32 @@ class ModelManager:
|
||||
persona_info: 人设信息
|
||||
interval_hours: 检查间隔(小时)
|
||||
"""
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
# 使用锁防止并发启动
|
||||
async with self._lock:
|
||||
# 检查是否已经启动
|
||||
if self._auto_training_started:
|
||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||
return
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
# 标记为已启动
|
||||
self._auto_training_started = True
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
self._auto_training_started = False # 失败时重置标志
|
||||
try:
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
|
||||
if self._auto_trainer is None:
|
||||
self._auto_trainer = get_auto_trainer()
|
||||
|
||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||
|
||||
# 标记为已启动
|
||||
self._auto_training_started = True
|
||||
|
||||
# 在后台任务中运行
|
||||
asyncio.create_task(
|
||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||
self._auto_training_started = False # 失败时重置标志
|
||||
|
||||
|
||||
# 单例获取函数
|
||||
|
||||
@@ -191,44 +191,3 @@ class SemanticInterestTrainer:
|
||||
|
||||
return dataset_path, model_path, metrics
|
||||
|
||||
|
||||
async def main():
|
||||
"""示例:训练一个语义兴趣度模型"""
|
||||
|
||||
# 示例人格信息
|
||||
persona_info = {
|
||||
"name": "小狐",
|
||||
"interests": ["动漫", "游戏", "编程", "技术", "二次元"],
|
||||
"dislikes": ["广告", "政治", "无聊闲聊"],
|
||||
"personality": "活泼开朗,对新鲜事物充满好奇",
|
||||
}
|
||||
|
||||
# 创建训练器
|
||||
trainer = SemanticInterestTrainer()
|
||||
|
||||
# 执行完整训练流程
|
||||
dataset_path, model_path, metrics = await trainer.full_training_pipeline(
|
||||
persona_info=persona_info,
|
||||
days=7, # 使用最近 7 天的消息
|
||||
max_samples=500, # 采样 500 条消息
|
||||
llm_model_name=None, # 使用默认 LLM
|
||||
tfidf_config={
|
||||
"analyzer": "char",
|
||||
"ngram_range": (2, 4),
|
||||
"max_features": 15000,
|
||||
"min_df": 3,
|
||||
},
|
||||
model_config={
|
||||
"class_weight": "balanced",
|
||||
"max_iter": 1000,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\n训练完成!")
|
||||
print(f"数据集: {dataset_path}")
|
||||
print(f"模型: {model_path}")
|
||||
print(f"准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -96,7 +96,7 @@ class ChineseTypoGenerator:
|
||||
|
||||
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
|
||||
if _shared_pinyin_dict is None:
|
||||
_shared_pinyin_dict = self._create_pinyin_dict()
|
||||
_shared_pinyin_dict = self._load_or_create_pinyin_dict()
|
||||
logger.debug("拼音字典已创建并缓存")
|
||||
self.pinyin_dict = _shared_pinyin_dict
|
||||
|
||||
@@ -141,6 +141,35 @@ class ChineseTypoGenerator:
|
||||
|
||||
return normalized_freq
|
||||
|
||||
def _load_or_create_pinyin_dict(self):
|
||||
"""
|
||||
加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动)
|
||||
"""
|
||||
cache_file = Path("depends-data/pinyin_dict.json")
|
||||
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
# 恢复为 defaultdict(list) 以兼容旧逻辑
|
||||
restored = defaultdict(list)
|
||||
for py, chars in data.items():
|
||||
restored[py] = list(chars)
|
||||
return restored
|
||||
except Exception as e:
|
||||
logger.warning(f"读取拼音缓存失败,将重新生成: {e}")
|
||||
|
||||
pinyin_dict = self._create_pinyin_dict()
|
||||
|
||||
try:
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"写入拼音缓存失败(不影响使用): {e}")
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
@staticmethod
|
||||
def _create_pinyin_dict():
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import threading
|
||||
import time
|
||||
@@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler):
|
||||
self.backup_count = backup_count
|
||||
self.encoding = encoding
|
||||
self._lock = threading.Lock()
|
||||
self._current_size = 0
|
||||
self._bytes_since_check = 0
|
||||
self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8"))
|
||||
self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024)
|
||||
|
||||
# 当前活跃的日志文件
|
||||
self.current_file = None
|
||||
@@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler):
|
||||
# 极低概率碰撞,稍作等待
|
||||
time.sleep(0.001)
|
||||
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
|
||||
self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0
|
||||
self._bytes_since_check = 0
|
||||
|
||||
def _should_rollover(self):
|
||||
"""检查是否需要轮转"""
|
||||
if self.current_file and self.current_file.exists():
|
||||
return self.current_file.stat().st_size >= self.max_bytes
|
||||
def _should_rollover(self, incoming_size: int = 0) -> bool:
|
||||
"""检查是否需要轮转,使用内存缓存的大小信息减少磁盘stat次数。"""
|
||||
if not self.current_file:
|
||||
return False
|
||||
|
||||
projected = self._current_size + incoming_size
|
||||
if projected >= self.max_bytes:
|
||||
return True
|
||||
|
||||
self._bytes_since_check += incoming_size
|
||||
if self._bytes_since_check >= self._stat_refresh_threshold:
|
||||
try:
|
||||
if self.current_file.exists():
|
||||
self._current_size = self.current_file.stat().st_size
|
||||
else:
|
||||
self._current_size = 0
|
||||
except OSError:
|
||||
self._current_size = 0
|
||||
finally:
|
||||
self._bytes_since_check = 0
|
||||
return False
|
||||
|
||||
def _do_rollover(self):
|
||||
@@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
"""发出日志记录"""
|
||||
try:
|
||||
message = self.format(record)
|
||||
encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes
|
||||
|
||||
with self._lock:
|
||||
# 检查是否需要轮转
|
||||
if self._should_rollover():
|
||||
if self._should_rollover(encoded_len):
|
||||
self._do_rollover()
|
||||
|
||||
# 写入日志
|
||||
if self.current_stream:
|
||||
msg = self.format(record)
|
||||
self.current_stream.write(msg + "\n")
|
||||
self.current_stream.write(message + "\n")
|
||||
self.current_stream.flush()
|
||||
self._current_size += encoded_len
|
||||
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
@@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = {
|
||||
}
|
||||
|
||||
|
||||
# 创建全局 Rich Console 实例用于颜色渲染
|
||||
_rich_console = Console(force_terminal=True, color_system="truecolor")
|
||||
|
||||
|
||||
class ModuleColoredConsoleRenderer:
|
||||
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
|
||||
|
||||
@@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer:
|
||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||
self._colors = colors
|
||||
self._config = LOG_CONFIG
|
||||
self._render_console = Console(force_terminal=True, color_system="truecolor", width=999)
|
||||
|
||||
# 日志级别颜色 (#RRGGBB 格式)
|
||||
self._level_colors_hex = {
|
||||
@@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer:
|
||||
self._enable_level_colors = False
|
||||
self._enable_full_content_colors = False
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_markup(content: str) -> bool:
|
||||
"""快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。"""
|
||||
if not content:
|
||||
return False
|
||||
return "[" in content and "]" in content
|
||||
|
||||
def _render_content_text(self, content: str, *, style: str | None = None) -> Text:
|
||||
"""只在必要时解析 Rich 标记,降低CPU占用。"""
|
||||
if self._looks_like_markup(content):
|
||||
try:
|
||||
return Text.from_markup(content, style=style)
|
||||
except Exception:
|
||||
return Text(content, style=style)
|
||||
return Text(content, style=style)
|
||||
|
||||
def __call__(self, logger, method_name, event_dict):
|
||||
# sourcery skip: merge-duplicate-blocks
|
||||
"""渲染日志消息"""
|
||||
@@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer:
|
||||
if prefix:
|
||||
# 解析 prefix 中的 Rich 标记
|
||||
if module_hex_color:
|
||||
content_text.append(Text.from_markup(prefix, style=module_hex_color))
|
||||
content_text.append(self._render_content_text(prefix, style=module_hex_color))
|
||||
else:
|
||||
content_text.append(Text.from_markup(prefix))
|
||||
content_text.append(self._render_content_text(prefix))
|
||||
|
||||
# 与"内心思考"段落之间插入空行
|
||||
if prefix:
|
||||
@@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer:
|
||||
else:
|
||||
# 使用 Text.from_markup 解析 Rich 标记语言
|
||||
if module_hex_color:
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content, style=module_hex_color))
|
||||
except Exception:
|
||||
# 如果标记解析失败,回退到普通文本
|
||||
parts.append(Text(event_content, style=module_hex_color))
|
||||
parts.append(self._render_content_text(event_content, style=module_hex_color))
|
||||
else:
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content))
|
||||
except Exception:
|
||||
# 如果标记解析失败,回退到普通文本
|
||||
parts.append(Text(event_content))
|
||||
parts.append(self._render_content_text(event_content))
|
||||
else:
|
||||
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
|
||||
try:
|
||||
parts.append(Text.from_markup(event_content))
|
||||
except Exception:
|
||||
# 如果标记解析失败,使用普通文本
|
||||
parts.append(Text(event_content))
|
||||
parts.append(self._render_content_text(event_content))
|
||||
|
||||
# 处理其他字段
|
||||
extras = []
|
||||
@@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer:
|
||||
|
||||
# 使用 Rich 拼接并返回字符串
|
||||
result = Text(" ").join(parts)
|
||||
# 将 Rich Text 对象转换为带 ANSI 颜色码的字符串
|
||||
from io import StringIO
|
||||
string_io = StringIO()
|
||||
temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999)
|
||||
temp_console.print(result, end="")
|
||||
return string_io.getvalue()
|
||||
# 使用持久化 Console + capture 避免每条日志重复实例化
|
||||
with self._render_console.capture() as capture:
|
||||
self._render_console.print(result, end="")
|
||||
return capture.get()
|
||||
|
||||
|
||||
# 配置标准logging以支持文件输出和压缩
|
||||
|
||||
@@ -803,16 +803,8 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
||||
# 兴趣评分系统参数
|
||||
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
||||
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
||||
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
|
||||
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
|
||||
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
|
||||
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
|
||||
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
|
||||
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
||||
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
||||
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
||||
|
||||
# 语义兴趣度评分优化参数(2024.12 新增)
|
||||
|
||||
# 语义兴趣度评分优化参数
|
||||
use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景")
|
||||
batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理")
|
||||
batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理")
|
||||
|
||||
@@ -298,80 +298,105 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.semantic_interest import get_semantic_scorer
|
||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||
# 防止并发初始化(使用锁)
|
||||
if not hasattr(self, '_init_lock'):
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async with self._init_lock:
|
||||
# 双重检查
|
||||
if self._semantic_initialized:
|
||||
logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过")
|
||||
return
|
||||
|
||||
# 查找最新的模型文件
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
if not model_dir.exists():
|
||||
logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用模型管理器(支持人设感知)
|
||||
self.model_manager = ModelManager(model_dir)
|
||||
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||
try:
|
||||
scorer = await self.model_manager.load_model(
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
self.semantic_scorer = scorer
|
||||
from src.chat.semantic_interest import get_semantic_scorer
|
||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||
|
||||
# 查找最新的模型文件
|
||||
model_dir = Path("data/semantic_interest/models")
|
||||
if not model_dir.exists():
|
||||
logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用模型管理器(支持人设感知)
|
||||
if self.model_manager is None:
|
||||
self.model_manager = ModelManager(model_dir)
|
||||
logger.debug("[语义评分] 模型管理器已创建")
|
||||
|
||||
# 获取人设信息
|
||||
persona_info = self._get_current_persona_info()
|
||||
|
||||
# 如果启用批处理队列模式
|
||||
if self._use_batch_queue:
|
||||
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
||||
|
||||
# 确保 scorer 有 FastScorer
|
||||
if scorer._fast_scorer is not None:
|
||||
self._batch_queue = BatchScoringQueue(
|
||||
scorer=scorer._fast_scorer,
|
||||
batch_size=self._batch_size,
|
||||
flush_interval_ms=self._batch_flush_interval_ms
|
||||
)
|
||||
await self._batch_queue.start()
|
||||
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
|
||||
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
|
||||
# 启动自动训练任务(每24小时检查一次)
|
||||
await self.model_manager.start_auto_training(
|
||||
persona_info=persona_info,
|
||||
interval_hours=24
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
# 先检查是否已有可用模型
|
||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||
auto_trainer = get_auto_trainer()
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
force=True # 强制训练
|
||||
)
|
||||
if trained and model_path:
|
||||
# 使用单例获取评分器(默认启用 FastScorer)
|
||||
self.semantic_scorer = await get_semantic_scorer(model_path)
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
||||
existing_model = auto_trainer.get_model_for_persona(persona_info)
|
||||
|
||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||
try:
|
||||
if existing_model and existing_model.exists():
|
||||
# 直接加载已有模型
|
||||
logger.info(f"[语义评分] 使用已有模型: {existing_model.name}")
|
||||
scorer = await get_semantic_scorer(existing_model, use_async=True)
|
||||
else:
|
||||
# 使用 ModelManager 自动选择或训练
|
||||
scorer = await self.model_manager.load_model(
|
||||
version="auto", # 自动选择或训练
|
||||
persona_info=persona_info
|
||||
)
|
||||
|
||||
self.semantic_scorer = scorer
|
||||
|
||||
# 如果启用批处理队列模式
|
||||
if self._use_batch_queue:
|
||||
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
||||
|
||||
# 确保 scorer 有 FastScorer
|
||||
if scorer._fast_scorer is not None:
|
||||
self._batch_queue = BatchScoringQueue(
|
||||
scorer=scorer._fast_scorer,
|
||||
batch_size=self._batch_size,
|
||||
flush_interval_ms=self._batch_flush_interval_ms
|
||||
)
|
||||
await self._batch_queue.start()
|
||||
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
|
||||
|
||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
else:
|
||||
logger.error("[语义评分] 首次训练失败")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
# 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动
|
||||
if not existing_model or not existing_model.exists():
|
||||
await self.model_manager.start_auto_training(
|
||||
persona_info=persona_info,
|
||||
interval_hours=24
|
||||
)
|
||||
else:
|
||||
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||
# 触发首次训练
|
||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||
persona_info=persona_info,
|
||||
force=True # 强制训练
|
||||
)
|
||||
if trained and model_path:
|
||||
# 使用单例获取评分器(默认启用 FastScorer)
|
||||
self.semantic_scorer = await get_semantic_scorer(model_path)
|
||||
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
||||
# 设置初始化标志
|
||||
self._semantic_initialized = True
|
||||
else:
|
||||
logger.error("[语义评分] 首次训练失败")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
except ImportError:
|
||||
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
||||
self.use_semantic_scoring = False
|
||||
except Exception as e:
|
||||
logger.error(f"[语义评分] 初始化失败: {e}")
|
||||
self.use_semantic_scoring = False
|
||||
except ImportError:
|
||||
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
||||
self.use_semantic_scoring = False
|
||||
except Exception as e:
|
||||
logger.error(f"[语义评分] 初始化失败: {e}")
|
||||
self.use_semantic_scoring = False
|
||||
|
||||
def _get_current_persona_info(self) -> dict[str, Any]:
|
||||
"""获取当前人设信息
|
||||
@@ -539,3 +564,5 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
logger.debug(
|
||||
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
|
||||
)
|
||||
|
||||
afc_interest_calculator = AffinityInterestCalculator()
|
||||
@@ -174,10 +174,10 @@ class ChatterActionPlanner:
|
||||
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||
AffinityInterestCalculator,
|
||||
afc_interest_calculator,
|
||||
)
|
||||
|
||||
calculator = AffinityInterestCalculator()
|
||||
calculator = afc_interest_calculator
|
||||
if not await calculator.initialize():
|
||||
logger.warning("AffinityInterestCalculator 初始化失败")
|
||||
return None
|
||||
|
||||
@@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 AffinityInterestCalculator(从 core 子模块)
|
||||
from .core.affinity_interest_calculator import AffinityInterestCalculator
|
||||
|
||||
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
||||
|
||||
try:
|
||||
# 延迟导入 UserProfileTool(从 tools 子模块)
|
||||
from .tools.user_profile_tool import UserProfileTool
|
||||
|
||||
@@ -539,14 +539,11 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在
|
||||
# 兴趣评分系统参数
|
||||
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
|
||||
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
|
||||
high_match_interest_threshold = 0.6 # 高匹配兴趣阈值
|
||||
medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值
|
||||
low_match_interest_threshold = 0.2 # 低匹配兴趣阈值
|
||||
high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率
|
||||
medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率
|
||||
low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率
|
||||
match_count_bonus = 0.01 # 匹配数关键词加成值
|
||||
max_match_bonus = 0.1 # 最大匹配数加成值
|
||||
|
||||
# 语义兴趣度评分优化参数
|
||||
use_batch_scoring = true # 是否启用批处理评分模式,适合高频群聊场景
|
||||
batch_size = 3 # 批处理大小,达到后立即处理
|
||||
batch_flush_interval_ms = 30.0 # 批处理刷新间隔(毫秒),超过后强制处理
|
||||
|
||||
# 回复决策系统参数
|
||||
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值
|
||||
|
||||
Reference in New Issue
Block a user