293 lines
7.9 KiB
Python
293 lines
7.9 KiB
Python
"""
|
||
表达系统工具函数
|
||
提供消息过滤、文本相似度计算、加权随机抽样等功能
|
||
"""
|
||
import difflib
|
||
import random
|
||
import re
|
||
from typing import Any
|
||
|
||
try:
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity
|
||
|
||
HAS_SKLEARN = True
|
||
except Exception: # pragma: no cover - 依赖缺失时静默回退
|
||
HAS_SKLEARN = False
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("express_utils")
|
||
|
||
|
||
# 预编译正则,减少重复编译开销
|
||
_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*")
|
||
_RE_AT = re.compile(r"@<[^>]*>")
|
||
_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]")
|
||
_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]")
|
||
|
||
|
||
def filter_message_content(content: str | None) -> str:
|
||
"""
|
||
过滤消息内容,移除回复、@、图片等格式
|
||
|
||
Args:
|
||
content: 原始消息内容
|
||
|
||
Returns:
|
||
过滤后的纯文本内容
|
||
"""
|
||
if not content:
|
||
return ""
|
||
|
||
# 使用预编译正则提升性能
|
||
content = _RE_REPLY.sub("", content)
|
||
content = _RE_AT.sub("", content)
|
||
content = _RE_IMAGE.sub("", content)
|
||
content = _RE_EMOJI.sub("", content)
|
||
|
||
return content.strip()
|
||
|
||
|
||
def _similarity_tfidf(text1: str, text2: str) -> float | None:
|
||
"""使用 TF-IDF + 余弦相似度;依赖 sklearn,缺失则返回 None。"""
|
||
if not HAS_SKLEARN:
|
||
return None
|
||
# 过短文本用传统算法更稳健
|
||
if len(text1) < 2 or len(text2) < 2:
|
||
return None
|
||
try:
|
||
vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2))
|
||
tfidf = vec.fit_transform([text1, text2])
|
||
sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0])
|
||
return max(0.0, min(1.0, sim))
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float:
|
||
"""
|
||
计算两个文本的相似度,返回0-1之间的值
|
||
|
||
- 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒)
|
||
- 不可用或失败时回退到 SequenceMatcher
|
||
|
||
Args:
|
||
text1: 第一个文本
|
||
text2: 第二个文本
|
||
prefer_vector: 是否优先使用向量化方案(默认是)
|
||
|
||
Returns:
|
||
相似度值 (0-1)
|
||
"""
|
||
if not text1 or not text2:
|
||
return 0.0
|
||
if text1 == text2:
|
||
return 1.0
|
||
|
||
if prefer_vector:
|
||
sim = _similarity_tfidf(text1, text2)
|
||
if sim is not None:
|
||
return sim
|
||
|
||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||
|
||
|
||
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
|
||
"""
|
||
加权随机抽样函数
|
||
|
||
Args:
|
||
population: 待抽样的数据列表
|
||
k: 抽样数量
|
||
weight_key: 权重字段名,如果为None则等概率抽样
|
||
|
||
Returns:
|
||
抽样结果列表
|
||
"""
|
||
if not population or k <= 0:
|
||
return []
|
||
|
||
if len(population) <= k:
|
||
return population.copy()
|
||
|
||
# 如果指定了权重字段
|
||
if weight_key and all(weight_key in item for item in population):
|
||
try:
|
||
# 获取权重
|
||
weights = [float(item.get(weight_key, 1.0)) for item in population]
|
||
# 使用random.choices进行加权抽样
|
||
return random.choices(population, weights=weights, k=k)
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||
|
||
# 等概率抽样(无放回,保持去重)
|
||
population_copy = population.copy()
|
||
# 使用 random.sample 提升可读性和性能
|
||
return random.sample(population_copy, k)
|
||
|
||
|
||
def normalize_text(text: str) -> str:
|
||
"""
|
||
标准化文本,移除多余空白字符
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
标准化后的文本
|
||
"""
|
||
# 替换多个连续空白字符为单个空格
|
||
text = re.sub(r"\s+", " ", text)
|
||
return text.strip()
|
||
|
||
|
||
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||
"""
|
||
简单的关键词提取(基于词频)
|
||
|
||
Args:
|
||
text: 输入文本
|
||
max_keywords: 最大关键词数量
|
||
|
||
Returns:
|
||
关键词列表
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
try:
|
||
import rjieba.analyse
|
||
|
||
# 使用TF-IDF提取关键词
|
||
keywords = rjieba.analyse.extract_tags(text, topK=max_keywords)
|
||
return keywords
|
||
except ImportError:
|
||
logger.warning("rjieba未安装,无法提取关键词")
|
||
# 简单分词,按长度降序优先输出较长词,提升粗略关键词质量
|
||
words = text.split()
|
||
words.sort(key=len, reverse=True)
|
||
return words[:max_keywords]
|
||
|
||
|
||
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
|
||
"""
|
||
格式化表达方式对
|
||
|
||
Args:
|
||
situation: 情境
|
||
style: 风格
|
||
index: 序号(可选)
|
||
|
||
Returns:
|
||
格式化后的字符串
|
||
"""
|
||
if index is not None:
|
||
return f'{index}. 当"{situation}"时,使用"{style}"'
|
||
else:
|
||
return f'当"{situation}"时,使用"{style}"'
|
||
|
||
|
||
def parse_expression_pair(text: str) -> tuple[str, str] | None:
|
||
"""
|
||
解析表达方式对文本
|
||
|
||
Args:
|
||
text: 格式化的表达方式对文本
|
||
|
||
Returns:
|
||
(situation, style) 或 None
|
||
"""
|
||
# 匹配格式:当"..."时,使用"..."
|
||
match = re.search(r'当"(.+?)"时,使用"(.+?)"', text)
|
||
if match:
|
||
return match.group(1), match.group(2)
|
||
return None
|
||
|
||
|
||
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
|
||
"""
|
||
批量去重表达方式
|
||
|
||
Args:
|
||
expressions: 表达方式列表
|
||
key_fields: 用于去重的字段名列表
|
||
|
||
Returns:
|
||
去重后的表达方式列表
|
||
"""
|
||
seen = set()
|
||
unique_expressions = []
|
||
|
||
for expr in expressions:
|
||
# 构建去重key
|
||
key_values = tuple(expr.get(field, "") for field in key_fields)
|
||
|
||
if key_values not in seen:
|
||
seen.add(key_values)
|
||
unique_expressions.append(expr)
|
||
|
||
return unique_expressions
|
||
|
||
|
||
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
|
||
"""
|
||
根据时间计算权重(时间衰减)
|
||
|
||
Args:
|
||
last_active_time: 最后活跃时间戳
|
||
current_time: 当前时间戳
|
||
half_life_days: 半衰期天数
|
||
|
||
Returns:
|
||
权重值 (0-1)
|
||
"""
|
||
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
|
||
if time_diff_days < 0:
|
||
return 1.0
|
||
|
||
# 使用指数衰减公式
|
||
decay_rate = 0.693 / half_life_days # ln(2) / half_life
|
||
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
|
||
|
||
return weight
|
||
|
||
|
||
def merge_expressions_from_multiple_chats(
|
||
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
|
||
) -> list[dict[str, Any]]:
|
||
"""
|
||
合并多个聊天室的表达方式
|
||
|
||
Args:
|
||
expressions_dict: {chat_id: [expressions]}
|
||
max_total: 最大合并数量
|
||
|
||
Returns:
|
||
合并后的表达方式列表
|
||
"""
|
||
all_expressions = []
|
||
|
||
# 收集所有表达方式
|
||
for chat_id, expressions in expressions_dict.items():
|
||
for expr in expressions:
|
||
expr_with_source = expr.copy()
|
||
expr_with_source["source_id"] = chat_id
|
||
all_expressions.append(expr_with_source)
|
||
|
||
if not all_expressions:
|
||
return []
|
||
|
||
# 选择排序键(优先 count,其次 last_active_time),无则保持原序
|
||
sample = all_expressions[0]
|
||
if "count" in sample:
|
||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||
elif "last_active_time" in sample:
|
||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||
|
||
# 去重(基于situation和style)
|
||
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
|
||
|
||
# 限制数量
|
||
return all_expressions[:max_total]
|