feat(expression): 添加表达方式选择模式支持与DatabaseMessages兼容性改进
- 新增统一的表达方式选择入口,支持classic和exp_model两种模式 - 添加StyleLearner模型预测模式,可基于机器学习模型选择表达风格 - 改进多个模块对DatabaseMessages数据模型的兼容性处理 - 优化消息处理逻辑,统一处理字典和DatabaseMessages对象 - 在配置中添加expression.mode字段控制表达选择模式
This commit is contained in:
254
src/chat/express/express_utils.py
Normal file
254
src/chat/express/express_utils.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
表达系统工具函数
|
||||
提供消息过滤、文本相似度计算、加权随机抽样等功能
|
||||
"""
|
||||
import difflib
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
过滤后的纯文本内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[图片:...]格式的图片ID
|
||||
content = re.sub(r"\[图片:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
相似度值 (0-1)
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
加权随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 待抽样的数据列表
|
||||
k: 抽样数量
|
||||
weight_key: 权重字段名,如果为None则等概率抽样
|
||||
|
||||
Returns:
|
||||
抽样结果列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 如果指定了权重字段
|
||||
if weight_key and all(weight_key in item for item in population):
|
||||
try:
|
||||
# 获取权重
|
||||
weights = [float(item.get(weight_key, 1.0)) for item in population]
|
||||
# 使用random.choices进行加权抽样
|
||||
return random.choices(population, weights=weights, k=k)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||||
|
||||
# 等概率抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
标准化文本,移除多余空白字符
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
标准化后的文本
|
||||
"""
|
||||
# 替换多个连续空白字符为单个空格
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
"""
|
||||
简单的关键词提取(基于词频)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
max_keywords: 最大关键词数量
|
||||
|
||||
Returns:
|
||||
关键词列表
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
try:
|
||||
import jieba.analyse
|
||||
|
||||
# 使用TF-IDF提取关键词
|
||||
keywords = jieba.analyse.extract_tags(text, topK=max_keywords)
|
||||
return keywords
|
||||
except ImportError:
|
||||
logger.warning("jieba未安装,无法提取关键词")
|
||||
# 简单分词
|
||||
words = text.split()
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
|
||||
"""
|
||||
格式化表达方式对
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
index: 序号(可选)
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
if index is not None:
|
||||
return f'{index}. 当"{situation}"时,使用"{style}"'
|
||||
else:
|
||||
return f'当"{situation}"时,使用"{style}"'
|
||||
|
||||
|
||||
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
解析表达方式对文本
|
||||
|
||||
Args:
|
||||
text: 格式化的表达方式对文本
|
||||
|
||||
Returns:
|
||||
(situation, style) 或 None
|
||||
"""
|
||||
# 匹配格式:当"..."时,使用"..."
|
||||
match = re.search(r'当"(.+?)"时,使用"(.+?)"', text)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
return None
|
||||
|
||||
|
||||
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量去重表达方式
|
||||
|
||||
Args:
|
||||
expressions: 表达方式列表
|
||||
key_fields: 用于去重的字段名列表
|
||||
|
||||
Returns:
|
||||
去重后的表达方式列表
|
||||
"""
|
||||
seen = set()
|
||||
unique_expressions = []
|
||||
|
||||
for expr in expressions:
|
||||
# 构建去重key
|
||||
key_values = tuple(expr.get(field, "") for field in key_fields)
|
||||
|
||||
if key_values not in seen:
|
||||
seen.add(key_values)
|
||||
unique_expressions.append(expr)
|
||||
|
||||
return unique_expressions
|
||||
|
||||
|
||||
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
|
||||
"""
|
||||
根据时间计算权重(时间衰减)
|
||||
|
||||
Args:
|
||||
last_active_time: 最后活跃时间戳
|
||||
current_time: 当前时间戳
|
||||
half_life_days: 半衰期天数
|
||||
|
||||
Returns:
|
||||
权重值 (0-1)
|
||||
"""
|
||||
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
|
||||
if time_diff_days < 0:
|
||||
return 1.0
|
||||
|
||||
# 使用指数衰减公式
|
||||
decay_rate = 0.693 / half_life_days # ln(2) / half_life
|
||||
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def merge_expressions_from_multiple_chats(
|
||||
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
合并多个聊天室的表达方式
|
||||
|
||||
Args:
|
||||
expressions_dict: {chat_id: [expressions]}
|
||||
max_total: 最大合并数量
|
||||
|
||||
Returns:
|
||||
合并后的表达方式列表
|
||||
"""
|
||||
all_expressions = []
|
||||
|
||||
# 收集所有表达方式
|
||||
for chat_id, expressions in expressions_dict.items():
|
||||
for expr in expressions:
|
||||
# 添加source_id标识
|
||||
expr_with_source = expr.copy()
|
||||
expr_with_source["source_id"] = chat_id
|
||||
all_expressions.append(expr_with_source)
|
||||
|
||||
# 按count或last_active_time排序
|
||||
if all_expressions and "count" in all_expressions[0]:
|
||||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||
elif all_expressions and "last_active_time" in all_expressions[0]:
|
||||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||||
|
||||
# 去重(基于situation和style)
|
||||
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
|
||||
|
||||
# 限制数量
|
||||
return all_expressions[:max_total]
|
||||
Reference in New Issue
Block a user