Files
Mofox-Core/src/chat/semantic_interest/features_tfidf.py
LuiKlee 8f77465bc3 ruff
2025-12-13 16:39:25 +08:00

147 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""TF-IDF 特征向量化器
使用字符级 n-gram 提取中文消息的 TF-IDF 特征
"""
from sklearn.feature_extraction.text import TfidfVectorizer
from src.common.logger import get_logger
logger = get_logger("semantic_interest.features")
class TfidfFeatureExtractor:
"""TF-IDF 特征提取器
使用字符级 n-gram 策略,适合中文/多语言场景
优化说明2024.12
- max_features 从 20000 降到 10000减少计算量
- ngram_range 默认 (2, 3),对于兴趣任务足够
- min_df 提高到 3过滤低频噪声
"""
def __init__(
self,
analyzer: str = "char", # type: ignore
ngram_range: tuple[int, int] = (2, 4), # 优化:缩小 n-gram 范围
max_features: int = 10000, # 优化:减少特征数量,矩阵大小和 dot product 减半
min_df: int = 3, # 优化:过滤低频 n-gram
max_df: float = 0.95,
):
"""初始化特征提取器
Args:
analyzer: 分析器类型 ('char''word')
ngram_range: n-gram 范围,例如 (2, 4) 表示 2~4 字符的 n-gram
max_features: 词表最大大小,防止特征爆炸
min_df: 最小文档频率,至少出现在 N 个样本中才纳入词表
max_df: 最大文档频率,出现频率超过此比例的词将被过滤(如停用词)
"""
self.vectorizer = TfidfVectorizer(
analyzer=analyzer,
ngram_range=ngram_range,
max_features=max_features,
min_df=min_df,
max_df=max_df,
lowercase=True,
strip_accents=None, # 保留中文字符
sublinear_tf=True, # 使用对数 TF 缩放
norm="l2", # L2 归一化
)
self.is_fitted = False
logger.info(
f"TF-IDF 特征提取器初始化: analyzer={analyzer}, "
f"ngram_range={ngram_range}, max_features={max_features}"
)
def fit(self, texts: list[str]) -> "TfidfFeatureExtractor":
"""训练向量化器
Args:
texts: 训练文本列表
Returns:
self
"""
logger.info(f"开始训练 TF-IDF 向量化器,样本数: {len(texts)}")
self.vectorizer.fit(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化器训练完成,词表大小: {vocab_size}")
return self
def transform(self, texts: list[str]):
"""将文本转换为 TF-IDF 向量
Args:
texts: 待转换文本列表
Returns:
稀疏矩阵
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练,请先调用 fit() 方法")
return self.vectorizer.transform(texts)
def fit_transform(self, texts: list[str]):
"""训练并转换文本
Args:
texts: 训练文本列表
Returns:
稀疏矩阵
"""
logger.info(f"开始训练并转换 TF-IDF 向量,样本数: {len(texts)}")
result = self.vectorizer.fit_transform(texts)
self.is_fitted = True
vocab_size = len(self.vectorizer.vocabulary_)
logger.info(f"TF-IDF 向量化完成,词表大小: {vocab_size}")
return result
def get_feature_names(self) -> list[str]:
"""获取特征名称列表
Returns:
特征名称列表
"""
if not self.is_fitted:
raise ValueError("向量化器尚未训练")
return self.vectorizer.get_feature_names_out().tolist()
def get_vocabulary_size(self) -> int:
"""获取词表大小
Returns:
词表大小
"""
if not self.is_fitted:
return 0
return len(self.vectorizer.vocabulary_)
def get_config(self) -> dict:
"""获取配置信息
Returns:
配置字典
"""
params = self.vectorizer.get_params()
return {
"analyzer": params["analyzer"],
"ngram_range": params["ngram_range"],
"max_features": params["max_features"],
"min_df": params["min_df"],
"max_df": params["max_df"],
"vocabulary_size": self.get_vocabulary_size() if self.is_fitted else 0,
"is_fitted": self.is_fitted,
}