优化表达方式学习
This commit is contained in:
@@ -7,11 +7,26 @@ import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except Exception: # pragma: no cover - 依赖缺失时静默回退
|
||||
HAS_SKLEARN = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
# 预编译正则,减少重复编译开销
|
||||
_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*")
|
||||
_RE_AT = re.compile(r"@<[^>]*>")
|
||||
_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]")
|
||||
_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]")
|
||||
|
||||
|
||||
def filter_message_content(content: str | None) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
@@ -25,29 +40,56 @@ def filter_message_content(content: str | None) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[图片:...]格式的图片ID
|
||||
content = re.sub(r"\[图片:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
# 使用预编译正则提升性能
|
||||
content = _RE_REPLY.sub("", content)
|
||||
content = _RE_AT.sub("", content)
|
||||
content = _RE_IMAGE.sub("", content)
|
||||
content = _RE_EMOJI.sub("", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def _similarity_tfidf(text1: str, text2: str) -> float | None:
|
||||
"""使用 TF-IDF + 余弦相似度;依赖 sklearn,缺失则返回 None。"""
|
||||
if not HAS_SKLEARN:
|
||||
return None
|
||||
# 过短文本用传统算法更稳健
|
||||
if len(text1) < 2 or len(text2) < 2:
|
||||
return None
|
||||
try:
|
||||
vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2))
|
||||
tfidf = vec.fit_transform([text1, text2])
|
||||
sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0])
|
||||
return max(0.0, min(1.0, sim))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
|
||||
- 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒)
|
||||
- 不可用或失败时回退到 SequenceMatcher
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
prefer_vector: 是否优先使用向量化方案(默认是)
|
||||
|
||||
Returns:
|
||||
相似度值 (0-1)
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
if text1 == text2:
|
||||
return 1.0
|
||||
|
||||
if prefer_vector:
|
||||
sim = _similarity_tfidf(text1, text2)
|
||||
if sim is not None:
|
||||
return sim
|
||||
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
@@ -79,18 +121,10 @@ def weighted_sample(population: list[dict], k: int, weight_key: str | None = Non
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||||
|
||||
# 等概率抽样
|
||||
selected = []
|
||||
# 等概率抽样(无放回,保持去重)
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
|
||||
return selected
|
||||
# 使用 random.sample 提升可读性和性能
|
||||
return random.sample(population_copy, k)
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
@@ -130,8 +164,9 @@ def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||||
return keywords
|
||||
except ImportError:
|
||||
logger.warning("rjieba未安装,无法提取关键词")
|
||||
# 简单分词
|
||||
# 简单分词,按长度降序优先输出较长词,提升粗略关键词质量
|
||||
words = text.split()
|
||||
words.sort(key=len, reverse=True)
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
@@ -236,15 +271,18 @@ def merge_expressions_from_multiple_chats(
|
||||
# 收集所有表达方式
|
||||
for chat_id, expressions in expressions_dict.items():
|
||||
for expr in expressions:
|
||||
# 添加source_id标识
|
||||
expr_with_source = expr.copy()
|
||||
expr_with_source["source_id"] = chat_id
|
||||
all_expressions.append(expr_with_source)
|
||||
|
||||
# 按count或last_active_time排序
|
||||
if all_expressions and "count" in all_expressions[0]:
|
||||
if not all_expressions:
|
||||
return []
|
||||
|
||||
# 选择排序键(优先 count,其次 last_active_time),无则保持原序
|
||||
sample = all_expressions[0]
|
||||
if "count" in sample:
|
||||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||
elif all_expressions and "last_active_time" in all_expressions[0]:
|
||||
elif "last_active_time" in sample:
|
||||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||||
|
||||
# 去重(基于situation和style)
|
||||
|
||||
Reference in New Issue
Block a user