本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
255 lines
6.6 KiB
Python
255 lines
6.6 KiB
Python
"""
|
||
表达系统工具函数
|
||
提供消息过滤、文本相似度计算、加权随机抽样等功能
|
||
"""
|
||
import difflib
|
||
import random
|
||
import re
|
||
from typing import Any
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("express_utils")
|
||
|
||
|
||
def filter_message_content(content: str | None) -> str:
|
||
"""
|
||
过滤消息内容,移除回复、@、图片等格式
|
||
|
||
Args:
|
||
content: 原始消息内容
|
||
|
||
Returns:
|
||
过滤后的纯文本内容
|
||
"""
|
||
if not content:
|
||
return ""
|
||
|
||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||
# 移除@<...>格式的内容
|
||
content = re.sub(r"@<[^>]*>", "", content)
|
||
# 移除[图片:...]格式的图片ID
|
||
content = re.sub(r"\[图片:[^\]]*\]", "", content)
|
||
# 移除[表情包:...]格式的内容
|
||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||
|
||
return content.strip()
|
||
|
||
|
||
def calculate_similarity(text1: str, text2: str) -> float:
|
||
"""
|
||
计算两个文本的相似度,返回0-1之间的值
|
||
|
||
Args:
|
||
text1: 第一个文本
|
||
text2: 第二个文本
|
||
|
||
Returns:
|
||
相似度值 (0-1)
|
||
"""
|
||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||
|
||
|
||
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
|
||
"""
|
||
加权随机抽样函数
|
||
|
||
Args:
|
||
population: 待抽样的数据列表
|
||
k: 抽样数量
|
||
weight_key: 权重字段名,如果为None则等概率抽样
|
||
|
||
Returns:
|
||
抽样结果列表
|
||
"""
|
||
if not population or k <= 0:
|
||
return []
|
||
|
||
if len(population) <= k:
|
||
return population.copy()
|
||
|
||
# 如果指定了权重字段
|
||
if weight_key and all(weight_key in item for item in population):
|
||
try:
|
||
# 获取权重
|
||
weights = [float(item.get(weight_key, 1.0)) for item in population]
|
||
# 使用random.choices进行加权抽样
|
||
return random.choices(population, weights=weights, k=k)
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||
|
||
# 等概率抽样
|
||
selected = []
|
||
population_copy = population.copy()
|
||
|
||
for _ in range(k):
|
||
if not population_copy:
|
||
break
|
||
# 随机选择一个元素
|
||
idx = random.randint(0, len(population_copy) - 1)
|
||
selected.append(population_copy.pop(idx))
|
||
|
||
return selected
|
||
|
||
|
||
def normalize_text(text: str) -> str:
|
||
"""
|
||
标准化文本,移除多余空白字符
|
||
|
||
Args:
|
||
text: 输入文本
|
||
|
||
Returns:
|
||
标准化后的文本
|
||
"""
|
||
# 替换多个连续空白字符为单个空格
|
||
text = re.sub(r"\s+", " ", text)
|
||
return text.strip()
|
||
|
||
|
||
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||
"""
|
||
简单的关键词提取(基于词频)
|
||
|
||
Args:
|
||
text: 输入文本
|
||
max_keywords: 最大关键词数量
|
||
|
||
Returns:
|
||
关键词列表
|
||
"""
|
||
if not text:
|
||
return []
|
||
|
||
try:
|
||
import jieba.analyse
|
||
|
||
# 使用TF-IDF提取关键词
|
||
keywords = jieba.analyse.extract_tags(text, topK=max_keywords)
|
||
return keywords
|
||
except ImportError:
|
||
logger.warning("jieba未安装,无法提取关键词")
|
||
# 简单分词
|
||
words = text.split()
|
||
return words[:max_keywords]
|
||
|
||
|
||
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
|
||
"""
|
||
格式化表达方式对
|
||
|
||
Args:
|
||
situation: 情境
|
||
style: 风格
|
||
index: 序号(可选)
|
||
|
||
Returns:
|
||
格式化后的字符串
|
||
"""
|
||
if index is not None:
|
||
return f'{index}. 当"{situation}"时,使用"{style}"'
|
||
else:
|
||
return f'当"{situation}"时,使用"{style}"'
|
||
|
||
|
||
def parse_expression_pair(text: str) -> tuple[str, str] | None:
|
||
"""
|
||
解析表达方式对文本
|
||
|
||
Args:
|
||
text: 格式化的表达方式对文本
|
||
|
||
Returns:
|
||
(situation, style) 或 None
|
||
"""
|
||
# 匹配格式:当"..."时,使用"..."
|
||
match = re.search(r'当"(.+?)"时,使用"(.+?)"', text)
|
||
if match:
|
||
return match.group(1), match.group(2)
|
||
return None
|
||
|
||
|
||
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
|
||
"""
|
||
批量去重表达方式
|
||
|
||
Args:
|
||
expressions: 表达方式列表
|
||
key_fields: 用于去重的字段名列表
|
||
|
||
Returns:
|
||
去重后的表达方式列表
|
||
"""
|
||
seen = set()
|
||
unique_expressions = []
|
||
|
||
for expr in expressions:
|
||
# 构建去重key
|
||
key_values = tuple(expr.get(field, "") for field in key_fields)
|
||
|
||
if key_values not in seen:
|
||
seen.add(key_values)
|
||
unique_expressions.append(expr)
|
||
|
||
return unique_expressions
|
||
|
||
|
||
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
|
||
"""
|
||
根据时间计算权重(时间衰减)
|
||
|
||
Args:
|
||
last_active_time: 最后活跃时间戳
|
||
current_time: 当前时间戳
|
||
half_life_days: 半衰期天数
|
||
|
||
Returns:
|
||
权重值 (0-1)
|
||
"""
|
||
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
|
||
if time_diff_days < 0:
|
||
return 1.0
|
||
|
||
# 使用指数衰减公式
|
||
decay_rate = 0.693 / half_life_days # ln(2) / half_life
|
||
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
|
||
|
||
return weight
|
||
|
||
|
||
def merge_expressions_from_multiple_chats(
|
||
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
|
||
) -> list[dict[str, Any]]:
|
||
"""
|
||
合并多个聊天室的表达方式
|
||
|
||
Args:
|
||
expressions_dict: {chat_id: [expressions]}
|
||
max_total: 最大合并数量
|
||
|
||
Returns:
|
||
合并后的表达方式列表
|
||
"""
|
||
all_expressions = []
|
||
|
||
# 收集所有表达方式
|
||
for chat_id, expressions in expressions_dict.items():
|
||
for expr in expressions:
|
||
# 添加source_id标识
|
||
expr_with_source = expr.copy()
|
||
expr_with_source["source_id"] = chat_id
|
||
all_expressions.append(expr_with_source)
|
||
|
||
# 按count或last_active_time排序
|
||
if all_expressions and "count" in all_expressions[0]:
|
||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||
elif all_expressions and "last_active_time" in all_expressions[0]:
|
||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||
|
||
# 去重(基于situation和style)
|
||
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
|
||
|
||
# 限制数量
|
||
return all_expressions[:max_total]
|