Files
Mofox-Core/src/chat/express/express_utils.py
2025-12-16 11:38:56 +08:00

293 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
表达系统工具函数
提供消息过滤、文本相似度计算、加权随机抽样等功能
"""
import difflib
import random
import re
from typing import Any
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity
HAS_SKLEARN = True
except Exception: # pragma: no cover - 依赖缺失时静默回退
HAS_SKLEARN = False
from src.common.logger import get_logger
logger = get_logger("express_utils")
# 预编译正则,减少重复编译开销
_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*")
_RE_AT = re.compile(r"@<[^>]*>")
_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]")
_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]")
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
Args:
content: 原始消息内容
Returns:
过滤后的纯文本内容
"""
if not content:
return ""
# 使用预编译正则提升性能
content = _RE_REPLY.sub("", content)
content = _RE_AT.sub("", content)
content = _RE_IMAGE.sub("", content)
content = _RE_EMOJI.sub("", content)
return content.strip()
def _similarity_tfidf(text1: str, text2: str) -> float | None:
"""使用 TF-IDF + 余弦相似度;依赖 sklearn缺失则返回 None。"""
if not HAS_SKLEARN:
return None
# 过短文本用传统算法更稳健
if len(text1) < 2 or len(text2) < 2:
return None
try:
vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2))
tfidf = vec.fit_transform([text1, text2])
sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0])
return max(0.0, min(1.0, sim))
except Exception:
return None
def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float:
"""
计算两个文本的相似度返回0-1之间的值
- 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒)
- 不可用或失败时回退到 SequenceMatcher
Args:
text1: 第一个文本
text2: 第二个文本
prefer_vector: 是否优先使用向量化方案(默认是)
Returns:
相似度值 (0-1)
"""
if not text1 or not text2:
return 0.0
if text1 == text2:
return 1.0
if prefer_vector:
sim = _similarity_tfidf(text1, text2)
if sim is not None:
return sim
return difflib.SequenceMatcher(None, text1, text2).ratio()
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
"""
加权随机抽样函数
Args:
population: 待抽样的数据列表
k: 抽样数量
weight_key: 权重字段名如果为None则等概率抽样
Returns:
抽样结果列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 如果指定了权重字段
if weight_key and all(weight_key in item for item in population):
try:
# 获取权重
weights = [float(item.get(weight_key, 1.0)) for item in population]
# 使用random.choices进行加权抽样
return random.choices(population, weights=weights, k=k)
except (ValueError, TypeError) as e:
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
# 等概率抽样(无放回,保持去重)
population_copy = population.copy()
# 使用 random.sample 提升可读性和性能
return random.sample(population_copy, k)
def normalize_text(text: str) -> str:
"""
标准化文本,移除多余空白字符
Args:
text: 输入文本
Returns:
标准化后的文本
"""
# 替换多个连续空白字符为单个空格
text = re.sub(r"\s+", " ", text)
return text.strip()
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
"""
简单的关键词提取(基于词频)
Args:
text: 输入文本
max_keywords: 最大关键词数量
Returns:
关键词列表
"""
if not text:
return []
try:
import rjieba.analyse
# 使用TF-IDF提取关键词
keywords = rjieba.analyse.extract_tags(text, topK=max_keywords)
return keywords
except ImportError:
logger.warning("rjieba未安装无法提取关键词")
# 简单分词,按长度降序优先输出较长词,提升粗略关键词质量
words = text.split()
words.sort(key=len, reverse=True)
return words[:max_keywords]
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
"""
格式化表达方式对
Args:
situation: 情境
style: 风格
index: 序号(可选)
Returns:
格式化后的字符串
"""
if index is not None:
return f'{index}. 当"{situation}"时,使用"{style}"'
else:
return f'"{situation}"时,使用"{style}"'
def parse_expression_pair(text: str) -> tuple[str, str] | None:
"""
解析表达方式对文本
Args:
text: 格式化的表达方式对文本
Returns:
(situation, style) 或 None
"""
# 匹配格式:当"..."时,使用"..."
match = re.search(r'"(.+?)"时,使用"(.+?)"', text)
if match:
return match.group(1), match.group(2)
return None
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
"""
批量去重表达方式
Args:
expressions: 表达方式列表
key_fields: 用于去重的字段名列表
Returns:
去重后的表达方式列表
"""
seen = set()
unique_expressions = []
for expr in expressions:
# 构建去重key
key_values = tuple(expr.get(field, "") for field in key_fields)
if key_values not in seen:
seen.add(key_values)
unique_expressions.append(expr)
return unique_expressions
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
"""
根据时间计算权重(时间衰减)
Args:
last_active_time: 最后活跃时间戳
current_time: 当前时间戳
half_life_days: 半衰期天数
Returns:
权重值 (0-1)
"""
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
if time_diff_days < 0:
return 1.0
# 使用指数衰减公式
decay_rate = 0.693 / half_life_days # ln(2) / half_life
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
return weight
def merge_expressions_from_multiple_chats(
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
) -> list[dict[str, Any]]:
"""
合并多个聊天室的表达方式
Args:
expressions_dict: {chat_id: [expressions]}
max_total: 最大合并数量
Returns:
合并后的表达方式列表
"""
all_expressions = []
# 收集所有表达方式
for chat_id, expressions in expressions_dict.items():
for expr in expressions:
expr_with_source = expr.copy()
expr_with_source["source_id"] = chat_id
all_expressions.append(expr_with_source)
if not all_expressions:
return []
# 选择排序键(优先 count其次 last_active_time无则保持原序
sample = all_expressions[0]
if "count" in sample:
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
elif "last_active_time" in sample:
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
# 去重基于situation和style
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
# 限制数量
return all_expressions[:max_total]