feat: 提升语义兴趣评分与拼写错误生成

- 为中文拼写生成器实现了背景预热功能,以提升首次使用时的性能。
- 更新了MessageStorageBatcher以支持可配置的提交批次大小和间隔,优化数据库写入性能。
- 增强版数据集生成器,对样本规模设置硬性限制并提升采样效率。
- 将AutoTrainer中的最大样本数增加至1000,以优化训练数据利用率。
- 对亲和兴趣计算器进行了重构,以避免并发初始化并优化模型加载逻辑。
- 引入批量处理机制用于语义兴趣评分,以应对高频聊天场景。
- 更新了配置模板以反映新的评分参数,并移除了已弃用的兴趣阈值。
This commit is contained in:
Windpicker-owo
2025-12-12 14:11:36 +08:00
parent 9d01b81cef
commit e6a4f855a2
17 changed files with 433 additions and 554 deletions

1
.gitignore vendored
View File

@@ -342,3 +342,4 @@ package.json
/backup
mofox_bot_statistics.html
src/plugins/built_in/napcat_adapter/src/handlers/napcat_cache.json
depends-data/pinyin_dict.json

View File

@@ -1,282 +0,0 @@
"""语义兴趣度评分器性能测试
对比测试:
1. 原始 sklearn 路径 vs FastScorer
2. 单条评分 vs 批处理
3. 同步 vs 异步
"""
import asyncio
import time
from pathlib import Path
# 测试样本
SAMPLE_TEXTS = [
"今天天气真好",
"这个游戏太好玩了!",
"无聊死了",
"我对这个话题很感兴趣",
"能不能聊点别的",
"哇这个真的很厉害",
"你好",
"有人在吗",
"这个问题很有深度",
"随便说说",
"真是太棒了,我非常喜欢",
"算了算了不想说了",
"来聊聊最近的新闻吧",
"emmmm",
"哈哈哈哈",
"666",
]
def benchmark_sklearn_scorer(model_path: str, iterations: int = 100):
"""测试原始 sklearn 评分器"""
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
scorer = SemanticInterestScorer(model_path, use_fast_scorer=False)
scorer.load()
# 预热
for text in SAMPLE_TEXTS[:3]:
scorer.score(text)
# 单条评分测试
start = time.perf_counter()
for _ in range(iterations):
for text in SAMPLE_TEXTS:
scorer.score(text)
single_time = time.perf_counter() - start
total_single = iterations * len(SAMPLE_TEXTS)
# 批量评分测试
start = time.perf_counter()
for _ in range(iterations):
scorer.score_batch(SAMPLE_TEXTS)
batch_time = time.perf_counter() - start
total_batch = iterations * len(SAMPLE_TEXTS)
return {
"mode": "sklearn",
"single_total_time": single_time,
"single_avg_ms": single_time / total_single * 1000,
"single_qps": total_single / single_time,
"batch_total_time": batch_time,
"batch_avg_ms": batch_time / total_batch * 1000,
"batch_qps": total_batch / batch_time,
}
def benchmark_fast_scorer(model_path: str, iterations: int = 100):
"""测试 FastScorer"""
from src.chat.semantic_interest.runtime_scorer import SemanticInterestScorer
scorer = SemanticInterestScorer(model_path, use_fast_scorer=True)
scorer.load()
# 预热
for text in SAMPLE_TEXTS[:3]:
scorer.score(text)
# 单条评分测试
start = time.perf_counter()
for _ in range(iterations):
for text in SAMPLE_TEXTS:
scorer.score(text)
single_time = time.perf_counter() - start
total_single = iterations * len(SAMPLE_TEXTS)
# 批量评分测试
start = time.perf_counter()
for _ in range(iterations):
scorer.score_batch(SAMPLE_TEXTS)
batch_time = time.perf_counter() - start
total_batch = iterations * len(SAMPLE_TEXTS)
return {
"mode": "fast_scorer",
"single_total_time": single_time,
"single_avg_ms": single_time / total_single * 1000,
"single_qps": total_single / single_time,
"batch_total_time": batch_time,
"batch_avg_ms": batch_time / total_batch * 1000,
"batch_qps": total_batch / batch_time,
}
async def benchmark_async_scoring(model_path: str, iterations: int = 100):
"""测试异步评分"""
from src.chat.semantic_interest.runtime_scorer import get_semantic_scorer
scorer = await get_semantic_scorer(model_path, use_async=True)
# 预热
for text in SAMPLE_TEXTS[:3]:
await scorer.score_async(text)
# 单条异步评分
start = time.perf_counter()
for _ in range(iterations):
for text in SAMPLE_TEXTS:
await scorer.score_async(text)
single_time = time.perf_counter() - start
total_single = iterations * len(SAMPLE_TEXTS)
# 并发评分(模拟高并发场景)
start = time.perf_counter()
for _ in range(iterations):
tasks = [scorer.score_async(text) for text in SAMPLE_TEXTS]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter() - start
total_concurrent = iterations * len(SAMPLE_TEXTS)
return {
"mode": "async",
"single_total_time": single_time,
"single_avg_ms": single_time / total_single * 1000,
"single_qps": total_single / single_time,
"concurrent_total_time": concurrent_time,
"concurrent_avg_ms": concurrent_time / total_concurrent * 1000,
"concurrent_qps": total_concurrent / concurrent_time,
}
async def benchmark_batch_queue(model_path: str, iterations: int = 100):
"""测试批处理队列"""
from src.chat.semantic_interest.optimized_scorer import get_fast_scorer
queue = await get_fast_scorer(
model_path,
use_batch_queue=True,
batch_size=8,
flush_interval_ms=20.0
)
# 预热
for text in SAMPLE_TEXTS[:3]:
await queue.score(text)
# 并发提交评分请求
start = time.perf_counter()
for _ in range(iterations):
tasks = [queue.score(text) for text in SAMPLE_TEXTS]
await asyncio.gather(*tasks)
total_time = time.perf_counter() - start
total_requests = iterations * len(SAMPLE_TEXTS)
stats = queue.get_statistics()
await queue.stop()
return {
"mode": "batch_queue",
"total_time": total_time,
"avg_ms": total_time / total_requests * 1000,
"qps": total_requests / total_time,
"total_batches": stats["total_batches"],
"avg_batch_size": stats["avg_batch_size"],
}
def print_results(results: dict):
"""打印测试结果"""
print(f"\n{'='*60}")
print(f"模式: {results['mode']}")
print(f"{'='*60}")
if "single_avg_ms" in results:
print(f"单条评分: {results['single_avg_ms']:.3f} ms/条, QPS: {results['single_qps']:.1f}")
if "batch_avg_ms" in results:
print(f"批量评分: {results['batch_avg_ms']:.3f} ms/条, QPS: {results['batch_qps']:.1f}")
if "concurrent_avg_ms" in results:
print(f"并发评分: {results['concurrent_avg_ms']:.3f} ms/条, QPS: {results['concurrent_qps']:.1f}")
if "total_batches" in results:
print(f"批处理队列: {results['avg_ms']:.3f} ms/条, QPS: {results['qps']:.1f}")
print(f" 总批次: {results['total_batches']}, 平均批大小: {results['avg_batch_size']:.1f}")
async def main():
"""运行性能测试"""
import sys
# 检查模型路径
model_dir = Path("data/semantic_interest/models")
model_files = list(model_dir.glob("semantic_interest_*.pkl"))
if not model_files:
print("错误: 未找到模型文件,请先训练模型")
print(f"模型目录: {model_dir}")
sys.exit(1)
# 使用最新的模型
model_path = str(max(model_files, key=lambda p: p.stat().st_mtime))
print(f"使用模型: {model_path}")
iterations = 50 # 测试迭代次数
print(f"\n测试配置: {iterations} 次迭代, {len(SAMPLE_TEXTS)} 条样本/次")
print(f"总评分次数: {iterations * len(SAMPLE_TEXTS)}")
# 1. sklearn 原始路径
print("\n[1/4] 测试 sklearn 原始路径...")
try:
sklearn_results = benchmark_sklearn_scorer(model_path, iterations)
print_results(sklearn_results)
except Exception as e:
print(f"sklearn 测试失败: {e}")
# 2. FastScorer
print("\n[2/4] 测试 FastScorer...")
try:
fast_results = benchmark_fast_scorer(model_path, iterations)
print_results(fast_results)
except Exception as e:
print(f"FastScorer 测试失败: {e}")
# 3. 异步评分
print("\n[3/4] 测试异步评分...")
try:
async_results = await benchmark_async_scoring(model_path, iterations)
print_results(async_results)
except Exception as e:
print(f"异步测试失败: {e}")
# 4. 批处理队列
print("\n[4/4] 测试批处理队列...")
try:
queue_results = await benchmark_batch_queue(model_path, iterations)
print_results(queue_results)
except Exception as e:
print(f"批处理队列测试失败: {e}")
# 性能对比总结
print(f"\n{'='*60}")
print("性能对比总结")
print(f"{'='*60}")
try:
speedup = sklearn_results["single_avg_ms"] / fast_results["single_avg_ms"]
print(f"FastScorer vs sklearn 单条: {speedup:.2f}x 加速")
speedup = sklearn_results["batch_avg_ms"] / fast_results["batch_avg_ms"]
print(f"FastScorer vs sklearn 批量: {speedup:.2f}x 加速")
except:
pass
print("\n清理资源...")
from src.chat.semantic_interest.optimized_scorer import shutdown_global_executor, clear_fast_scorer_instances
from src.chat.semantic_interest.runtime_scorer import clear_scorer_instances
shutdown_global_executor()
clear_fast_scorer_instances()
clear_scorer_instances()
print("测试完成!")
if __name__ == "__main__":
asyncio.run(main())

20
bot.py
View File

@@ -567,6 +567,7 @@ class MaiBotMain:
def __init__(self):
self.main_system = None
self._typo_prewarm_task = None
def setup_timezone(self):
"""设置时区"""
@@ -663,6 +664,25 @@ class MaiBotMain:
async def run_async_init(self, main_system):
"""执行异步初始化步骤"""
# 后台预热中文错别字生成器,避免首次使用阻塞主流程
try:
from src.chat.utils.typo_generator import get_typo_generator
typo_cfg = getattr(global_config, "chinese_typo", None)
self._typo_prewarm_task = asyncio.create_task(
asyncio.to_thread(
get_typo_generator,
error_rate=getattr(typo_cfg, "error_rate", 0.3),
min_freq=getattr(typo_cfg, "min_freq", 5),
tone_error_rate=getattr(typo_cfg, "tone_error_rate", 0.2),
word_replace_rate=getattr(typo_cfg, "word_replace_rate", 0.3),
max_freq_diff=getattr(typo_cfg, "max_freq_diff", 200),
)
)
logger.debug("已启动 ChineseTypoGenerator 后台预热任务")
except Exception as e:
logger.debug(f"启动 ChineseTypoGenerator 预热失败(可忽略): {e}")
# 初始化数据库表结构
await self.initialize_database_async()

View File

@@ -3,10 +3,10 @@ import re
import time
import traceback
from collections import deque
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, Any, cast
import orjson
from sqlalchemy import desc, select, update
from sqlalchemy import desc, insert, select, update
from sqlalchemy.engine import CursorResult
from src.common.data_models.database_data_model import DatabaseMessages
@@ -25,29 +25,55 @@ class MessageStorageBatcher:
消息存储批处理器
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
2025-12: 增加二级缓冲区,降低 commit 频率并使用 Core 批量插入。
"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
def __init__(
self,
batch_size: int = 50,
flush_interval: float = 5.0,
*,
commit_batch_size: int | None = None,
commit_interval: float | None = None,
db_chunk_size: int = 200,
):
"""
初始化批处理器
Args:
batch_size: 批量大小,达到此数量立即写入
flush_interval: 自动刷新间隔(秒)
batch_size: 写入队列中触发准备阶段的消息条数
flush_interval: 自动刷新/检查间隔(秒)
commit_batch_size: 实际落库前需要累积的条数(默认=2x batch_size至少100
commit_interval: 降低刷盘频率的最大等待时长(默认=max(flush_interval*2, 10s)
db_chunk_size: 单次SQL语句批量写入数量上限
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.commit_batch_size = commit_batch_size or max(batch_size * 2, 100)
self.commit_interval = commit_interval or max(flush_interval * 2, 10.0)
self.db_chunk_size = max(50, db_chunk_size)
self.pending_messages: deque = deque()
self._prepared_buffer: list[dict[str, Any]] = []
self._lock = asyncio.Lock()
self._flush_barrier = asyncio.Lock()
self._flush_task = None
self._running = False
self._last_commit_ts = time.monotonic()
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None and not self._running:
self._running = True
self._last_commit_ts = time.monotonic()
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
logger.info(
"消息存储批处理器已启动 (批量大小: %s, 刷新间隔: %ss, commit批量: %s, commit间隔: %ss)",
self.batch_size,
self.flush_interval,
self.commit_batch_size,
self.commit_interval,
)
async def stop(self):
"""停止批处理器"""
@@ -62,7 +88,7 @@ class MessageStorageBatcher:
self._flush_task = None
# 刷新剩余的消息
await self.flush()
await self.flush(force=True)
logger.info("消息存储批处理器已停止")
async def add_message(self, message_data: dict):
@@ -76,61 +102,82 @@ class MessageStorageBatcher:
'chat_stream': ChatStream
}
"""
should_force_flush = False
async with self._lock:
self.pending_messages.append(message_data)
# 如果达到批量大小,立即刷新
if len(self.pending_messages) >= self.batch_size:
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
await self.flush()
should_force_flush = True
async def flush(self):
"""执行批量写入"""
async with self._lock:
if not self.pending_messages:
return
if should_force_flush:
logger.debug(f"达到批量大小 {self.batch_size},立即触发数据库刷新")
await self.flush(force=True)
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
async def flush(self, force: bool = False):
"""执行批量写入, 支持强制落库和延迟提交策略。"""
async with self._flush_barrier:
async with self._lock:
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
if not messages_to_store:
if messages_to_store:
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
await self._maybe_commit_buffer(force=force)
async def _maybe_commit_buffer(self, *, force: bool = False) -> None:
"""根据阈值/时间窗口判断是否需要真正写库。"""
if not self._prepared_buffer:
return
now = time.monotonic()
enough_rows = len(self._prepared_buffer) >= self.commit_batch_size
waited_long_enough = (now - self._last_commit_ts) >= self.commit_interval
if not (force or enough_rows or waited_long_enough):
return
await self._write_buffer_to_database()
async def _write_buffer_to_database(self) -> None:
payload = self._prepared_buffer
if not payload:
return
self._prepared_buffer = []
start_time = time.time()
success_count = 0
total = len(payload)
try:
# 🔧 优化准备字典数据而不是ORM对象使用批量INSERT
messages_dicts = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"]
)
if message_dict:
messages_dicts.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
continue
# 批量写入数据库 - 使用高效的批量INSERT
if messages_dicts:
from sqlalchemy import insert
async with get_db_session() as session:
stmt = insert(Messages).values(messages_dicts)
await session.execute(stmt)
await session.commit()
success_count = len(messages_dicts)
async with get_db_session() as session:
for start in range(0, total, self.db_chunk_size):
chunk = payload[start : start + self.db_chunk_size]
if chunk:
await session.execute(insert(Messages), chunk)
await session.commit()
elapsed = time.time() - start_time
self._last_commit_ts = time.monotonic()
per_item = (elapsed / total) * 1000 if total else 0
logger.info(
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
f"批量存储了 {total} 条消息 (耗时 {elapsed:.3f} 秒, 平均 {per_item:.2f} ms/条, chunk={self.db_chunk_size})"
)
except Exception as e:
# 回滚到缓冲区, 等待下一次尝试
self._prepared_buffer = payload + self._prepared_buffer
logger.error(f"批量存储消息失败: {e}")
async def _prepare_message_dict(self, message, chat_stream):

View File

@@ -116,6 +116,10 @@ class AutoTrainer:
"interests": sorted(persona_info.get("interests", [])),
"dislikes": sorted(persona_info.get("dislikes", [])),
"personality": persona_info.get("personality", ""),
# 可选的更完整人设字段(存在则纳入哈希)
"personality_core": persona_info.get("personality_core", ""),
"personality_side": persona_info.get("personality_side", ""),
"identity": persona_info.get("identity", ""),
}
# 转为JSON并计算哈希
@@ -178,7 +182,7 @@ class AutoTrainer:
self,
persona_info: dict[str, Any],
days: int = 7,
max_samples: int = 500,
max_samples: int = 1000,
force: bool = False,
) -> tuple[bool, Path | None]:
"""自动训练(如果需要)
@@ -186,7 +190,7 @@ class AutoTrainer:
Args:
persona_info: 人设信息
days: 采样天数
max_samples: 最大采样数
max_samples: 最大采样数默认1000条
force: 强制训练
Returns:
@@ -279,11 +283,12 @@ class AutoTrainer:
"""
# 检查是否已经有任务在运行
if self._scheduled_task_running:
logger.debug(f"[自动训练器] 定时任务已在运行,跳过")
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
return
self._scheduled_task_running = True
logger.info(f"[自动训练器] 启动定时训练任务,间隔: {interval_hours}小时")
logger.info(f"[自动训练器] 当前人设哈希: {self._calculate_persona_hash(persona_info)[:8]}")
while True:
try:

View File

@@ -22,6 +22,9 @@ class DatasetGenerator:
从历史消息中采样并使用 LLM 进行标注
"""
# 采样消息时的硬上限,避免一次采样过大导致内存/耗时问题
HARD_MAX_SAMPLES = 2000
# 标注提示词模板(单条)
ANNOTATION_PROMPT = """你是一个帮助标注消息兴趣度的专家。你需要根据人格设定判断该消息是否会引起角色的兴趣。
@@ -107,7 +110,7 @@ class DatasetGenerator:
max_samples: int = 1000,
priority_ranges: list[tuple[float, float]] | None = None,
) -> list[dict[str, Any]]:
"""从数据库采样消息
"""从数据库采样消息(优化版:减少查询量和内存使用)
Args:
days: 采样最近 N 天的消息
@@ -120,40 +123,75 @@ class DatasetGenerator:
"""
from src.common.database.api.query import QueryBuilder
from src.common.database.core.models import Messages
from sqlalchemy import func, or_
logger.info(f"开始采样消息,时间范围: 最近 {days}")
logger.info(f"开始采样消息,时间范围: 最近 {days},目标数量: {max_samples}")
# 限制采样数量硬上限
requested_max_samples = max_samples
if max_samples is None:
max_samples = self.HARD_MAX_SAMPLES
else:
max_samples = int(max_samples)
if max_samples <= 0:
logger.warning(f"max_samples={requested_max_samples} 非法,返回空样本")
return []
if max_samples > self.HARD_MAX_SAMPLES:
logger.warning(
f"max_samples={requested_max_samples} 超过硬上限 {self.HARD_MAX_SAMPLES}"
f"已截断为 {self.HARD_MAX_SAMPLES}"
)
max_samples = self.HARD_MAX_SAMPLES
# 查询条件
cutoff_time = datetime.now() - timedelta(days=days)
cutoff_ts = cutoff_time.timestamp()
# 优化策略:为了过滤掉长度不足的消息,预取 max_samples * 1.5 条
# 这样可以在保证足够样本的同时减少查询量
prefetch_limit = int(max_samples * 1.5)
# 构建优化查询:在数据库层面限制数量并按时间倒序(最新消息优先)
query_builder = QueryBuilder(Messages)
# 获取所有符合条件的消息(使用 as_dict 方便访问字段)
# 过滤条件:时间范围 + 消息文本不为空
messages = await query_builder.filter(
time__gte=cutoff_ts,
).order_by(
"-time" # 按时间倒序,优先采样最新消息
).limit(
prefetch_limit # 限制预取数量
).all(as_dict=True)
logger.info(f"查询到 {len(messages)} 条消息")
logger.info(f"预取 {len(messages)} 条消息(限制: {prefetch_limit}")
# 过滤消息长度
# 过滤消息长度和提取文本
filtered = []
for msg in messages:
text = msg.get("processed_plain_text") or msg.get("display_message") or ""
if text and len(text.strip()) >= min_length:
text = text.strip()
if text and len(text) >= min_length:
filtered.append({**msg, "message_text": text})
# 达到目标数量即可停止
if len(filtered) >= max_samples:
break
logger.info(f"过滤后剩余 {len(filtered)}消息")
logger.info(f"过滤后得到 {len(filtered)}有效消息(目标: {max_samples}")
# 优先采样策略
if priority_ranges and len(filtered) > max_samples:
# 随机采样
samples = random.sample(filtered, max_samples)
else:
samples = filtered[:max_samples]
# 如果过滤后数量不足,记录警告
if len(filtered) < max_samples:
logger.warning(
f"过滤后消息数量 ({len(filtered)}) 少于目标 ({max_samples})"
f"可能需要扩大采样范围(增加 days 参数或降低 min_length"
)
# 转换为字典格式
# 随机打乱样本顺序(避免时间偏向)
if len(filtered) > 0:
random.shuffle(filtered)
# 转换为标准格式
result = []
for msg in samples:
for msg in filtered:
result.append({
"message_id": msg.get("message_id"),
"user_id": msg.get("user_id"),
@@ -335,19 +373,50 @@ class DatasetGenerator:
Returns:
格式化后的人格描述
"""
parts = []
def _stringify(value: Any) -> str:
if value is None:
return ""
if isinstance(value, (list, tuple, set)):
return "".join([str(v) for v in value if v is not None and str(v).strip()])
if isinstance(value, dict):
try:
return json.dumps(value, ensure_ascii=False, sort_keys=True)
except Exception:
return str(value)
return str(value).strip()
if "name" in persona_info:
parts.append(f"角色名称: {persona_info['name']}")
parts: list[str] = []
if "interests" in persona_info:
parts.append(f"兴趣点: {', '.join(persona_info['interests'])}")
name = _stringify(persona_info.get("name"))
if name:
parts.append(f"角色名称: {name}")
if "dislikes" in persona_info:
parts.append(f"厌恶点: {', '.join(persona_info['dislikes'])}")
# 核心/侧面/身份等完整人设信息
personality_core = _stringify(persona_info.get("personality_core"))
if personality_core:
parts.append(f"核心人设: {personality_core}")
if "personality" in persona_info:
parts.append(f"性格特点: {persona_info['personality']}")
personality_side = _stringify(persona_info.get("personality_side"))
if personality_side:
parts.append(f"侧面特质: {personality_side}")
identity = _stringify(persona_info.get("identity"))
if identity:
parts.append(f"身份特征: {identity}")
# 追加其他未覆盖字段(保持信息完整)
known_keys = {
"name",
"personality_core",
"personality_side",
"identity",
}
for key, value in persona_info.items():
if key in known_keys:
continue
value_str = _stringify(value)
if value_str:
parts.append(f"{key}: {value_str}")
return "\n".join(parts) if parts else "无特定人格设定"

View File

@@ -26,7 +26,7 @@ class TfidfFeatureExtractor:
def __init__(
self,
analyzer: str = "char", # type: ignore
ngram_range: tuple[int, int] = (2, 3), # 优化:缩小 n-gram 范围
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
min_df: int = 3, # 优化:过滤低频 n-gram
max_df: float = 0.95,

View File

@@ -44,7 +44,6 @@ class SemanticInterestModel:
n_jobs: 并行任务数,-1 表示使用所有 CPU 核心
"""
self.clf = LogisticRegression(
multi_class="multinomial",
solver=solver,
max_iter=max_iter,
class_weight=class_weight,
@@ -206,7 +205,6 @@ class SemanticInterestModel:
"""
params = self.clf.get_params()
return {
"multi_class": params["multi_class"],
"solver": params["solver"],
"max_iter": params["max_iter"],
"class_weight": params["class_weight"],

View File

@@ -558,7 +558,7 @@ class ModelManager:
trained, model_path = await self._auto_trainer.auto_train_if_needed(
persona_info=persona_info,
days=7,
max_samples=500,
max_samples=1000, # 初始训练使用1000条消息
)
if trained and model_path:
@@ -607,30 +607,32 @@ class ModelManager:
persona_info: 人设信息
interval_hours: 检查间隔(小时)
"""
# 检查是否已经启动
if self._auto_training_started:
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
return
try:
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
# 使用锁防止并发启动
async with self._lock:
# 检查是否已经启动
if self._auto_training_started:
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
return
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
# 标记为已启动
self._auto_training_started = True
# 在后台任务中运行
asyncio.create_task(
self._auto_trainer.scheduled_train(persona_info, interval_hours)
)
except Exception as e:
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
self._auto_training_started = False # 失败时重置标志
try:
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
if self._auto_trainer is None:
self._auto_trainer = get_auto_trainer()
logger.info(f"[模型管理器] 启动自动训练任务,间隔: {interval_hours}小时")
# 标记为已启动
self._auto_training_started = True
# 在后台任务中运行
asyncio.create_task(
self._auto_trainer.scheduled_train(persona_info, interval_hours)
)
except Exception as e:
logger.error(f"[模型管理器] 启动自动训练失败: {e}")
self._auto_training_started = False # 失败时重置标志
# 单例获取函数

View File

@@ -191,44 +191,3 @@ class SemanticInterestTrainer:
return dataset_path, model_path, metrics
async def main():
"""示例:训练一个语义兴趣度模型"""
# 示例人格信息
persona_info = {
"name": "小狐",
"interests": ["动漫", "游戏", "编程", "技术", "二次元"],
"dislikes": ["广告", "政治", "无聊闲聊"],
"personality": "活泼开朗,对新鲜事物充满好奇",
}
# 创建训练器
trainer = SemanticInterestTrainer()
# 执行完整训练流程
dataset_path, model_path, metrics = await trainer.full_training_pipeline(
persona_info=persona_info,
days=7, # 使用最近 7 天的消息
max_samples=500, # 采样 500 条消息
llm_model_name=None, # 使用默认 LLM
tfidf_config={
"analyzer": "char",
"ngram_range": (2, 4),
"max_features": 15000,
"min_df": 3,
},
model_config={
"class_weight": "balanced",
"max_iter": 1000,
},
)
print(f"\n训练完成!")
print(f"数据集: {dataset_path}")
print(f"模型: {model_path}")
print(f"准确率: {metrics.get('test_accuracy', 0):.4f}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -96,7 +96,7 @@ class ChineseTypoGenerator:
# 🔧 内存优化:复用全局缓存的拼音字典和字频数据
if _shared_pinyin_dict is None:
_shared_pinyin_dict = self._create_pinyin_dict()
_shared_pinyin_dict = self._load_or_create_pinyin_dict()
logger.debug("拼音字典已创建并缓存")
self.pinyin_dict = _shared_pinyin_dict
@@ -141,6 +141,35 @@ class ChineseTypoGenerator:
return normalized_freq
def _load_or_create_pinyin_dict(self):
"""
加载或创建拼音到汉字映射字典(磁盘缓存加速冷启动)
"""
cache_file = Path("depends-data/pinyin_dict.json")
if cache_file.exists():
try:
with open(cache_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
# 恢复为 defaultdict(list) 以兼容旧逻辑
restored = defaultdict(list)
for py, chars in data.items():
restored[py] = list(chars)
return restored
except Exception as e:
logger.warning(f"读取拼音缓存失败,将重新生成: {e}")
pinyin_dict = self._create_pinyin_dict()
try:
cache_file.parent.mkdir(parents=True, exist_ok=True)
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(dict(pinyin_dict), option=orjson.OPT_INDENT_2).decode("utf-8"))
except Exception as e:
logger.warning(f"写入拼音缓存失败(不影响使用): {e}")
return pinyin_dict
@staticmethod
def _create_pinyin_dict():
"""

View File

@@ -1,6 +1,7 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
import os
import tarfile
import threading
import time
@@ -189,6 +190,10 @@ class TimestampedFileHandler(logging.Handler):
self.backup_count = backup_count
self.encoding = encoding
self._lock = threading.Lock()
self._current_size = 0
self._bytes_since_check = 0
self._newline_bytes = len(os.linesep.encode(self.encoding or "utf-8"))
self._stat_refresh_threshold = max(self.max_bytes // 8, 256 * 1024)
# 当前活跃的日志文件
self.current_file = None
@@ -207,11 +212,29 @@ class TimestampedFileHandler(logging.Handler):
# 极低概率碰撞,稍作等待
time.sleep(0.001)
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
self._current_size = self.current_file.stat().st_size if self.current_file.exists() else 0
self._bytes_since_check = 0
def _should_rollover(self):
"""检查是否需要轮转"""
if self.current_file and self.current_file.exists():
return self.current_file.stat().st_size >= self.max_bytes
def _should_rollover(self, incoming_size: int = 0) -> bool:
"""检查是否需要轮转使用内存缓存的大小信息减少磁盘stat次数。"""
if not self.current_file:
return False
projected = self._current_size + incoming_size
if projected >= self.max_bytes:
return True
self._bytes_since_check += incoming_size
if self._bytes_since_check >= self._stat_refresh_threshold:
try:
if self.current_file.exists():
self._current_size = self.current_file.stat().st_size
else:
self._current_size = 0
except OSError:
self._current_size = 0
finally:
self._bytes_since_check = 0
return False
def _do_rollover(self):
@@ -270,16 +293,17 @@ class TimestampedFileHandler(logging.Handler):
def emit(self, record):
"""发出日志记录"""
try:
message = self.format(record)
encoded_len = len(message.encode(self.encoding or "utf-8")) + self._newline_bytes
with self._lock:
# 检查是否需要轮转
if self._should_rollover():
if self._should_rollover(encoded_len):
self._do_rollover()
# 写入日志
if self.current_stream:
msg = self.format(record)
self.current_stream.write(msg + "\n")
self.current_stream.write(message + "\n")
self.current_stream.flush()
self._current_size += encoded_len
except Exception:
self.handleError(record)
@@ -837,10 +861,6 @@ DEFAULT_MODULE_ALIASES = {
}
# 创建全局 Rich Console 实例用于颜色渲染
_rich_console = Console(force_terminal=True, color_system="truecolor")
class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,使用 Rich 库原生支持 hex 颜色"""
@@ -848,6 +868,7 @@ class ModuleColoredConsoleRenderer:
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors
self._config = LOG_CONFIG
self._render_console = Console(force_terminal=True, color_system="truecolor", width=999)
# 日志级别颜色 (#RRGGBB 格式)
self._level_colors_hex = {
@@ -876,6 +897,22 @@ class ModuleColoredConsoleRenderer:
self._enable_level_colors = False
self._enable_full_content_colors = False
@staticmethod
def _looks_like_markup(content: str) -> bool:
"""快速判断内容里是否包含 Rich 标记,避免不必要的解析开销。"""
if not content:
return False
return "[" in content and "]" in content
def _render_content_text(self, content: str, *, style: str | None = None) -> Text:
"""只在必要时解析 Rich 标记降低CPU占用。"""
if self._looks_like_markup(content):
try:
return Text.from_markup(content, style=style)
except Exception:
return Text(content, style=style)
return Text(content, style=style)
def __call__(self, logger, method_name, event_dict):
# sourcery skip: merge-duplicate-blocks
"""渲染日志消息"""
@@ -966,9 +1003,9 @@ class ModuleColoredConsoleRenderer:
if prefix:
# 解析 prefix 中的 Rich 标记
if module_hex_color:
content_text.append(Text.from_markup(prefix, style=module_hex_color))
content_text.append(self._render_content_text(prefix, style=module_hex_color))
else:
content_text.append(Text.from_markup(prefix))
content_text.append(self._render_content_text(prefix))
# 与"内心思考"段落之间插入空行
if prefix:
@@ -983,24 +1020,12 @@ class ModuleColoredConsoleRenderer:
else:
# 使用 Text.from_markup 解析 Rich 标记语言
if module_hex_color:
try:
parts.append(Text.from_markup(event_content, style=module_hex_color))
except Exception:
# 如果标记解析失败,回退到普通文本
parts.append(Text(event_content, style=module_hex_color))
parts.append(self._render_content_text(event_content, style=module_hex_color))
else:
try:
parts.append(Text.from_markup(event_content))
except Exception:
# 如果标记解析失败,回退到普通文本
parts.append(Text(event_content))
parts.append(self._render_content_text(event_content))
else:
# 即使在非 full 模式下,也尝试解析 Rich 标记(但不应用颜色)
try:
parts.append(Text.from_markup(event_content))
except Exception:
# 如果标记解析失败,使用普通文本
parts.append(Text(event_content))
parts.append(self._render_content_text(event_content))
# 处理其他字段
extras = []
@@ -1029,12 +1054,10 @@ class ModuleColoredConsoleRenderer:
# 使用 Rich 拼接并返回字符串
result = Text(" ").join(parts)
# 将 Rich Text 对象转换为带 ANSI 颜色码的字符串
from io import StringIO
string_io = StringIO()
temp_console = Console(file=string_io, force_terminal=True, color_system="truecolor", width=999)
temp_console.print(result, end="")
return string_io.getvalue()
# 使用持久化 Console + capture 避免每条日志重复实例化
with self._render_console.capture() as capture:
self._render_console.print(result, end="")
return capture.get()
# 配置标准logging以支持文件输出和压缩

View File

@@ -803,16 +803,8 @@ class AffinityFlowConfig(ValidatedConfigBase):
# 兴趣评分系统参数
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
# 语义兴趣度评分优化参数2024.12 新增)
# 语义兴趣度评分优化参数
use_batch_scoring: bool = Field(default=False, description="是否启用批处理评分模式,适合高频群聊场景")
batch_size: int = Field(default=8, ge=1, le=64, description="批处理大小,达到后立即处理")
batch_flush_interval_ms: float = Field(default=30.0, ge=10.0, le=200.0, description="批处理刷新间隔(毫秒),超过后强制处理")

View File

@@ -298,80 +298,105 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug("[语义评分] 未启用语义兴趣度评分")
return
try:
from src.chat.semantic_interest import get_semantic_scorer
from src.chat.semantic_interest.runtime_scorer import ModelManager
# 防止并发初始化(使用锁)
if not hasattr(self, '_init_lock'):
self._init_lock = asyncio.Lock()
async with self._init_lock:
# 双重检查
if self._semantic_initialized:
logger.debug("[语义评分] 评分器已在其他任务中初始化,跳过")
return
# 查找最新的模型文件
model_dir = Path("data/semantic_interest/models")
if not model_dir.exists():
logger.warning(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
model_dir.mkdir(parents=True, exist_ok=True)
# 使用模型管理器(支持人设感知)
self.model_manager = ModelManager(model_dir)
# 获取人设信息
persona_info = self._get_current_persona_info()
# 加载模型(自动选择合适的版本,使用单例 + FastScorer
try:
scorer = await self.model_manager.load_model(
version="auto", # 自动选择或训练
persona_info=persona_info
)
self.semantic_scorer = scorer
from src.chat.semantic_interest import get_semantic_scorer
from src.chat.semantic_interest.runtime_scorer import ModelManager
# 查找最新的模型文件
model_dir = Path("data/semantic_interest/models")
if not model_dir.exists():
logger.info(f"[语义评分] 模型目录不存在,已创建: {model_dir}")
model_dir.mkdir(parents=True, exist_ok=True)
# 使用模型管理器(支持人设感知)
if self.model_manager is None:
self.model_manager = ModelManager(model_dir)
logger.debug("[语义评分] 模型管理器已创建")
# 获取人设信息
persona_info = self._get_current_persona_info()
# 如果启用批处理队列模式
if self._use_batch_queue:
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
# 确保 scorer 有 FastScorer
if scorer._fast_scorer is not None:
self._batch_queue = BatchScoringQueue(
scorer=scorer._fast_scorer,
batch_size=self._batch_size,
flush_interval_ms=self._batch_flush_interval_ms
)
await self._batch_queue.start()
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
logger.info("[语义评分] 语义兴趣度评分器初始化成功FastScorer优化 + 单例)")
# 设置初始化标志
self._semantic_initialized = True
# 启动自动训练任务每24小时检查一次
await self.model_manager.start_auto_training(
persona_info=persona_info,
interval_hours=24
)
except FileNotFoundError:
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
# 触发首次训练
# 先检查是否已有可用模型
from src.chat.semantic_interest.auto_trainer import get_auto_trainer
auto_trainer = get_auto_trainer()
trained, model_path = await auto_trainer.auto_train_if_needed(
persona_info=persona_info,
force=True # 强制训练
)
if trained and model_path:
# 使用单例获取评分器(默认启用 FastScorer
self.semantic_scorer = await get_semantic_scorer(model_path)
logger.info("[语义评分] 首次训练完成模型已加载FastScorer优化 + 单例)")
existing_model = auto_trainer.get_model_for_persona(persona_info)
# 加载模型(自动选择合适的版本,使用单例 + FastScorer
try:
if existing_model and existing_model.exists():
# 直接加载已有模型
logger.info(f"[语义评分] 使用已有模型: {existing_model.name}")
scorer = await get_semantic_scorer(existing_model, use_async=True)
else:
# 使用 ModelManager 自动选择或训练
scorer = await self.model_manager.load_model(
version="auto", # 自动选择或训练
persona_info=persona_info
)
self.semantic_scorer = scorer
# 如果启用批处理队列模式
if self._use_batch_queue:
from src.chat.semantic_interest.optimized_scorer import BatchScoringQueue
# 确保 scorer 有 FastScorer
if scorer._fast_scorer is not None:
self._batch_queue = BatchScoringQueue(
scorer=scorer._fast_scorer,
batch_size=self._batch_size,
flush_interval_ms=self._batch_flush_interval_ms
)
await self._batch_queue.start()
logger.info(f"[语义评分] 批处理队列已启动 (batch_size={self._batch_size}, interval={self._batch_flush_interval_ms}ms)")
logger.info("[语义评分] 语义兴趣度评分器初始化成功FastScorer优化 + 单例)")
# 设置初始化标志
self._semantic_initialized = True
else:
logger.error("[语义评分] 首次训练失败")
self.use_semantic_scoring = False
# 启动自动训练任务每24小时检查一次- 只在没有模型时或明确需要时启动
if not existing_model or not existing_model.exists():
await self.model_manager.start_auto_training(
persona_info=persona_info,
interval_hours=24
)
else:
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
except FileNotFoundError:
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
# 触发首次训练
trained, model_path = await auto_trainer.auto_train_if_needed(
persona_info=persona_info,
force=True # 强制训练
)
if trained and model_path:
# 使用单例获取评分器(默认启用 FastScorer
self.semantic_scorer = await get_semantic_scorer(model_path)
logger.info("[语义评分] 首次训练完成模型已加载FastScorer优化 + 单例)")
# 设置初始化标志
self._semantic_initialized = True
else:
logger.error("[语义评分] 首次训练失败")
self.use_semantic_scoring = False
except ImportError:
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
self.use_semantic_scoring = False
except Exception as e:
logger.error(f"[语义评分] 初始化失败: {e}")
self.use_semantic_scoring = False
except ImportError:
logger.warning("[语义评分] 无法导入语义兴趣度模块,将禁用语义评分")
self.use_semantic_scoring = False
except Exception as e:
logger.error(f"[语义评分] 初始化失败: {e}")
self.use_semantic_scoring = False
def _get_current_persona_info(self) -> dict[str, Any]:
"""获取当前人设信息
@@ -539,3 +564,5 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.debug(
f"[回复后机制] 未回复消息,剩余降低次数: {self.post_reply_boost_remaining}"
)
afc_interest_calculator = AffinityInterestCalculator()

View File

@@ -174,10 +174,10 @@ class ChatterActionPlanner:
try:
from src.plugins.built_in.affinity_flow_chatter.core.affinity_interest_calculator import (
AffinityInterestCalculator,
afc_interest_calculator,
)
calculator = AffinityInterestCalculator()
calculator = afc_interest_calculator
if not await calculator.initialize():
logger.warning("AffinityInterestCalculator 初始化失败")
return None

View File

@@ -46,14 +46,6 @@ class AffinityChatterPlugin(BasePlugin):
except Exception as e:
logger.error(f"加载 AffinityChatter 时出错: {e}")
try:
# 延迟导入 AffinityInterestCalculator从 core 子模块)
from .core.affinity_interest_calculator import AffinityInterestCalculator
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
except Exception as e:
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
try:
# 延迟导入 UserProfileTool从 tools 子模块)
from .tools.user_profile_tool import UserProfileTool

View File

@@ -539,14 +539,11 @@ enable_normal_mode = true # 是否启用 Normal 聊天模式。启用后,在
# 兴趣评分系统参数
reply_action_interest_threshold = 0.75 # 回复动作兴趣阈值
non_reply_action_interest_threshold = 0.65 # 非回复动作兴趣阈值
high_match_interest_threshold = 0.6 # 高匹配兴趣阈值
medium_match_interest_threshold = 0.4 # 中匹配兴趣阈值
low_match_interest_threshold = 0.2 # 低匹配兴趣阈值
high_match_keyword_multiplier = 4 # 高匹配关键词兴趣倍率
medium_match_keyword_multiplier = 2.5 # 中匹配关键词兴趣倍率
low_match_keyword_multiplier = 1.15 # 低匹配关键词兴趣倍率
match_count_bonus = 0.01 # 匹配数关键词加成值
max_match_bonus = 0.1 # 最大匹配数加成值
# 语义兴趣度评分优化参数
use_batch_scoring = true # 是否启用批处理评分模式,适合高频群聊场景
batch_size = 3 # 批处理大小,达到后立即处理
batch_flush_interval_ms = 30.0 # 批处理刷新间隔(毫秒),超过后强制处理
# 回复决策系统参数
no_reply_threshold_adjustment = 0.02 # 不回复兴趣阈值调整值