feat: 提升语义兴趣评分与拼写错误生成
- 为中文拼写生成器实现了背景预热功能,以提升首次使用时的性能。 - 更新了MessageStorageBatcher以支持可配置的提交批次大小和间隔,优化数据库写入性能。 - 增强版数据集生成器,对样本规模设置硬性限制并提升采样效率。 - 将AutoTrainer中的最大样本数增加至1000,以优化训练数据利用率。 - 对亲和兴趣计算器进行了重构,以避免并发初始化并优化模型加载逻辑。 - 引入批量处理机制用于语义兴趣评分,以应对高频聊天场景。 - 更新了配置模板以反映新的评分参数,并移除了已弃用的兴趣阈值。
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -342,3 +342,4 @@ package.json
|
|||||||
/backup
|
/backup
|
||||||
mofox_bot_statistics.html
|
mofox_bot_statistics.html
|
||||||
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
|
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
|
||||||
|
depends-data/pinyin_dict.json
|
||||||
|
|||||||
@@ -1,282 +0,0 @@
|
|||||||
"""语义兴趣度评分器性能测试
|
|
||||||
|
|
||||||
对比测试:
|
|
||||||
1. 原始 sklearn 路径 vs FastScorer
|
|
||||||
2. 单条评分 vs 批处理
|
|
||||||
3. 同步 vs 异步
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 测试样本
|
|
||||||
SAMPLE_TEXTS = [
|
|
||||||
"今天天气真好",
|
|
||||||
"这个游戏太好玩了!",
|
|
||||||
"无聊死了",
|
|
||||||
"我对这个话题很感兴趣",
|
|
||||||
"能不能聊点别的",
|
|
||||||
"哇这个真的很厉害",
|
|
||||||
"你好",
|
|
||||||
"有人在吗",
|
|
||||||
"这个问题很有深度",
|
|
||||||
"随便说说",
|
|
||||||
"真是太棒了,我非常喜欢",
|
|
||||||
"算了算了不想说了",
|
|
||||||
"来聊聊最近的新闻吧",
|
|
||||||
"emmmm",
|
|
||||||
"哈哈哈哈",
|
|
||||||
"666",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_sklearn_scorer(model_path: str, iterations: int = 100):
|
|
||||||
"""测试原始 sklearn 评分器"""
|
|
||||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
|
||||||
|
|
||||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=False)
|
|
||||||
scorer.load()
|
|
||||||
|
|
||||||
# 预热
|
|
||||||
for text in SAMPLE_TEXTS[:3]:
|
|
||||||
scorer.score(text)
|
|
||||||
|
|
||||||
# 单条评分测试
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
for text in SAMPLE_TEXTS:
|
|
||||||
scorer.score(text)
|
|
||||||
single_time = time.perf_counter() - start
|
|
||||||
total_single = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
# 批量评分测试
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
scorer.score_batch(SAMPLE_TEXTS)
|
|
||||||
batch_time = time.perf_counter() - start
|
|
||||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mode": "sklearn",
|
|
||||||
"single_total_time": single_time,
|
|
||||||
"single_avg_ms": single_time / total_single * 1000,
|
|
||||||
"single_qps": total_single / single_time,
|
|
||||||
"batch_total_time": batch_time,
|
|
||||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
|
||||||
"batch_qps": total_batch / batch_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_fast_scorer(model_path: str, iterations: int = 100):
|
|
||||||
"""测试 FastScorer"""
|
|
||||||
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
|
|
||||||
|
|
||||||
scorer = SemanticInterestScorer(model_path, use_fast_scorer=True)
|
|
||||||
scorer.load()
|
|
||||||
|
|
||||||
# 预热
|
|
||||||
for text in SAMPLE_TEXTS[:3]:
|
|
||||||
scorer.score(text)
|
|
||||||
|
|
||||||
# 单条评分测试
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
for text in SAMPLE_TEXTS:
|
|
||||||
scorer.score(text)
|
|
||||||
single_time = time.perf_counter() - start
|
|
||||||
total_single = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
# 批量评分测试
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
scorer.score_batch(SAMPLE_TEXTS)
|
|
||||||
batch_time = time.perf_counter() - start
|
|
||||||
total_batch = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mode": "fast_scorer",
|
|
||||||
"single_total_time": single_time,
|
|
||||||
"single_avg_ms": single_time / total_single * 1000,
|
|
||||||
"single_qps": total_single / single_time,
|
|
||||||
"batch_total_time": batch_time,
|
|
||||||
"batch_avg_ms": batch_time / total_batch * 1000,
|
|
||||||
"batch_qps": total_batch / batch_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def benchmark_async_scoring(model_path: str, iterations: int = 100):
|
|
||||||
"""测试异步评分"""
|
|
||||||
from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer
|
|
||||||
|
|
||||||
scorer = await get_semantic_scorer(model_path, use_async=True)
|
|
||||||
|
|
||||||
# 预热
|
|
||||||
for text in SAMPLE_TEXTS[:3]:
|
|
||||||
await scorer.score_async(text)
|
|
||||||
|
|
||||||
# 单条异步评分
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
for text in SAMPLE_TEXTS:
|
|
||||||
await scorer.score_async(text)
|
|
||||||
single_time = time.perf_counter() - start
|
|
||||||
total_single = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
# 并发评分(模拟高并发场景)
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
concurrent_time = time.perf_counter() - start
|
|
||||||
total_concurrent = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mode": "async",
|
|
||||||
"single_total_time": single_time,
|
|
||||||
"single_avg_ms": single_time / total_single * 1000,
|
|
||||||
"single_qps": total_single / single_time,
|
|
||||||
"concurrent_total_time": concurrent_time,
|
|
||||||
"concurrent_avg_ms": concurrent_time / total_concurrent * 1000,
|
|
||||||
"concurrent_qps": total_concurrent / concurrent_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def benchmark_batch_queue(model_path: str, iterations: int = 100):
|
|
||||||
"""测试批处理队列"""
|
|
||||||
from src.chat.semantic_interest.optimized_scorer import get_fast_scorer
|
|
||||||
|
|
||||||
queue = await get_fast_scorer(
|
|
||||||
model_path,
|
|
||||||
use_batch_queue=True,
|
|
||||||
batch_size=8,
|
|
||||||
flush_interval_ms=20.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# 预热
|
|
||||||
for text in SAMPLE_TEXTS[:3]:
|
|
||||||
await queue.score(text)
|
|
||||||
|
|
||||||
# 并发提交评分请求
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(iterations):
|
|
||||||
tasks = [queue.score(text) for text in SAMPLE_TEXTS]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
total_time = time.perf_counter() - start
|
|
||||||
total_requests = iterations * len(SAMPLE_TEXTS)
|
|
||||||
|
|
||||||
stats = queue.get_statistics()
|
|
||||||
|
|
||||||
await queue.stop()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"mode": "batch_queue",
|
|
||||||
"total_time": total_time,
|
|
||||||
"avg_ms": total_time / total_requests * 1000,
|
|
||||||
"qps": total_requests / total_time,
|
|
||||||
"total_batches": stats["total_batches"],
|
|
||||||
"avg_batch_size": stats["avg_batch_size"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def print_results(results: dict):
|
|
||||||
"""打印测试结果"""
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"模式: {results['mode']}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
if "single_avg_ms" in results:
|
|
||||||
print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}")
|
|
||||||
|
|
||||||
if "batch_avg_ms" in results:
|
|
||||||
print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}")
|
|
||||||
|
|
||||||
if "concurrent_avg_ms" in results:
|
|
||||||
print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}")
|
|
||||||
|
|
||||||
if "total_batches" in results:
|
|
||||||
print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}")
|
|
||||||
print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""运行性能测试"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# 检查模型路径
|
|
||||||
model_dir = Path("data/semantic_interest/models")
|
|
||||||
model_files = list(model_dir.glob("semantic_interest_*.pkl"))
|
|
||||||
|
|
||||||
if not model_files:
|
|
||||||
print("错误: 未找到模型文件,请先训练模型")
|
|
||||||
print(f"模型目录: {model_dir}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# 使用最新的模型
|
|
||||||
model_path = str(max(model_files, key=lambda p: p.stat().st_mtime))
|
|
||||||
print(f"使用模型: {model_path}")
|
|
||||||
|
|
||||||
iterations = 50 # 测试迭代次数
|
|
||||||
|
|
||||||
print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次")
|
|
||||||
print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)} 条")
|
|
||||||
|
|
||||||
# 1. sklearn 原始路径
|
|
||||||
print("\n[1/4] 测试 sklearn 原始路径...")
|
|
||||||
try:
|
|
||||||
sklearn_results = benchmark_sklearn_scorer(model_path, iterations)
|
|
||||||
print_results(sklearn_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"sklearn 测试失败: {e}")
|
|
||||||
|
|
||||||
# 2. FastScorer
|
|
||||||
print("\n[2/4] 测试 FastScorer...")
|
|
||||||
try:
|
|
||||||
fast_results = benchmark_fast_scorer(model_path, iterations)
|
|
||||||
print_results(fast_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"FastScorer 测试失败: {e}")
|
|
||||||
|
|
||||||
# 3. 异步评分
|
|
||||||
print("\n[3/4] 测试异步评分...")
|
|
||||||
try:
|
|
||||||
async_results = await benchmark_async_scoring(model_path, iterations)
|
|
||||||
print_results(async_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"异步测试失败: {e}")
|
|
||||||
|
|
||||||
# 4. 批处理队列
|
|
||||||
print("\n[4/4] 测试批处理队列...")
|
|
||||||
try:
|
|
||||||
queue_results = await benchmark_batch_queue(model_path, iterations)
|
|
||||||
print_results(queue_results)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"批处理队列测试失败: {e}")
|
|
||||||
|
|
||||||
# 性能对比总结
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("性能对比总结")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"]
|
|
||||||
print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速")
|
|
||||||
|
|
||||||
speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"]
|
|
||||||
print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("\n清理资源...")
|
|
||||||
from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances
|
|
||||||
from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances
|
|
||||||
|
|
||||||
shutdown_global_executor()
|
|
||||||
clear_fast_scorer_instances()
|
|
||||||
clear_scorer_instances()
|
|
||||||
|
|
||||||
print("测试完成!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
20
bot.py
20
bot.py
@@ -567,6 +567,7 @@ class MaiBotMain:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.main_system = None
|
self.main_system = None
|
||||||
|
self._typo_prewarm_task = None
|
||||||
|
|
||||||
def setup_timezone(self):
|
def setup_timezone(self):
|
||||||
"""设置时区"""
|
"""设置时区"""
|
||||||
@@ -663,6 +664,25 @@ class MaiBotMain:
|
|||||||
async def run_async_init(self, main_system):
|
async def run_async_init(self, main_system):
|
||||||
"""执行异步初始化步骤"""
|
"""执行异步初始化步骤"""
|
||||||
|
|
||||||
|
# 后台预热中文错别字生成器,避免首次使用阻塞主流程
|
||||||
|
try:
|
||||||
|
from src.chat.utils.typo_generator import get_typo_generator
|
||||||
|
|
||||||
|
typo_cfg = getattr(global_config, "chinese_typo", None)
|
||||||
|
self._typo_prewarm_task = asyncio.create_task(
|
||||||
|
asyncio.to_thread(
|
||||||
|
get_typo_generator,
|
||||||
|
error_rate=getattr(typo_cfg, "error_rate", 0.3),
|
||||||
|
min_freq=getattr(typo_cfg, "min_freq", 5),
|
||||||
|
tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2),
|
||||||
|
word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3),
|
||||||
|
max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug("已启动 ChineseTypoGenerator 后台预热任务")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}")
|
||||||
|
|
||||||
# 初始化数据库表结构
|
# 初始化数据库表结构
|
||||||
await self.initialize_database_async()
|
await self.initialize_database_async()
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING, Optional, cast
|
from typing import TYPE_CHECKING, Optional, Any, cast
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import desc, select, update
|
from sqlalchemy import desc, insert, select, update
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
@@ -25,29 +25,55 @@ class MessageStorageBatcher:
|
|||||||
消息存储批处理器
|
消息存储批处理器
|
||||||
|
|
||||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||||
|
2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int = 50,
|
||||||
|
flush_interval: float = 5.0,
|
||||||
|
*,
|
||||||
|
commit_batch_size: int | None = None,
|
||||||
|
commit_interval: float | None = None,
|
||||||
|
db_chunk_size: int = 200,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
初始化批处理器
|
初始化批处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_size: 批量大小,达到此数量立即写入
|
batch_size: 写入队列中触发准备阶段的消息条数
|
||||||
flush_interval: 自动刷新间隔(秒)
|
flush_interval: 自动刷新/检查间隔(秒)
|
||||||
|
commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size,至少100)
|
||||||
|
commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s))
|
||||||
|
db_chunk_size: 单次SQL语句批量写入数量上限
|
||||||
"""
|
"""
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.flush_interval = flush_interval
|
self.flush_interval = flush_interval
|
||||||
|
self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100)
|
||||||
|
self.commit_interval = commit_interval or max(flush_interval * 2, 10.0)
|
||||||
|
self.db_chunk_size = max(50, db_chunk_size)
|
||||||
|
|
||||||
self.pending_messages: deque = deque()
|
self.pending_messages: deque = deque()
|
||||||
|
self._prepared_buffer: list[dict[str, Any]] = []
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._flush_barrier = asyncio.Lock()
|
||||||
self._flush_task = None
|
self._flush_task = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._last_commit_ts = time.monotonic()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动自动刷新任务"""
|
"""启动自动刷新任务"""
|
||||||
if self._flush_task is None and not self._running:
|
if self._flush_task is None and not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
|
self._last_commit_ts = time.monotonic()
|
||||||
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
||||||
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
|
logger.info(
|
||||||
|
"消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)",
|
||||||
|
self.batch_size,
|
||||||
|
self.flush_interval,
|
||||||
|
self.commit_batch_size,
|
||||||
|
self.commit_interval,
|
||||||
|
)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""停止批处理器"""
|
"""停止批处理器"""
|
||||||
@@ -62,7 +88,7 @@ class MessageStorageBatcher:
|
|||||||
self._flush_task = None
|
self._flush_task = None
|
||||||
|
|
||||||
# 刷新剩余的消息
|
# 刷新剩余的消息
|
||||||
await self.flush()
|
await self.flush(force=True)
|
||||||
logger.info("消息存储批处理器已停止")
|
logger.info("消息存储批处理器已停止")
|
||||||
|
|
||||||
async def add_message(self, message_data: dict):
|
async def add_message(self, message_data: dict):
|
||||||
@@ -76,61 +102,82 @@ class MessageStorageBatcher:
|
|||||||
'chat_stream': ChatStream
|
'chat_stream': ChatStream
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
should_force_flush = False
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.pending_messages.append(message_data)
|
self.pending_messages.append(message_data)
|
||||||
|
|
||||||
# 如果达到批量大小,立即刷新
|
|
||||||
if len(self.pending_messages) >= self.batch_size:
|
if len(self.pending_messages) >= self.batch_size:
|
||||||
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
|
should_force_flush = True
|
||||||
await self.flush()
|
|
||||||
|
|
||||||
async def flush(self):
|
if should_force_flush:
|
||||||
"""执行批量写入"""
|
logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新")
|
||||||
async with self._lock:
|
await self.flush(force=True)
|
||||||
if not self.pending_messages:
|
|
||||||
return
|
|
||||||
|
|
||||||
messages_to_store = list(self.pending_messages)
|
async def flush(self, force: bool = False):
|
||||||
self.pending_messages.clear()
|
"""执行批量写入, 支持强制落库和延迟提交策略。"""
|
||||||
|
async with self._flush_barrier:
|
||||||
|
async with self._lock:
|
||||||
|
messages_to_store = list(self.pending_messages)
|
||||||
|
self.pending_messages.clear()
|
||||||
|
|
||||||
if not messages_to_store:
|
if messages_to_store:
|
||||||
|
prepared_messages: list[dict[str, Any]] = []
|
||||||
|
for msg_data in messages_to_store:
|
||||||
|
try:
|
||||||
|
message_dict = await self._prepare_message_dict(
|
||||||
|
msg_data["message"],
|
||||||
|
msg_data["chat_stream"],
|
||||||
|
)
|
||||||
|
if message_dict:
|
||||||
|
prepared_messages.append(message_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"准备消息数据失败: {e}")
|
||||||
|
|
||||||
|
if prepared_messages:
|
||||||
|
self._prepared_buffer.extend(prepared_messages)
|
||||||
|
|
||||||
|
await self._maybe_commit_buffer(force=force)
|
||||||
|
|
||||||
|
async def _maybe_commit_buffer(self, *, force: bool = False) -> None:
|
||||||
|
"""根据阈值/时间窗口判断是否需要真正写库。"""
|
||||||
|
if not self._prepared_buffer:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
enough_rows = len(self._prepared_buffer) >= self.commit_batch_size
|
||||||
|
waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval
|
||||||
|
|
||||||
|
if not (force or enough_rows or waited_long_enough):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._write_buffer_to_database()
|
||||||
|
|
||||||
|
async def _write_buffer_to_database(self) -> None:
|
||||||
|
payload = self._prepared_buffer
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._prepared_buffer = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
success_count = 0
|
total = len(payload)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
async with get_db_session() as session:
|
||||||
messages_dicts = []
|
for start in range(0, total, self.db_chunk_size):
|
||||||
|
chunk = payload[start : start + self.db_chunk_size]
|
||||||
for msg_data in messages_to_store:
|
if chunk:
|
||||||
try:
|
await session.execute(insert(Messages), chunk)
|
||||||
message_dict = await self._prepare_message_dict(
|
await session.commit()
|
||||||
msg_data["message"],
|
|
||||||
msg_data["chat_stream"]
|
|
||||||
)
|
|
||||||
if message_dict:
|
|
||||||
messages_dicts.append(message_dict)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"准备消息数据失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 批量写入数据库 - 使用高效的批量INSERT
|
|
||||||
if messages_dicts:
|
|
||||||
from sqlalchemy import insert
|
|
||||||
async with get_db_session() as session:
|
|
||||||
stmt = insert(Messages).values(messages_dicts)
|
|
||||||
await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
success_count = len(messages_dicts)
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
self._last_commit_ts = time.monotonic()
|
||||||
|
per_item = (elapsed / total) * 1000 if total else 0
|
||||||
logger.info(
|
logger.info(
|
||||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})"
|
||||||
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 回滚到缓冲区, 等待下一次尝试
|
||||||
|
self._prepared_buffer = payload + self._prepared_buffer
|
||||||
logger.error(f"批量存储消息失败: {e}")
|
logger.error(f"批量存储消息失败: {e}")
|
||||||
|
|
||||||
async def _prepare_message_dict(self, message, chat_stream):
|
async def _prepare_message_dict(self, message, chat_stream):
|
||||||
|
|||||||
@@ -116,6 +116,10 @@ class AutoTrainer:
|
|||||||
"interests": sorted(persona_info.get("interests", [])),
|
"interests": sorted(persona_info.get("interests", [])),
|
||||||
"dislikes": sorted(persona_info.get("dislikes", [])),
|
"dislikes": sorted(persona_info.get("dislikes", [])),
|
||||||
"personality": persona_info.get("personality", ""),
|
"personality": persona_info.get("personality", ""),
|
||||||
|
# 可选的更完整人设字段(存在则纳入哈希)
|
||||||
|
"personality_core": persona_info.get("personality_core", ""),
|
||||||
|
"personality_side": persona_info.get("personality_side", ""),
|
||||||
|
"identity": persona_info.get("identity", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 转为JSON并计算哈希
|
# 转为JSON并计算哈希
|
||||||
@@ -178,7 +182,7 @@ class AutoTrainer:
|
|||||||
self,
|
self,
|
||||||
persona_info: dict[str, Any],
|
persona_info: dict[str, Any],
|
||||||
days: int = 7,
|
days: int = 7,
|
||||||
max_samples: int = 500,
|
max_samples: int = 1000,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> tuple[bool, Path | None]:
|
) -> tuple[bool, Path | None]:
|
||||||
"""自动训练(如果需要)
|
"""自动训练(如果需要)
|
||||||
@@ -186,7 +190,7 @@ class AutoTrainer:
|
|||||||
Args:
|
Args:
|
||||||
persona_info: 人设信息
|
persona_info: 人设信息
|
||||||
days: 采样天数
|
days: 采样天数
|
||||||
max_samples: 最大采样数
|
max_samples: 最大采样数(默认1000条)
|
||||||
force: 强制训练
|
force: 强制训练
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -279,11 +283,12 @@ class AutoTrainer:
|
|||||||
"""
|
"""
|
||||||
# 检查是否已经有任务在运行
|
# 检查是否已经有任务在运行
|
||||||
if self._scheduled_task_running:
|
if self._scheduled_task_running:
|
||||||
logger.debug(f"[自动训练器] 定时任务已在运行,跳过")
|
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._scheduled_task_running = True
|
self._scheduled_task_running = True
|
||||||
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
|
||||||
|
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ class DatasetGenerator:
|
|||||||
从历史消息中采样并使用 LLM 进行标注
|
从历史消息中采样并使用 LLM 进行标注
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
|
||||||
|
HARD_MAX_SAMPLES = 2000
|
||||||
|
|
||||||
# 标注提示词模板(单条)
|
# 标注提示词模板(单条)
|
||||||
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
|
||||||
|
|
||||||
@@ -107,7 +110,7 @@ class DatasetGenerator:
|
|||||||
max_samples: int = 1000,
|
max_samples: int = 1000,
|
||||||
priority_ranges: list[tuple[float, float]] | None = None,
|
priority_ranges: list[tuple[float, float]] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""从数据库采样消息
|
"""从数据库采样消息(优化版:减少查询量和内存使用)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
days: 采样最近 N 天的消息
|
days: 采样最近 N 天的消息
|
||||||
@@ -120,40 +123,75 @@ class DatasetGenerator:
|
|||||||
"""
|
"""
|
||||||
from src.common.database.api.query import QueryBuilder
|
from src.common.database.api.query import QueryBuilder
|
||||||
from src.common.database.core.models import Messages
|
from src.common.database.core.models import Messages
|
||||||
|
from sqlalchemy import func, or_
|
||||||
|
|
||||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天")
|
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||||
|
|
||||||
|
# 限制采样数量硬上限
|
||||||
|
requested_max_samples = max_samples
|
||||||
|
if max_samples is None:
|
||||||
|
max_samples = self.HARD_MAX_SAMPLES
|
||||||
|
else:
|
||||||
|
max_samples = int(max_samples)
|
||||||
|
if max_samples <= 0:
|
||||||
|
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
|
||||||
|
return []
|
||||||
|
if max_samples > self.HARD_MAX_SAMPLES:
|
||||||
|
logger.warning(
|
||||||
|
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES},"
|
||||||
|
f"已截断为 {self.HARD_MAX_SAMPLES}"
|
||||||
|
)
|
||||||
|
max_samples = self.HARD_MAX_SAMPLES
|
||||||
|
|
||||||
# 查询条件
|
# 查询条件
|
||||||
cutoff_time = datetime.now() - timedelta(days=days)
|
cutoff_time = datetime.now() - timedelta(days=days)
|
||||||
cutoff_ts = cutoff_time.timestamp()
|
cutoff_ts = cutoff_time.timestamp()
|
||||||
|
|
||||||
|
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
|
||||||
|
# 这样可以在保证足够样本的同时减少查询量
|
||||||
|
prefetch_limit = int(max_samples * 1.5)
|
||||||
|
|
||||||
|
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
|
||||||
query_builder = QueryBuilder(Messages)
|
query_builder = QueryBuilder(Messages)
|
||||||
|
|
||||||
# 获取所有符合条件的消息(使用 as_dict 方便访问字段)
|
# 过滤条件:时间范围 + 消息文本不为空
|
||||||
messages = await query_builder.filter(
|
messages = await query_builder.filter(
|
||||||
time__gte=cutoff_ts,
|
time__gte=cutoff_ts,
|
||||||
|
).order_by(
|
||||||
|
"-time" # 按时间倒序,优先采样最新消息
|
||||||
|
).limit(
|
||||||
|
prefetch_limit # 限制预取数量
|
||||||
).all(as_dict=True)
|
).all(as_dict=True)
|
||||||
|
|
||||||
logger.info(f"查询到 {len(messages)} 条消息")
|
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit})")
|
||||||
|
|
||||||
# 过滤消息长度
|
# 过滤消息长度和提取文本
|
||||||
filtered = []
|
filtered = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
|
||||||
if text and len(text.strip()) >= min_length:
|
text = text.strip()
|
||||||
|
if text and len(text) >= min_length:
|
||||||
filtered.append({**msg, "message_text": text})
|
filtered.append({**msg, "message_text": text})
|
||||||
|
# 达到目标数量即可停止
|
||||||
|
if len(filtered) >= max_samples:
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"过滤后剩余 {len(filtered)} 条消息")
|
logger.info(f"过滤后得到 {len(filtered)} 条有效消息(目标: {max_samples})")
|
||||||
|
|
||||||
# 优先采样策略
|
# 如果过滤后数量不足,记录警告
|
||||||
if priority_ranges and len(filtered) > max_samples:
|
if len(filtered) < max_samples:
|
||||||
# 随机采样
|
logger.warning(
|
||||||
samples = random.sample(filtered, max_samples)
|
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples}),"
|
||||||
else:
|
f"可能需要扩大采样范围(增加 days 参数或降低 min_length)"
|
||||||
samples = filtered[:max_samples]
|
)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 随机打乱样本顺序(避免时间偏向)
|
||||||
|
if len(filtered) > 0:
|
||||||
|
random.shuffle(filtered)
|
||||||
|
|
||||||
|
# 转换为标准格式
|
||||||
result = []
|
result = []
|
||||||
for msg in samples:
|
for msg in filtered:
|
||||||
result.append({
|
result.append({
|
||||||
"message_id": msg.get("message_id"),
|
"message_id": msg.get("message_id"),
|
||||||
"user_id": msg.get("user_id"),
|
"user_id": msg.get("user_id"),
|
||||||
@@ -335,19 +373,50 @@ class DatasetGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
格式化后的人格描述
|
格式化后的人格描述
|
||||||
"""
|
"""
|
||||||
parts = []
|
def _stringify(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
return "、".join([str(v) for v in value if v is not None and str(v).strip()])
|
||||||
|
if isinstance(value, dict):
|
||||||
|
try:
|
||||||
|
return json.dumps(value, ensure_ascii=False, sort_keys=True)
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
return str(value).strip()
|
||||||
|
|
||||||
if "name" in persona_info:
|
parts: list[str] = []
|
||||||
parts.append(f"角色名称: {persona_info['name']}")
|
|
||||||
|
|
||||||
if "interests" in persona_info:
|
name = _stringify(persona_info.get("name"))
|
||||||
parts.append(f"兴趣点: {', '.join(persona_info['interests'])}")
|
if name:
|
||||||
|
parts.append(f"角色名称: {name}")
|
||||||
|
|
||||||
if "dislikes" in persona_info:
|
# 核心/侧面/身份等完整人设信息
|
||||||
parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}")
|
personality_core = _stringify(persona_info.get("personality_core"))
|
||||||
|
if personality_core:
|
||||||
|
parts.append(f"核心人设: {personality_core}")
|
||||||
|
|
||||||
if "personality" in persona_info:
|
personality_side = _stringify(persona_info.get("personality_side"))
|
||||||
parts.append(f"性格特点: {persona_info['personality']}")
|
if personality_side:
|
||||||
|
parts.append(f"侧面特质: {personality_side}")
|
||||||
|
|
||||||
|
identity = _stringify(persona_info.get("identity"))
|
||||||
|
if identity:
|
||||||
|
parts.append(f"身份特征: {identity}")
|
||||||
|
|
||||||
|
# 追加其他未覆盖字段(保持信息完整)
|
||||||
|
known_keys = {
|
||||||
|
"name",
|
||||||
|
"personality_core",
|
||||||
|
"personality_side",
|
||||||
|
"identity",
|
||||||
|
}
|
||||||
|
for key, value in persona_info.items():
|
||||||
|
if key in known_keys:
|
||||||
|
continue
|
||||||
|
value_str = _stringify(value)
|
||||||
|
if value_str:
|
||||||
|
parts.append(f"{key}: {value_str}")
|
||||||
|
|
||||||
return "\n".join(parts) if parts else "无特定人格设定"
|
return "\n".join(parts) if parts else "无特定人格设定"
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class TfidfFeatureExtractor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
analyzer: str = "char", # type: ignore
|
analyzer: str = "char", # type: ignore
|
||||||
ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围
|
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
|
||||||
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
|
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
|
||||||
min_df: int = 3, # 优化:过滤低频 n-gram
|
min_df: int = 3, # 优化:过滤低频 n-gram
|
||||||
max_df: float = 0.95,
|
max_df: float = 0.95,
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ class SemanticInterestModel:
|
|||||||
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
|
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
|
||||||
"""
|
"""
|
||||||
self.clf = LogisticRegression(
|
self.clf = LogisticRegression(
|
||||||
multi_class="multinomial",
|
|
||||||
solver=solver,
|
solver=solver,
|
||||||
max_iter=max_iter,
|
max_iter=max_iter,
|
||||||
class_weight=class_weight,
|
class_weight=class_weight,
|
||||||
@@ -206,7 +205,6 @@ class SemanticInterestModel:
|
|||||||
"""
|
"""
|
||||||
params = self.clf.get_params()
|
params = self.clf.get_params()
|
||||||
return {
|
return {
|
||||||
"multi_class": params["multi_class"],
|
|
||||||
"solver": params["solver"],
|
"solver": params["solver"],
|
||||||
"max_iter": params["max_iter"],
|
"max_iter": params["max_iter"],
|
||||||
"class_weight": params["class_weight"],
|
"class_weight": params["class_weight"],
|
||||||
|
|||||||
@@ -558,7 +558,7 @@ class ModelManager:
|
|||||||
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
trained, model_path = await self._auto_trainer.auto_train_if_needed(
|
||||||
persona_info=persona_info,
|
persona_info=persona_info,
|
||||||
days=7,
|
days=7,
|
||||||
max_samples=500,
|
max_samples=1000, # 初始训练使用1000条消息
|
||||||
)
|
)
|
||||||
|
|
||||||
if trained and model_path:
|
if trained and model_path:
|
||||||
@@ -607,30 +607,32 @@ class ModelManager:
|
|||||||
persona_info: 人设信息
|
persona_info: 人设信息
|
||||||
interval_hours: 检查间隔(小时)
|
interval_hours: 检查间隔(小时)
|
||||||
"""
|
"""
|
||||||
# 检查是否已经启动
|
# 使用锁防止并发启动
|
||||||
if self._auto_training_started:
|
async with self._lock:
|
||||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
# 检查是否已经启动
|
||||||
return
|
if self._auto_training_started:
|
||||||
|
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
||||||
try:
|
return
|
||||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
|
||||||
|
|
||||||
if self._auto_trainer is None:
|
try:
|
||||||
self._auto_trainer = get_auto_trainer()
|
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||||
|
|
||||||
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
if self._auto_trainer is None:
|
||||||
|
self._auto_trainer = get_auto_trainer()
|
||||||
# 标记为已启动
|
|
||||||
self._auto_training_started = True
|
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
|
||||||
|
|
||||||
# 在后台任务中运行
|
# 标记为已启动
|
||||||
asyncio.create_task(
|
self._auto_training_started = True
|
||||||
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
|
||||||
)
|
# 在后台任务中运行
|
||||||
|
asyncio.create_task(
|
||||||
except Exception as e:
|
self._auto_trainer.scheduled_train(persona_info, interval_hours)
|
||||||
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
)
|
||||||
self._auto_training_started = False # 失败时重置标志
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
|
||||||
|
self._auto_training_started = False # 失败时重置标志
|
||||||
|
|
||||||
|
|
||||||
# 单例获取函数
|
# 单例获取函数
|
||||||
|
|||||||
@@ -191,44 +191,3 @@ class SemanticInterestTrainer:
|
|||||||
|
|
||||||
return dataset_path, model_path, metrics
|
return dataset_path, model_path, metrics
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""示例:训练一个语义兴趣度模型"""
|
|
||||||
|
|
||||||
# 示例人格信息
|
|
||||||
persona_info = {
|
|
||||||
"name": "小狐",
|
|
||||||
"interests": ["动漫", "游戏", "编程", "技术", "二次元"],
|
|
||||||
"dislikes": ["广告", "政治", "无聊闲聊"],
|
|
||||||
"personality": "活泼开朗,对新鲜事物充满好奇",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建训练器
|
|
||||||
trainer = SemanticInterestTrainer()
|
|
||||||
|
|
||||||
# 执行完整训练流程
|
|
||||||
dataset_path, model_path, metrics = await trainer.full_training_pipeline(
|
|
||||||
persona_info=persona_info,
|
|
||||||
days=7, # 使用最近 7 天的消息
|
|
||||||
max_samples=500, # 采样 500 条消息
|
|
||||||
llm_model_name=None, # 使用默认 LLM
|
|
||||||
tfidf_config={
|
|
||||||
"analyzer": "char",
|
|
||||||
"ngram_range": (2, 4),
|
|
||||||
"max_features": 15000,
|
|
||||||
"min_df": 3,
|
|
||||||
},
|
|
||||||
model_config={
|
|
||||||
"class_weight": "balanced",
|
|
||||||
"max_iter": 1000,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n训练完成!")
|
|
||||||
print(f"数据集: {dataset_path}")
|
|
||||||
print(f"模型: {model_path}")
|
|
||||||
print(f"准确率: {metrics.get('test_accuracy', 0):.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class ChineseTypoGenerator:
|
|||||||
|
|
||||||
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
|
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
|
||||||
if _shared_pinyin_dict is None:
|
if _shared_pinyin_dict is None:
|
||||||
_shared_pinyin_dict = self._create_pinyin_dict()
|
_shared_pinyin_dict = self._load_or_create_pinyin_dict()
|
||||||
logger.debug("拼音字典已创建并缓存")
|
logger.debug("拼音字典已创建并缓存")
|
||||||
self.pinyin_dict = _shared_pinyin_dict
|
self.pinyin_dict = _shared_pinyin_dict
|
||||||
|
|
||||||
@@ -141,6 +141,35 @@ class ChineseTypoGenerator:
|
|||||||
|
|
||||||
return normalized_freq
|
return normalized_freq
|
||||||
|
|
||||||
|
def _load_or_create_pinyin_dict(self):
|
||||||
|
"""
|
||||||
|
加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动)
|
||||||
|
"""
|
||||||
|
cache_file = Path("depends-data/pinyin_dict.json")
|
||||||
|
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(cache_file, encoding="utf-8") as f:
|
||||||
|
data = orjson.loads(f.read())
|
||||||
|
# 恢复为 defaultdict(list) 以兼容旧逻辑
|
||||||
|
restored = defaultdict(list)
|
||||||
|
for py, chars in data.items():
|
||||||
|
restored[py] = list(chars)
|
||||||
|
return restored
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取拼音缓存失败,将重新生成: {e}")
|
||||||
|
|
||||||
|
pinyin_dict = self._create_pinyin_dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(cache_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"写入拼音缓存失败(不影响使用): {e}")
|
||||||
|
|
||||||
|
return pinyin_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_pinyin_dict():
|
def _create_pinyin_dict():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
self.backup_count = backup_count
|
self.backup_count = backup_count
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
self._current_size = 0
|
||||||
|
self._bytes_since_check = 0
|
||||||
|
self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8"))
|
||||||
|
self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024)
|
||||||
|
|
||||||
# 当前活跃的日志文件
|
# 当前活跃的日志文件
|
||||||
self.current_file = None
|
self.current_file = None
|
||||||
@@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
# 极低概率碰撞,稍作等待
|
# 极低概率碰撞,稍作等待
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
|
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
|
||||||
|
self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0
|
||||||
|
self._bytes_since_check = 0
|
||||||
|
|
||||||
def _should_rollover(self):
|
def _should_rollover(self, incoming_size: int = 0) -> bool:
|
||||||
"""检查是否需要轮转"""
|
"""检查是否需要轮转,使用内存缓存的大小信息减少磁盘stat次数。"""
|
||||||
if self.current_file and self.current_file.exists():
|
if not self.current_file:
|
||||||
return self.current_file.stat().st_size >= self.max_bytes
|
return False
|
||||||
|
|
||||||
|
projected = self._current_size + incoming_size
|
||||||
|
if projected >= self.max_bytes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
self._bytes_since_check += incoming_size
|
||||||
|
if self._bytes_since_check >= self._stat_refresh_threshold:
|
||||||
|
try:
|
||||||
|
if self.current_file.exists():
|
||||||
|
self._current_size = self.current_file.stat().st_size
|
||||||
|
else:
|
||||||
|
self._current_size = 0
|
||||||
|
except OSError:
|
||||||
|
self._current_size = 0
|
||||||
|
finally:
|
||||||
|
self._bytes_since_check = 0
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _do_rollover(self):
|
def _do_rollover(self):
|
||||||
@@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
"""发出日志记录"""
|
"""发出日志记录"""
|
||||||
try:
|
try:
|
||||||
|
message = self.format(record)
|
||||||
|
encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# 检查是否需要轮转
|
if self._should_rollover(encoded_len):
|
||||||
if self._should_rollover():
|
|
||||||
self._do_rollover()
|
self._do_rollover()
|
||||||
|
|
||||||
# 写入日志
|
|
||||||
if self.current_stream:
|
if self.current_stream:
|
||||||
msg = self.format(record)
|
self.current_stream.write(message + "\n")
|
||||||
self.current_stream.write(msg + "\n")
|
|
||||||
self.current_stream.flush()
|
self.current_stream.flush()
|
||||||
|
self._current_size += encoded_len
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self.handleError(record)
|
self.handleError(record)
|
||||||
@@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 创建全局 Rich Console 实例用于颜色渲染
|
|
||||||
_rich_console = Console(force_terminal=True, color_system="truecolor")
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleColoredConsoleRenderer:
|
class ModuleColoredConsoleRenderer:
|
||||||
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
|
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
|
||||||
|
|
||||||
@@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||||
self._colors = colors
|
self._colors = colors
|
||||||
self._config = LOG_CONFIG
|
self._config = LOG_CONFIG
|
||||||
|
self._render_console = Console(force_terminal=True, color_system="truecolor", width=999)
|
||||||
|
|
||||||
# 日志级别颜色 (#RRGGBB 格式)
|
# 日志级别颜色 (#RRGGBB 格式)
|
||||||
self._level_colors_hex = {
|
self._level_colors_hex = {
|
||||||
@@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer:
|
|||||||
self._enable_level_colors = False
|
self._enable_level_colors = False
|
||||||
self._enable_full_content_colors = False
|
self._enable_full_content_colors = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _looks_like_markup(content: str) -> bool:
|
||||||
|
"""快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。"""
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
return "[" in content and "]" in content
|
||||||
|
|
||||||
|
def _render_content_text(self, content: str, *, style: str | None = None) -> Text:
|
||||||
|
"""只在必要时解析 Rich 标记,降低CPU占用。"""
|
||||||
|
if self._looks_like_markup(content):
|
||||||
|
try:
|
||||||
|
return Text.from_markup(content, style=style)
|
||||||
|
except Exception:
|
||||||
|
return Text(content, style=style)
|
||||||
|
return Text(content, style=style)
|
||||||
|
|
||||||
def __call__(self, logger, method_name, event_dict):
|
def __call__(self, logger, method_name, event_dict):
|
||||||
# sourcery skip: merge-duplicate-blocks
|
# sourcery skip: merge-duplicate-blocks
|
||||||
"""渲染日志消息"""
|
"""渲染日志消息"""
|
||||||
@@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer:
|
|||||||
if prefix:
|
if prefix:
|
||||||
# 解析 prefix 中的 Rich 标记
|
# 解析 prefix 中的 Rich 标记
|
||||||
if module_hex_color:
|
if module_hex_color:
|
||||||
content_text.append(Text.from_markup(prefix, style=module_hex_color))
|
content_text.append(self._render_content_text(prefix, style=module_hex_color))
|
||||||
else:
|
else:
|
||||||
content_text.append(Text.from_markup(prefix))
|
content_text.append(self._render_content_text(prefix))
|
||||||
|
|
||||||
# 与"内心思考"段落之间插入空行
|
# 与"内心思考"段落之间插入空行
|
||||||
if prefix:
|
if prefix:
|
||||||
@@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer:
|
|||||||
else:
|
else:
|
||||||
# 使用 Text.from_markup 解析 Rich 标记语言
|
# 使用 Text.from_markup 解析 Rich 标记语言
|
||||||
if module_hex_color:
|
if module_hex_color:
|
||||||
try:
|
parts.append(self._render_content_text(event_content, style=module_hex_color))
|
||||||
parts.append(Text.from_markup(event_content, style=module_hex_color))
|
|
||||||
except Exception:
|
|
||||||
# 如果标记解析失败,回退到普通文本
|
|
||||||
parts.append(Text(event_content, style=module_hex_color))
|
|
||||||
else:
|
else:
|
||||||
try:
|
parts.append(self._render_content_text(event_content))
|
||||||
parts.append(Text.from_markup(event_content))
|
|
||||||
except Exception:
|
|
||||||
# 如果标记解析失败,回退到普通文本
|
|
||||||
parts.append(Text(event_content))
|
|
||||||
else:
|
else:
|
||||||
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
|
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
|
||||||
try:
|
parts.append(self._render_content_text(event_content))
|
||||||
parts.append(Text.from_markup(event_content))
|
|
||||||
except Exception:
|
|
||||||
# 如果标记解析失败,使用普通文本
|
|
||||||
parts.append(Text(event_content))
|
|
||||||
|
|
||||||
# 处理其他字段
|
# 处理其他字段
|
||||||
extras = []
|
extras = []
|
||||||
@@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer:
|
|||||||
|
|
||||||
# 使用 Rich 拼接并返回字符串
|
# 使用 Rich 拼接并返回字符串
|
||||||
result = Text(" ").join(parts)
|
result = Text(" ").join(parts)
|
||||||
# 将 Rich Text 对象转换为带 ANSI 颜色码的字符串
|
# 使用持久化 Console + capture 避免每条日志重复实例化
|
||||||
from io import StringIO
|
with self._render_console.capture() as capture:
|
||||||
string_io = StringIO()
|
self._render_console.print(result, end="")
|
||||||
temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999)
|
return capture.get()
|
||||||
temp_console.print(result, end="")
|
|
||||||
return string_io.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
# 配置标准logging以支持文件输出和压缩
|
# 配置标准logging以支持文件输出和压缩
|
||||||
|
|||||||
@@ -803,16 +803,8 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
|||||||
# 兴趣评分系统参数
|
# 兴趣评分系统参数
|
||||||
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
||||||
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
||||||
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
|
|
||||||
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
|
# 语义兴趣度评分优化参数
|
||||||
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
|
|
||||||
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
|
|
||||||
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
|
|
||||||
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
|
||||||
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
|
||||||
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
|
||||||
|
|
||||||
# 语义兴趣度评分优化参数(2024.12 新增)
|
|
||||||
use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景")
|
use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景")
|
||||||
batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理")
|
batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理")
|
||||||
batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理")
|
batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理")
|
||||||
|
|||||||
@@ -298,80 +298,105 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
logger.debug("[语义评分] 未启用语义兴趣度评分")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
# 防止并发初始化(使用锁)
|
||||||
from src.chat.semantic_interest import get_semantic_scorer
|
if not hasattr(self, '_init_lock'):
|
||||||
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async with self._init_lock:
|
||||||
|
# 双重检查
|
||||||
|
if self._semantic_initialized:
|
||||||
|
logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
# 查找最新的模型文件
|
|
||||||
model_dir = Path("data/semantic_interest/models")
|
|
||||||
if not model_dir.exists():
|
|
||||||
logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 使用模型管理器(支持人设感知)
|
|
||||||
self.model_manager = ModelManager(model_dir)
|
|
||||||
|
|
||||||
# 获取人设信息
|
|
||||||
persona_info = self._get_current_persona_info()
|
|
||||||
|
|
||||||
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
|
||||||
try:
|
try:
|
||||||
scorer = await self.model_manager.load_model(
|
from src.chat.semantic_interest import get_semantic_scorer
|
||||||
version="auto", # 自动选择或训练
|
from src.chat.semantic_interest.runtime_scorer import ModelManager
|
||||||
persona_info=persona_info
|
|
||||||
)
|
# 查找最新的模型文件
|
||||||
self.semantic_scorer = scorer
|
model_dir = Path("data/semantic_interest/models")
|
||||||
|
if not model_dir.exists():
|
||||||
|
logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 使用模型管理器(支持人设感知)
|
||||||
|
if self.model_manager is None:
|
||||||
|
self.model_manager = ModelManager(model_dir)
|
||||||
|
logger.debug("[语义评分] 模型管理器已创建")
|
||||||
|
|
||||||
|
# 获取人设信息
|
||||||
|
persona_info = self._get_current_persona_info()
|
||||||
|
|
||||||
# 如果启用批处理队列模式
|
# 先检查是否已有可用模型
|
||||||
if self._use_batch_queue:
|
|
||||||
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
|
||||||
|
|
||||||
# 确保 scorer 有 FastScorer
|
|
||||||
if scorer._fast_scorer is not None:
|
|
||||||
self._batch_queue = BatchScoringQueue(
|
|
||||||
scorer=scorer._fast_scorer,
|
|
||||||
batch_size=self._batch_size,
|
|
||||||
flush_interval_ms=self._batch_flush_interval_ms
|
|
||||||
)
|
|
||||||
await self._batch_queue.start()
|
|
||||||
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
|
|
||||||
|
|
||||||
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
|
||||||
|
|
||||||
# 设置初始化标志
|
|
||||||
self._semantic_initialized = True
|
|
||||||
|
|
||||||
# 启动自动训练任务(每24小时检查一次)
|
|
||||||
await self.model_manager.start_auto_training(
|
|
||||||
persona_info=persona_info,
|
|
||||||
interval_hours=24
|
|
||||||
)
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
|
||||||
# 触发首次训练
|
|
||||||
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
|
||||||
auto_trainer = get_auto_trainer()
|
auto_trainer = get_auto_trainer()
|
||||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
existing_model = auto_trainer.get_model_for_persona(persona_info)
|
||||||
persona_info=persona_info,
|
|
||||||
force=True # 强制训练
|
# 加载模型(自动选择合适的版本,使用单例 + FastScorer)
|
||||||
)
|
try:
|
||||||
if trained and model_path:
|
if existing_model and existing_model.exists():
|
||||||
# 使用单例获取评分器(默认启用 FastScorer)
|
# 直接加载已有模型
|
||||||
self.semantic_scorer = await get_semantic_scorer(model_path)
|
logger.info(f"[语义评分] 使用已有模型: {existing_model.name}")
|
||||||
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
scorer = await get_semantic_scorer(existing_model, use_async=True)
|
||||||
|
else:
|
||||||
|
# 使用 ModelManager 自动选择或训练
|
||||||
|
scorer = await self.model_manager.load_model(
|
||||||
|
version="auto", # 自动选择或训练
|
||||||
|
persona_info=persona_info
|
||||||
|
)
|
||||||
|
|
||||||
|
self.semantic_scorer = scorer
|
||||||
|
|
||||||
|
# 如果启用批处理队列模式
|
||||||
|
if self._use_batch_queue:
|
||||||
|
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
|
||||||
|
|
||||||
|
# 确保 scorer 有 FastScorer
|
||||||
|
if scorer._fast_scorer is not None:
|
||||||
|
self._batch_queue = BatchScoringQueue(
|
||||||
|
scorer=scorer._fast_scorer,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
flush_interval_ms=self._batch_flush_interval_ms
|
||||||
|
)
|
||||||
|
await self._batch_queue.start()
|
||||||
|
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
|
||||||
|
|
||||||
|
logger.info("[语义评分] 语义兴趣度评分器初始化成功(FastScorer优化 + 单例)")
|
||||||
|
|
||||||
# 设置初始化标志
|
# 设置初始化标志
|
||||||
self._semantic_initialized = True
|
self._semantic_initialized = True
|
||||||
else:
|
|
||||||
logger.error("[语义评分] 首次训练失败")
|
# 启动自动训练任务(每24小时检查一次)- 只在没有模型时或明确需要时启动
|
||||||
self.use_semantic_scoring = False
|
if not existing_model or not existing_model.exists():
|
||||||
|
await self.model_manager.start_auto_training(
|
||||||
|
persona_info=persona_info,
|
||||||
|
interval_hours=24
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
||||||
|
# 触发首次训练
|
||||||
|
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||||
|
persona_info=persona_info,
|
||||||
|
force=True # 强制训练
|
||||||
|
)
|
||||||
|
if trained and model_path:
|
||||||
|
# 使用单例获取评分器(默认启用 FastScorer)
|
||||||
|
self.semantic_scorer = await get_semantic_scorer(model_path)
|
||||||
|
logger.info("[语义评分] 首次训练完成,模型已加载(FastScorer优化 + 单例)")
|
||||||
|
# 设置初始化标志
|
||||||
|
self._semantic_initialized = True
|
||||||
|
else:
|
||||||
|
logger.error("[语义评分] 首次训练失败")
|
||||||
|
self.use_semantic_scoring = False
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
|
||||||
self.use_semantic_scoring = False
|
self.use_semantic_scoring = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[语义评分] 初始化失败: {e}")
|
logger.error(f"[语义评分] 初始化失败: {e}")
|
||||||
self.use_semantic_scoring = False
|
self.use_semantic_scoring = False
|
||||||
|
|
||||||
def _get_current_persona_info(self) -> dict[str, Any]:
|
def _get_current_persona_info(self) -> dict[str, Any]:
|
||||||
"""获取当前人设信息
|
"""获取当前人设信息
|
||||||
@@ -539,3 +564,5 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
|
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
afc_interest_calculator = AffinityInterestCalculator()
|
||||||
@@ -174,10 +174,10 @@ class ChatterActionPlanner:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
|
||||||
AffinityInterestCalculator,
|
afc_interest_calculator,
|
||||||
)
|
)
|
||||||
|
|
||||||
calculator = AffinityInterestCalculator()
|
calculator = afc_interest_calculator
|
||||||
if not await calculator.initialize():
|
if not await calculator.initialize():
|
||||||
logger.warning("AffinityInterestCalculator 初始化失败")
|
logger.warning("AffinityInterestCalculator 初始化失败")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||||
|
|
||||||
try:
|
|
||||||
# 延迟导入 AffinityInterestCalculator(从 core 子模块)
|
|
||||||
from .core.affinity_interest_calculator import AffinityInterestCalculator
|
|
||||||
|
|
||||||
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 延迟导入 UserProfileTool(从 tools 子模块)
|
# 延迟导入 UserProfileTool(从 tools 子模块)
|
||||||
from .tools.user_profile_tool import UserProfileTool
|
from .tools.user_profile_tool import UserProfileTool
|
||||||
|
|||||||
@@ -539,14 +539,11 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在
|
|||||||
# 兴趣评分系统参数
|
# 兴趣评分系统参数
|
||||||
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
|
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
|
||||||
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
|
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
|
||||||
high_match_interest_threshold = 0.6 # 高匹配兴趣阈值
|
|
||||||
medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值
|
# 语义兴趣度评分优化参数
|
||||||
low_match_interest_threshold = 0.2 # 低匹配兴趣阈值
|
use_batch_scoring = true # 是否启用批处理评分模式,适合高频群聊场景
|
||||||
high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率
|
batch_size = 3 # 批处理大小,达到后立即处理
|
||||||
medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率
|
batch_flush_interval_ms = 30.0 # 批处理刷新间隔(毫秒),超过后强制处理
|
||||||
low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率
|
|
||||||
match_count_bonus = 0.01 # 匹配数关键词加成值
|
|
||||||
max_match_bonus = 0.1 # 最大匹配数加成值
|
|
||||||
|
|
||||||
# 回复决策系统参数
|
# 回复决策系统参数
|
||||||
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值
|
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值
|
||||||
|
|||||||
Reference in New Issue
Block a user