Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -4,7 +4,6 @@ import binascii
|
|||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import json_repair
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
@@ -12,6 +11,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
import json_repair
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING, Optional, Any, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import desc, insert, select, update
|
from sqlalchemy import desc, insert, select, update
|
||||||
|
|||||||
@@ -1799,7 +1799,7 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != 'llm':
|
if not global_config.response_splitter.enable or global_config.response_splitter.split_mode != "llm":
|
||||||
# 移除 [SPLIT] 标记,防止消息被分割
|
# 移除 [SPLIT] 标记,防止消息被分割
|
||||||
content = content.replace("[SPLIT]", "")
|
content = content.replace("[SPLIT]", "")
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,8 @@ from datetime import datetime, timedelta
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
from src.chat.semantic_interest.trainer import SemanticInterestTrainer
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("semantic_interest.auto_trainer")
|
logger = get_logger("semantic_interest.auto_trainer")
|
||||||
|
|
||||||
@@ -78,7 +77,7 @@ class AutoTrainer:
|
|||||||
"""加载缓存的人设状态"""
|
"""加载缓存的人设状态"""
|
||||||
if self.persona_cache_file.exists():
|
if self.persona_cache_file.exists():
|
||||||
try:
|
try:
|
||||||
with open(self.persona_cache_file, "r", encoding="utf-8") as f:
|
with open(self.persona_cache_file, encoding="utf-8") as f:
|
||||||
cache = json.load(f)
|
cache = json.load(f)
|
||||||
self.last_persona_hash = cache.get("persona_hash")
|
self.last_persona_hash = cache.get("persona_hash")
|
||||||
last_train_str = cache.get("last_train_time")
|
last_train_str = cache.get("last_train_time")
|
||||||
@@ -142,7 +141,7 @@ class AutoTrainer:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
if current_hash != self.last_persona_hash:
|
if current_hash != self.last_persona_hash:
|
||||||
logger.info(f"[自动训练器] 检测到人设变化")
|
logger.info("[自动训练器] 检测到人设变化")
|
||||||
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
logger.info(f" - 旧哈希: {self.last_persona_hash[:8]}")
|
||||||
logger.info(f" - 新哈希: {current_hash[:8]}")
|
logger.info(f" - 新哈希: {current_hash[:8]}")
|
||||||
return True
|
return True
|
||||||
@@ -236,7 +235,7 @@ class AutoTrainer:
|
|||||||
# 创建"latest"符号链接
|
# 创建"latest"符号链接
|
||||||
self._create_latest_link(model_path)
|
self._create_latest_link(model_path)
|
||||||
|
|
||||||
logger.info(f"[自动训练器] 训练完成!")
|
logger.info("[自动训练器] 训练完成!")
|
||||||
logger.info(f" - 模型: {model_path.name}")
|
logger.info(f" - 模型: {model_path.name}")
|
||||||
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
logger.info(f" - 准确率: {metrics.get('test_accuracy', 0):.4f}")
|
||||||
|
|
||||||
@@ -265,7 +264,7 @@ class AutoTrainer:
|
|||||||
import shutil
|
import shutil
|
||||||
shutil.copy2(model_path, latest_path)
|
shutil.copy2(model_path, latest_path)
|
||||||
|
|
||||||
logger.info(f"[自动训练器] 已更新 latest 模型")
|
logger.info("[自动训练器] 已更新 latest 模型")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
logger.warning(f"[自动训练器] 创建 latest 链接失败: {e}")
|
||||||
@@ -283,7 +282,7 @@ class AutoTrainer:
|
|||||||
"""
|
"""
|
||||||
# 检查是否已经有任务在运行
|
# 检查是否已经有任务在运行
|
||||||
if self._scheduled_task_running:
|
if self._scheduled_task_running:
|
||||||
logger.info(f"[自动训练器] 定时任务已在运行,跳过重复启动")
|
logger.info("[自动训练器] 定时任务已在运行,跳过重复启动")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._scheduled_task_running = True
|
self._scheduled_task_running = True
|
||||||
@@ -330,10 +329,10 @@ class AutoTrainer:
|
|||||||
# 没有找到,返回 latest
|
# 没有找到,返回 latest
|
||||||
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
latest_path = self.model_dir / "semantic_interest_latest.pkl"
|
||||||
if latest_path.exists():
|
if latest_path.exists():
|
||||||
logger.debug(f"[自动训练器] 使用 latest 模型")
|
logger.debug("[自动训练器] 使用 latest 模型")
|
||||||
return latest_path
|
return latest_path
|
||||||
|
|
||||||
logger.warning(f"[自动训练器] 未找到可用模型")
|
logger.warning("[自动训练器] 未找到可用模型")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cleanup_old_models(self, keep_count: int = 5):
|
def cleanup_old_models(self, keep_count: int = 5):
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
从数据库采样消息并使用 LLM 进行兴趣度标注
|
从数据库采样消息并使用 LLM 进行兴趣度标注
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@@ -11,7 +10,6 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
logger = get_logger("semantic_interest.dataset")
|
logger = get_logger("semantic_interest.dataset")
|
||||||
|
|
||||||
@@ -111,16 +109,16 @@ class DatasetGenerator:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化 LLM 客户端"""
|
"""初始化 LLM 客户端"""
|
||||||
try:
|
try:
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
# 使用 utilities 模型配置(标注更偏工具型)
|
# 使用 utilities 模型配置(标注更偏工具型)
|
||||||
if hasattr(model_config.model_task_config, 'utils'):
|
if hasattr(model_config.model_task_config, "utils"):
|
||||||
self.model_client = LLMRequest(
|
self.model_client = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils,
|
model_set=model_config.model_task_config.utils,
|
||||||
request_type="semantic_annotation"
|
request_type="semantic_annotation"
|
||||||
)
|
)
|
||||||
logger.info(f"数据集生成器初始化完成,使用 utils 模型")
|
logger.info("数据集生成器初始化完成,使用 utils 模型")
|
||||||
else:
|
else:
|
||||||
logger.error("未找到 utils 模型配置")
|
logger.error("未找到 utils 模型配置")
|
||||||
self.model_client = None
|
self.model_client = None
|
||||||
@@ -149,9 +147,9 @@ class DatasetGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
消息样本列表
|
消息样本列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.common.database.api.query import QueryBuilder
|
from src.common.database.api.query import QueryBuilder
|
||||||
from src.common.database.core.models import Messages
|
from src.common.database.core.models import Messages
|
||||||
from sqlalchemy import func, or_
|
|
||||||
|
|
||||||
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
logger.info(f"开始采样消息,时间范围: 最近 {days} 天,目标数量: {max_samples}")
|
||||||
|
|
||||||
@@ -632,7 +630,7 @@ class DatasetGenerator:
|
|||||||
|
|
||||||
# 提取JSON内容
|
# 提取JSON内容
|
||||||
import re
|
import re
|
||||||
json_match = re.search(r'```json\s*({.*?})\s*```', response, re.DOTALL)
|
json_match = re.search(r"```json\s*({.*?})\s*```", response, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
json_str = json_match.group(1)
|
json_str = json_match.group(1)
|
||||||
else:
|
else:
|
||||||
@@ -703,7 +701,7 @@ class DatasetGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
(文本列表, 标签列表)
|
(文本列表, 标签列表)
|
||||||
"""
|
"""
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
texts = [item["message_text"] for item in data]
|
texts = [item["message_text"] for item in data]
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,15 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import joblib
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.metrics import classification_report, confusion_matrix
|
from sklearn.metrics import classification_report, confusion_matrix
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("semantic_interest.model")
|
logger = get_logger("semantic_interest.model")
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from collections import Counter
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ class FastScorer:
|
|||||||
self.total_time = 0.0
|
self.total_time = 0.0
|
||||||
|
|
||||||
# n-gram 正则(预编译)
|
# n-gram 正则(预编译)
|
||||||
self._tokenize_pattern = re.compile(r'\s+')
|
self._tokenize_pattern = re.compile(r"\s+")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_sklearn_model(
|
def from_sklearn_model(
|
||||||
@@ -139,13 +139,13 @@ class FastScorer:
|
|||||||
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
将 TF-IDF 的 idf 和 LR 的权重合并为单一的 token→weight 字典
|
||||||
"""
|
"""
|
||||||
# 获取底层 sklearn 对象
|
# 获取底层 sklearn 对象
|
||||||
if hasattr(vectorizer, 'vectorizer'):
|
if hasattr(vectorizer, "vectorizer"):
|
||||||
# TfidfFeatureExtractor 包装类
|
# TfidfFeatureExtractor 包装类
|
||||||
tfidf = vectorizer.vectorizer
|
tfidf = vectorizer.vectorizer
|
||||||
else:
|
else:
|
||||||
tfidf = vectorizer
|
tfidf = vectorizer
|
||||||
|
|
||||||
if hasattr(model, 'clf'):
|
if hasattr(model, "clf"):
|
||||||
# SemanticInterestModel 包装类
|
# SemanticInterestModel 包装类
|
||||||
clf = model.clf
|
clf = model.clf
|
||||||
else:
|
else:
|
||||||
@@ -611,7 +611,7 @@ def convert_sklearn_to_fast(
|
|||||||
|
|
||||||
# 从 vectorizer 配置推断 n-gram range
|
# 从 vectorizer 配置推断 n-gram range
|
||||||
if config is None:
|
if config is None:
|
||||||
vconfig = vectorizer.get_config() if hasattr(vectorizer, 'get_config') else {}
|
vconfig = vectorizer.get_config() if hasattr(vectorizer, "get_config") else {}
|
||||||
config = FastScorerConfig(
|
config = FastScorerConfig(
|
||||||
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
ngram_range=vconfig.get("ngram_range", (2, 4)),
|
||||||
weight_prune_threshold=1e-4,
|
weight_prune_threshold=1e-4,
|
||||||
|
|||||||
@@ -16,11 +16,10 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import joblib
|
import joblib
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
from src.chat.semantic_interest.features_tfidf import TfidfFeatureExtractor
|
||||||
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
from src.chat.semantic_interest.model_lr import SemanticInterestModel
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("semantic_interest.scorer")
|
logger = get_logger("semantic_interest.scorer")
|
||||||
|
|
||||||
@@ -611,7 +610,7 @@ class ModelManager:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
# 检查是否已经启动
|
# 检查是否已经启动
|
||||||
if self._auto_training_started:
|
if self._auto_training_started:
|
||||||
logger.debug(f"[模型管理器] 自动训练任务已启动,跳过")
|
logger.debug("[模型管理器] 自动训练任务已启动,跳过")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,16 +3,15 @@
|
|||||||
统一的训练流程入口,包含数据采样、标注、训练、评估
|
统一的训练流程入口,包含数据采样、标注、训练、评估
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
from src.chat.semantic_interest.dataset import DatasetGenerator, generate_training_dataset
|
||||||
from src.chat.semantic_interest.model_lr import train_semantic_model
|
from src.chat.semantic_interest.model_lr import train_semantic_model
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("semantic_interest.trainer")
|
logger = get_logger("semantic_interest.trainer")
|
||||||
|
|
||||||
@@ -110,7 +109,6 @@ class SemanticInterestTrainer:
|
|||||||
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
logger.info(f"开始训练模型,数据集: {dataset_path}")
|
||||||
|
|
||||||
# 加载数据集
|
# 加载数据集
|
||||||
from src.chat.semantic_interest.dataset import DatasetGenerator
|
|
||||||
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
texts, labels = DatasetGenerator.load_dataset(dataset_path)
|
||||||
|
|
||||||
# 训练模型
|
# 训练模型
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.common.data_models.database_data_model import DatabaseUserInfo
|
|||||||
|
|
||||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_and_length_messages, count_messages, find_messages
|
from src.common.message_repository import count_and_length_messages, find_messages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,10 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from enum import Enum
|
|||||||
from typing import Any, ClassVar, Literal
|
from typing import Any, ClassVar, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable, Coroutine
|
||||||
from random import choices
|
from random import choices
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
|||||||
@@ -11,11 +11,10 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import json_repair
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
|
import json_repair
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class RelationshipFetcher:
|
|||||||
kw_lower = kw.lower()
|
kw_lower = kw.lower()
|
||||||
# 排除聊天互动、情感需求等不是真实兴趣的词汇
|
# 排除聊天互动、情感需求等不是真实兴趣的词汇
|
||||||
if not any(excluded in kw_lower for excluded in [
|
if not any(excluded in kw_lower for excluded in [
|
||||||
'亲亲', '撒娇', '被宠', '被夸', '聊天', '互动', '关心', '专注', '需要'
|
"亲亲", "撒娇", "被宠", "被夸", "聊天", "互动", "关心", "专注", "需要"
|
||||||
]):
|
]):
|
||||||
filtered_keywords.append(kw)
|
filtered_keywords.append(kw)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from inspect import iscoroutinefunction
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.plugin_system.apis.logging_api import get_logger
|
from src.plugin_system.apis.logging_api import get_logger
|
||||||
from src.plugin_system.apis.permission_api import permission_api
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
from src.plugin_system.apis.send_api import text_to_stream
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 防止并发初始化(使用锁)
|
# 防止并发初始化(使用锁)
|
||||||
if not hasattr(self, '_init_lock'):
|
if not hasattr(self, "_init_lock"):
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
async with self._init_lock:
|
async with self._init_lock:
|
||||||
@@ -354,7 +354,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
logger.debug("[语义评分] 已有模型,跳过自动训练启动")
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"[语义评分] 未找到训练模型,将自动训练...")
|
logger.warning("[语义评分] 未找到训练模型,将自动训练...")
|
||||||
# 触发首次训练
|
# 触发首次训练
|
||||||
trained, model_path = await auto_trainer.auto_train_if_needed(
|
trained, model_path = await auto_trainer.auto_train_if_needed(
|
||||||
persona_info=persona_info,
|
persona_info=persona_info,
|
||||||
@@ -464,7 +464,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.info("[语义评分] 开始重新加载模型...")
|
logger.info("[语义评分] 开始重新加载模型...")
|
||||||
|
|
||||||
# 检查人设是否变化
|
# 检查人设是否变化
|
||||||
if hasattr(self, 'model_manager') and self.model_manager:
|
if hasattr(self, "model_manager") and self.model_manager:
|
||||||
persona_info = self._get_current_persona_info()
|
persona_info = self._get_current_persona_info()
|
||||||
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
reloaded = await self.model_manager.check_and_reload_for_persona(persona_info)
|
||||||
if reloaded:
|
if reloaded:
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ def build_custom_decision_module() -> str:
|
|||||||
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
custom_prompt = getattr(kfc_config, "custom_decision_prompt", "")
|
||||||
|
|
||||||
# 调试输出
|
# 调试输出
|
||||||
logger.debug(f"[自定义决策提示词] 原始值: {repr(custom_prompt)}, 类型: {type(custom_prompt)}")
|
logger.debug(f"[自定义决策提示词] 原始值: {custom_prompt!r}, 类型: {type(custom_prompt)}")
|
||||||
|
|
||||||
if not custom_prompt or not custom_prompt.strip():
|
if not custom_prompt or not custom_prompt.strip():
|
||||||
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
|
logger.debug("[自定义决策提示词] 为空或仅含空白字符,跳过")
|
||||||
|
|||||||
@@ -2,21 +2,28 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from mofox_wire import (
|
import orjson
|
||||||
MessageBuilder,
|
from mofox_wire import MessageBuilder, SegPayload
|
||||||
SegPayload,
|
|
||||||
)
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import config_api
|
from src.plugin_system.apis import config_api
|
||||||
|
|
||||||
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
|
from ...event_models import ACCEPT_FORMAT, QQ_FACE, RealMessageType
|
||||||
from ..utils import *
|
from ..utils import (
|
||||||
|
get_forward_message,
|
||||||
|
get_group_info,
|
||||||
|
get_image_base64,
|
||||||
|
get_member_info,
|
||||||
|
get_message_detail,
|
||||||
|
get_record_detail,
|
||||||
|
get_self_info,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ....plugin import NapcatAdapter
|
from ....plugin import NapcatAdapter
|
||||||
@@ -300,8 +307,7 @@ class MessageHandler:
|
|||||||
try:
|
try:
|
||||||
if file_path and Path(file_path).exists():
|
if file_path and Path(file_path).exists():
|
||||||
# 本地文件处理
|
# 本地文件处理
|
||||||
with open(file_path, "rb") as f:
|
video_data = await asyncio.to_thread(Path(file_path).read_bytes)
|
||||||
video_data = f.read()
|
|
||||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class MetaEventHandler:
|
|||||||
self.adapter = adapter
|
self.adapter = adapter
|
||||||
self.plugin_config: dict[str, Any] | None = None
|
self.plugin_config: dict[str, Any] | None = None
|
||||||
self._interval_checking = False
|
self._interval_checking = False
|
||||||
|
self._heartbeat_task: asyncio.Task | None = None
|
||||||
|
|
||||||
def set_plugin_config(self, config: dict[str, Any]) -> None:
|
def set_plugin_config(self, config: dict[str, Any]) -> None:
|
||||||
"""设置插件配置"""
|
"""设置插件配置"""
|
||||||
@@ -41,7 +42,7 @@ class MetaEventHandler:
|
|||||||
self_id = raw.get("self_id")
|
self_id = raw.get("self_id")
|
||||||
if not self._interval_checking and self_id:
|
if not self._interval_checking and self_id:
|
||||||
# 第一次收到心跳包时才启动心跳检查
|
# 第一次收到心跳包时才启动心跳检查
|
||||||
asyncio.create_task(self.check_heartbeat(self_id))
|
self._heartbeat_task = asyncio.create_task(self.check_heartbeat(self_id))
|
||||||
self.last_heart_beat = time.time()
|
self.last_heart_beat = time.time()
|
||||||
interval = raw.get("interval")
|
interval = raw.get("interval")
|
||||||
if interval:
|
if interval:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import asyncio
|
|||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import toml
|
import toml
|
||||||
@@ -139,25 +140,34 @@ class SiliconFlowIndexTTSAction(BaseAction):
|
|||||||
action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成,支持零样本语音克隆"
|
action_description = "使用SiliconFlow API进行高质量的IndexTTS语音合成,支持零样本语音克隆"
|
||||||
|
|
||||||
# 关键词配置
|
# 关键词配置
|
||||||
activation_keywords = ["克隆语音", "模仿声音", "语音合成", "indextts", "声音克隆", "语音生成", "仿声", "变声"]
|
activation_keywords: ClassVar[list[str]] = [
|
||||||
|
"克隆语音",
|
||||||
|
"模仿声音",
|
||||||
|
"语音合成",
|
||||||
|
"indextts",
|
||||||
|
"声音克隆",
|
||||||
|
"语音生成",
|
||||||
|
"仿声",
|
||||||
|
"变声",
|
||||||
|
]
|
||||||
keyword_case_sensitive = False
|
keyword_case_sensitive = False
|
||||||
|
|
||||||
# 动作参数定义
|
# 动作参数定义
|
||||||
action_parameters = {
|
action_parameters: ClassVar[dict[str, str]] = {
|
||||||
"text": "需要合成语音的文本内容,必填,应当清晰流畅",
|
"text": "需要合成语音的文本内容,必填,应当清晰流畅",
|
||||||
"speed": "语速(可选),范围0.1-3.0,默认1.0"
|
"speed": "语速(可选),范围0.1-3.0,默认1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 动作使用场景
|
# 动作使用场景
|
||||||
action_require = [
|
action_require: ClassVar[list[str]] = [
|
||||||
"当用户要求语音克隆或模仿某个声音时使用",
|
"当用户要求语音克隆或模仿某个声音时使用",
|
||||||
"当用户明确要求进行语音合成时使用",
|
"当用户明确要求进行语音合成时使用",
|
||||||
"当需要高质量语音输出时使用",
|
"当需要高质量语音输出时使用",
|
||||||
"当用户要求变声或仿声时使用"
|
"当用户要求变声或仿声时使用",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 关联类型 - 支持语音消息
|
# 关联类型 - 支持语音消息
|
||||||
associated_types = ["voice"]
|
associated_types: ClassVar[list[str]] = ["voice"]
|
||||||
|
|
||||||
async def execute(self) -> tuple[bool, str]:
|
async def execute(self) -> tuple[bool, str]:
|
||||||
"""执行SiliconFlow IndexTTS语音合成"""
|
"""执行SiliconFlow IndexTTS语音合成"""
|
||||||
@@ -258,11 +268,11 @@ class SiliconFlowTTSCommand(BaseCommand):
|
|||||||
|
|
||||||
command_name = "sf_tts"
|
command_name = "sf_tts"
|
||||||
command_description = "使用SiliconFlow IndexTTS进行语音合成"
|
command_description = "使用SiliconFlow IndexTTS进行语音合成"
|
||||||
command_aliases = ["sftts", "sf语音", "硅基语音"]
|
command_aliases: ClassVar[list[str]] = ["sftts", "sf语音", "硅基语音"]
|
||||||
|
|
||||||
command_parameters = {
|
command_parameters: ClassVar[dict[str, dict[str, object]]] = {
|
||||||
"text": {"type": str, "required": True, "description": "要合成的文本"},
|
"text": {"type": str, "required": True, "description": "要合成的文本"},
|
||||||
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"}
|
"speed": {"type": float, "required": False, "description": "语速 (0.1-3.0)"},
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]:
|
async def execute(self, text: str, speed: float = 1.0) -> tuple[bool, str]:
|
||||||
@@ -341,14 +351,14 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
|||||||
|
|
||||||
# 必需的抽象属性
|
# 必需的抽象属性
|
||||||
enable_plugin: bool = True
|
enable_plugin: bool = True
|
||||||
dependencies: list[str] = []
|
dependencies: ClassVar[list[str]] = []
|
||||||
config_file_name: str = "config.toml"
|
config_file_name: str = "config.toml"
|
||||||
|
|
||||||
# Python依赖
|
# Python依赖
|
||||||
python_dependencies = ["aiohttp>=3.8.0"]
|
python_dependencies: ClassVar[list[str]] = ["aiohttp>=3.8.0"]
|
||||||
|
|
||||||
# 配置描述
|
# 配置描述
|
||||||
config_section_descriptions = {
|
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||||
"plugin": "插件基本配置",
|
"plugin": "插件基本配置",
|
||||||
"components": "组件启用配置",
|
"components": "组件启用配置",
|
||||||
"api": "SiliconFlow API配置",
|
"api": "SiliconFlow API配置",
|
||||||
@@ -356,7 +366,7 @@ class SiliconFlowIndexTTSPlugin(BasePlugin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 配置schema
|
# 配置schema
|
||||||
config_schema = {
|
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||||
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
"config_version": ConfigField(type=str, default="2.0.0", description="配置文件版本"),
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ class VoiceUploader:
|
|||||||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||||||
|
|
||||||
# 读取音频文件并转换为base64
|
# 读取音频文件并转换为base64
|
||||||
with open(audio_path, "rb") as f:
|
audio_data = await asyncio.to_thread(audio_path.read_bytes)
|
||||||
audio_data = f.read()
|
|
||||||
|
|
||||||
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@@ -347,8 +347,10 @@ class SystemCommand(PlusCommand):
|
|||||||
return
|
return
|
||||||
|
|
||||||
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
response_parts = [f"🧩 已注册的提示词组件 (共 {len(components)} 个):"]
|
||||||
for comp in components:
|
|
||||||
response_parts.append(f"• `{comp.name}` (来自: `{comp.plugin_name}`)")
|
response_parts.extend(
|
||||||
|
[f"• `{comp.name}` (来自: `{comp.plugin_name}`)" for comp in components]
|
||||||
|
)
|
||||||
|
|
||||||
await self._send_long_message("\n".join(response_parts))
|
await self._send_long_message("\n".join(response_parts))
|
||||||
|
|
||||||
@@ -586,8 +588,10 @@ class SystemCommand(PlusCommand):
|
|||||||
|
|
||||||
for plugin_name, comps in by_plugin.items():
|
for plugin_name, comps in by_plugin.items():
|
||||||
response_parts.append(f"🔌 **{plugin_name}**:")
|
response_parts.append(f"🔌 **{plugin_name}**:")
|
||||||
for comp in comps:
|
|
||||||
response_parts.append(f" ❌ `{comp.name}` ({comp.component_type.value})")
|
response_parts.extend(
|
||||||
|
[f" ❌ `{comp.name}` ({comp.component_type.value})" for comp in comps]
|
||||||
|
)
|
||||||
|
|
||||||
await self._send_long_message("\n".join(response_parts))
|
await self._send_long_message("\n".join(response_parts))
|
||||||
|
|
||||||
|
|||||||
@@ -121,13 +121,17 @@ class SerperSearchEngine(BaseSearchEngine):
|
|||||||
|
|
||||||
# 添加有机搜索结果
|
# 添加有机搜索结果
|
||||||
if "organic" in data:
|
if "organic" in data:
|
||||||
for result in data["organic"][:num_results]:
|
results.extend(
|
||||||
results.append({
|
[
|
||||||
|
{
|
||||||
"title": result.get("title", "无标题"),
|
"title": result.get("title", "无标题"),
|
||||||
"url": result.get("link", ""),
|
"url": result.get("link", ""),
|
||||||
"snippet": result.get("snippet", ""),
|
"snippet": result.get("snippet", ""),
|
||||||
"provider": "Serper",
|
"provider": "Serper",
|
||||||
})
|
}
|
||||||
|
for result in data["organic"][:num_results]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}")
|
logger.info(f"Serper搜索成功: 查询='{query}', 结果数={len(results)}")
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ Web Search Tool Plugin
|
|||||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
||||||
from src.plugin_system.apis import config_api
|
from src.plugin_system.apis import config_api
|
||||||
@@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name: str = "web_search_tool" # 内部标识符
|
plugin_name: str = "web_search_tool" # 内部标识符
|
||||||
enable_plugin: bool = True
|
enable_plugin: bool = True
|
||||||
dependencies: list[str] = [] # 插件依赖列表
|
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""初始化插件,立即加载所有搜索引擎"""
|
"""初始化插件,立即加载所有搜索引擎"""
|
||||||
@@ -80,11 +82,14 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
|||||||
config_file_name: str = "config.toml" # 配置文件名
|
config_file_name: str = "config.toml" # 配置文件名
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
config_section_descriptions: ClassVar[dict[str, str]] = {
|
||||||
|
"plugin": "插件基本信息",
|
||||||
|
"proxy": "链接本地解析代理配置",
|
||||||
|
}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||||
config_schema: dict = {
|
config_schema: ClassVar[dict[str, dict[str, ConfigField]]] = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
|
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
|
||||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||||
|
|||||||
Reference in New Issue
Block a user