优化表达方式学习

This commit is contained in:
LuiKlee
2025-12-16 11:38:56 +08:00
parent 8c451e42fb
commit c3e2e713ef
9 changed files with 526 additions and 260 deletions

View File

@@ -7,11 +7,26 @@ import random
import re
from typing import Any
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity
HAS_SKLEARN = True
except Exception: # pragma: no cover - 依赖缺失时静默回退
HAS_SKLEARN = False
from src.common.logger import get_logger
logger = get_logger("express_utils")
# 预编译正则,减少重复编译开销
_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*")
_RE_AT = re.compile(r"@<[^>]*>")
_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]")
_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]")
def filter_message_content(content: str | None) -> str:
"""
过滤消息内容,移除回复、@、图片等格式
@@ -25,29 +40,56 @@ def filter_message_content(content: str | None) -> str:
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[图片:...]格式的图片ID
content = re.sub(r"\[图片:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
# 使用预编译正则提升性能
content = _RE_REPLY.sub("", content)
content = _RE_AT.sub("", content)
content = _RE_IMAGE.sub("", content)
content = _RE_EMOJI.sub("", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
def _similarity_tfidf(text1: str, text2: str) -> float | None:
"""使用 TF-IDF + 余弦相似度;依赖 sklearn缺失则返回 None。"""
if not HAS_SKLEARN:
return None
# 过短文本用传统算法更稳健
if len(text1) < 2 or len(text2) < 2:
return None
try:
vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2))
tfidf = vec.fit_transform([text1, text2])
sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0])
return max(0.0, min(1.0, sim))
except Exception:
return None
def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float:
"""
计算两个文本的相似度返回0-1之间的值
- 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒)
- 不可用或失败时回退到 SequenceMatcher
Args:
text1: 第一个文本
text2: 第二个文本
prefer_vector: 是否优先使用向量化方案(默认是)
Returns:
相似度值 (0-1)
"""
if not text1 or not text2:
return 0.0
if text1 == text2:
return 1.0
if prefer_vector:
sim = _similarity_tfidf(text1, text2)
if sim is not None:
return sim
return difflib.SequenceMatcher(None, text1, text2).ratio()
@@ -79,18 +121,10 @@ def weighted_sample(population: list[dict], k: int, weight_key: str | None = Non
except (ValueError, TypeError) as e:
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
# 等概率抽样
selected = []
# 等概率抽样(无放回,保持去重)
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
return selected
# 使用 random.sample 提升可读性和性能
return random.sample(population_copy, k)
def normalize_text(text: str) -> str:
@@ -130,8 +164,9 @@ def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
return keywords
except ImportError:
logger.warning("rjieba未安装无法提取关键词")
# 简单分词
# 简单分词,按长度降序优先输出较长词,提升粗略关键词质量
words = text.split()
words.sort(key=len, reverse=True)
return words[:max_keywords]
@@ -236,15 +271,18 @@ def merge_expressions_from_multiple_chats(
# 收集所有表达方式
for chat_id, expressions in expressions_dict.items():
for expr in expressions:
# 添加source_id标识
expr_with_source = expr.copy()
expr_with_source["source_id"] = chat_id
all_expressions.append(expr_with_source)
# 按count或last_active_time排序
if all_expressions and "count" in all_expressions[0]:
if not all_expressions:
return []
# 选择排序键(优先 count其次 last_active_time无则保持原序
sample = all_expressions[0]
if "count" in sample:
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
elif all_expressions and "last_active_time" in all_expressions[0]:
elif "last_active_time" in sample:
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
# 去重基于situation和style

View File

@@ -358,7 +358,10 @@ class ExpressionLearner:
@staticmethod
@cached(ttl=600, key_prefix="chat_expressions")
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""内部方法:从数据库获取表达方式(带缓存)"""
"""内部方法:从数据库获取表达方式(带缓存)
🔥 优化:使用列表推导式和更高效的数据处理
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
@@ -366,67 +369,91 @@ class ExpressionLearner:
crud = CRUDBase(Expression)
all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000)
# 🔥 优化:使用列表推导式批量处理,减少循环开销
for expr in all_expressions:
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
expr_data = {
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": expr.type,
"create_date": create_date,
}
expr_data = {
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": expr.type,
"create_date": create_date,
}
# 根据类型分类
if expr.type == "style":
learnt_style_expressions.append(expr_data)
elif expr.type == "grammar":
learnt_grammar_expressions.append(expr_data)
# 根据类型分类(避免多次类型检查)
if expr.type == "style":
learnt_style_expressions.append(expr_data)
elif expr.type == "grammar":
learnt_grammar_expressions.append(expr_data)
logger.debug(f"已加载 {len(learnt_style_expressions)} 个style和 {len(learnt_grammar_expressions)} 个grammar表达方式 (chat_id={chat_id})")
return learnt_style_expressions, learnt_grammar_expressions
async def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
优化: 使用CRUD批量处理所有更改最后统一提交
优化: 使用分批处理和原生 SQL 操作提升性能
"""
try:
# 使用CRUD查询所有表达方式
crud = CRUDBase(Expression)
all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式
BATCH_SIZE = 1000 # 分批处理,避免一次性加载过多数据
updated_count = 0
deleted_count = 0
offset = 0
# 需要手动操作的情况下使用session
async with get_db_session() as session:
# 批量处理所有修改
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
await session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
updated_count += 1
# 优化: 统一提交所有更改从N次提交减少到1次
if updated_count > 0 or deleted_count > 0:
while True:
async with get_db_session() as session:
# 分批查询表达方式
batch_result = await session.execute(
select(Expression)
.order_by(Expression.id)
.limit(BATCH_SIZE)
.offset(offset)
)
batch_expressions = list(batch_result.scalars())
if not batch_expressions:
break # 没有更多数据
# 批量处理当前批次
to_delete = []
for expr in batch_expressions:
# 计算时间差
time_diff_days = (current_time - expr.last_active_time) / (24 * 3600)
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 标记删除
to_delete.append(expr)
else:
# 更新count
expr.count = new_count
updated_count += 1
# 批量删除
if to_delete:
for expr in to_delete:
await session.delete(expr)
deleted_count += len(to_delete)
# 提交当前批次
await session.commit()
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
# 如果批次不满,说明已经处理完所有数据
if len(batch_expressions) < BATCH_SIZE:
break
offset += BATCH_SIZE
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
@@ -509,88 +536,106 @@ class ExpressionLearner:
CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
# 🔥 优化批量查询所有现有表达方式避免N次数据库查询
existing_exprs_result = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
)
)
existing_exprs = list(existing_exprs_result.scalars())
# 构建快速查找索引
exact_match_map = {} # (situation, style) -> Expression
situation_map = {} # situation -> Expression
style_map = {} # style -> Expression
for expr in existing_exprs:
key = (expr.situation, expr.style)
exact_match_map[key] = expr
# 只保留第一个匹配(优先级:完全匹配 > 情景匹配 > 表达匹配)
if expr.situation not in situation_map:
situation_map[expr.situation] = expr
if expr.style not in style_map:
style_map[expr.style] = expr
# 批量处理所有新表达方式
for new_expr in expr_list:
# 🔥 改进1检查是否存在相同情景或相同表达的数据
# 情况1相同 chat_id + type + situation相同情景不同表达
query_same_situation = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
)
)
same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.style == new_expr["style"])
)
)
same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
)
exact_match_expr = query_exact_match.scalar()
situation = new_expr["situation"]
style_val = new_expr["style"]
exact_key = (situation, style_val)
# 优先处理完全匹配的情况
if exact_match_expr:
if exact_key in exact_match_map:
# 完全相同增加count更新时间
expr_obj = exact_match_expr
expr_obj = exact_match_map[exact_key]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
logger.debug(f"完全匹配更新count {expr_obj.count}")
elif same_situation_expr:
elif situation in situation_map:
# 相同情景,不同表达:覆盖旧的表达
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'")
same_situation_expr.style = new_expr["style"]
same_situation_expr = situation_map[situation]
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'")
# 更新映射
old_key = (same_situation_expr.situation, same_situation_expr.style)
if old_key in exact_match_map:
del exact_match_map[old_key]
same_situation_expr.style = style_val
same_situation_expr.count = same_situation_expr.count + 1
same_situation_expr.last_active_time = current_time
elif same_style_expr:
# 更新新的完全匹配映射
exact_match_map[exact_key] = same_situation_expr
elif style_val in style_map:
# 相同表达,不同情景:覆盖旧的情景
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'")
same_style_expr.situation = new_expr["situation"]
same_style_expr = style_map[style_val]
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'")
# 更新映射
old_key = (same_style_expr.situation, same_style_expr.style)
if old_key in exact_match_map:
del exact_match_map[old_key]
same_style_expr.situation = situation
same_style_expr.count = same_style_expr.count + 1
same_style_expr.last_active_time = current_time
# 更新新的完全匹配映射
exact_match_map[exact_key] = same_style_expr
situation_map[situation] = same_style_expr
else:
# 完全新的表达方式:创建新记录
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
situation=situation,
style=style_val,
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type,
create_date=current_time, # 手动设置创建日期
create_date=current_time,
)
session.add(new_expression)
logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}")
# 更新映射
exact_match_map[exact_key] = new_expression
situation_map[situation] = new_expression
style_map[style_val] = new_expression
logger.debug(f"新增表达方式:{situation} -> {style_val}")
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
exprs_result = await session.execute(
select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
exprs = list(exprs_result.scalars())
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
# 🔥 优化:限制最大数量 - 使用已加载的数据避免重复查询
# existing_exprs 已包含该 chat_id 和 type 的所有表达方式
all_current_exprs = list(exact_match_map.values())
if len(all_current_exprs) > MAX_EXPRESSION_COUNT:
# 按 count 排序,删除 count 最小的多余表达方式
sorted_exprs = sorted(all_current_exprs, key=lambda e: e.count)
for expr in sorted_exprs[: len(all_current_exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 从映射中移除
key = (expr.situation, expr.style)
if key in exact_match_map:
del exact_match_map[key]
logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式")
# 提交后清除相关缓存
# 提交数据库更改
await session.commit()
# 🔥 清除共享组内所有 chat_id 的表达方式缓存
# 🔥 优化:只在实际有更新时才清除缓存(移到外层,避免重复清除)
if chat_dict: # 只有当有数据更新时才清除缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
@@ -602,53 +647,59 @@ class ExpressionLearner:
if len(related_chat_ids) > 1:
logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存")
# 🔥 训练 StyleLearner支持共享组
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)
if type == "style":
try:
logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}")
# 为每个共享组内的 chat_id 训练其 StyleLearner
for target_chat_id in related_chat_ids:
learner = style_learner_manager.get_learner(target_chat_id)
# 🔥 训练 StyleLearner支持共享组
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)
if type == "style" and chat_dict:
try:
related_chat_ids = self.get_related_chat_ids()
total_samples = sum(len(expr_list) for expr_list in chat_dict.values())
logger.debug(f"开始训练 StyleLearner: 共享组包含 {len(related_chat_ids)} 个chat_id, 总样本数={total_samples}")
# 为每个共享组内的 chat_id 训练其 StyleLearner
for target_chat_id in related_chat_ids:
learner = style_learner_manager.get_learner(target_chat_id)
# 收集该 target_chat_id 对应的所有表达方式
# 如果是源 chat_id使用 chat_dict 中的数据;否则也要训练(共享组特性)
total_success = 0
total_samples = 0
for source_chat_id, expr_list in chat_dict.items():
# 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标
# 这是最符合语义的方式:场景 -> 表达方式
success_count = 0
for expr in expr_list:
situation = expr["situation"]
style = expr["style"]
# 训练映射关系: situation -> style
if learner.learn_mapping(situation, style):
success_count += 1
else:
logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}")
# 保存模型
total_success += 1
total_samples += 1
# 保存模型
if total_samples > 0:
if learner.save(style_learner_manager.model_save_path):
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
else:
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
if target_chat_id == chat_id:
# 只为 chat_id 记录详细日志
if target_chat_id == self.chat_id:
# 只为当前 chat_id 记录详细日志
logger.info(
f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, "
f"StyleLearner 训练完成: {total_success}/{total_samples} 成功, "
f"当前风格总数={len(learner.get_all_styles())}, "
f"总样本数={learner.learning_stats['total_samples']}"
)
else:
logger.debug(
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功"
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {total_success}/{total_samples} 成功"
)
if len(related_chat_ids) > 1:
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
if len(related_chat_ids) > 1:
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}")
except Exception as e:
logger.error(f"训练 StyleLearner 失败: {e}")
return learnt_expressions
return None

View File

@@ -207,31 +207,20 @@ class ExpressionSelector:
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
)
style_exprs = [
{
# 🔥 优化:提前定义转换函数,避免重复代码
def expr_to_dict(expr, expr_type: str) -> dict[str, Any]:
return {
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"type": expr_type,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
@@ -251,9 +240,14 @@ class ExpressionSelector:
@staticmethod
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库
🔥 优化:合并所有更新到一个事务中,减少数据库连接开销
"""
if not expressions_to_update:
return
# 去重处理
updates_by_key = {}
affected_chat_ids = set()
for expr in expressions_to_update:
@@ -269,9 +263,15 @@ class ExpressionSelector:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
for chat_id, expr_type, situation, style in updates_by_key:
async with get_db_session() as session:
query = await session.execute(
if not updates_by_key:
return
# 🔥 优化:使用单个 session 批量处理所有更新
current_time = time.time()
async with get_db_session() as session:
updated_count = 0
for chat_id, expr_type, situation, style in updates_by_key:
query_result = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
@@ -279,25 +279,26 @@ class ExpressionSelector:
& (Expression.style == style)
)
)
query = query.scalar()
if query:
expr_obj = query
expr_obj = query_result.scalar()
if expr_obj:
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.last_active_time = current_time
updated_count += 1
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
# 批量提交所有更改
if updated_count > 0:
await session.commit()
logger.debug(f"批量更新了 {updated_count} 个表达方式的count值")
# 清除所有受影响的chat_id的缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
for chat_id in affected_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", chat_id))
if affected_chat_ids:
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
for chat_id in affected_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", chat_id))
async def select_suitable_expressions(
self,
@@ -518,29 +519,41 @@ class ExpressionSelector:
logger.warning("数据库中完全没有任何表达方式,需要先学习")
return []
# 🔥 使用模糊匹配而不是精确匹配
# 计算每个预测style与数据库style的相似度
# 🔥 优化:使用更高效的模糊匹配算法
from difflib import SequenceMatcher
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
matched_expressions = []
for expr in all_expressions:
db_style = expr.style or ""
db_style_lower = db_style.lower()
max_similarity = 0.0
best_predicted = ""
# 与每个预测的style计算相似度
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
# 计算字符串相似度
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
# 也检查包含关系(如果一个是另一个的子串,给更高分)
if len(predicted_style) >= 2 and len(db_style) >= 2:
if predicted_style in db_style or db_style in predicted_style:
similarity = max(similarity, 0.7)
for predicted_style_lower, pred_score in predicted_styles_lower:
# 快速检查:完全匹配
if predicted_style_lower == db_style_lower:
max_similarity = 1.0
best_predicted = predicted_style_lower
break
# 快速检查:子串匹配
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
similarity = 0.7
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style_lower
continue
# 计算字符串相似度(较慢,只在必要时使用)
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
if similarity > max_similarity:
max_similarity = similarity
best_predicted = predicted_style
best_predicted = predicted_style_lower
# 🔥 降低阈值到30%因为StyleLearner预测质量较差
if max_similarity >= 0.3: # 30%相似度阈值
@@ -573,14 +586,15 @@ class ExpressionSelector:
f"(候选 {len(matched_expressions)}temperature={temperature})"
)
# 转换为字典格式
# 🔥 优化:使用列表推导式和预定义函数减少开销
expressions = [
{
"situation": expr.situation or "",
"style": expr.style or "",
"type": expr.type or "style",
"count": float(expr.count) if expr.count else 0.0,
"last_active_time": expr.last_active_time or 0.0
"last_active_time": expr.last_active_time or 0.0,
"source_id": expr.chat_id # 添加 source_id 以便后续更新
}
for expr in expressions_objs
]

View File

@@ -127,7 +127,8 @@ class SituationExtractor:
Returns:
情境描述列表
"""
situations = []
situations: list[str] = []
seen = set()
for line in response.splitlines():
line = line.strip()
@@ -150,6 +151,11 @@ class SituationExtractor:
if any(keyword in line.lower() for keyword in ["例如", "注意", "", "分析", "总结"]):
continue
# 去重,保持原有顺序
if line in seen:
continue
seen.add(line)
situations.append(line)
if len(situations) >= max_situations:

View File

@@ -4,6 +4,7 @@
支持多聊天室独立建模和在线学习
"""
import os
import pickle
import time
from src.common.logger import get_logger
@@ -16,11 +17,12 @@ logger = get_logger("expressor.style_learner")
class StyleLearner:
"""单个聊天室的表达风格学习器"""
def __init__(self, chat_id: str, model_config: dict | None = None):
def __init__(self, chat_id: str, model_config: dict | None = None, resource_limit_enabled: bool = True):
"""
Args:
chat_id: 聊天室ID
model_config: 模型配置
resource_limit_enabled: 是否启用资源上限控制(默认关闭)
"""
self.chat_id = chat_id
self.model_config = model_config or {
@@ -34,6 +36,9 @@ class StyleLearner:
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 资源上限控制开关(默认开启,可按需关闭)
self.resource_limit_enabled = resource_limit_enabled
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.cleanup_threshold = 0.9 # 达到90%容量时触发清理
@@ -67,18 +72,15 @@ class StyleLearner:
if style in self.style_to_id:
return True
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
self._cleanup_styles()
elif current_count >= cleanup_trigger:
# 接近限制,提前清理
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
# 检查是否需要清理(仅计算一次阈值)
if self.resource_limit_enabled:
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
else:
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
self._cleanup_styles()
# 生成新的style_id
@@ -95,7 +97,8 @@ class StyleLearner:
self.expressor.add_candidate(style_id, style, situation)
# 初始化统计
self.learning_stats["style_counts"][style_id] = 0
self.learning_stats.setdefault("style_counts", {})[style_id] = 0
self.learning_stats.setdefault("style_last_used", {})
logger.debug(f"添加风格成功: {style_id} -> {style}")
return True
@@ -114,64 +117,64 @@ class StyleLearner:
3. 默认清理 cleanup_ratio (20%) 的风格
"""
try:
total_styles = len(self.style_to_id)
if total_styles == 0:
return
# 只有在达到阈值时才执行昂贵的排序
cleanup_count = max(1, int(total_styles * self.cleanup_ratio))
if cleanup_count <= 0:
return
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 局部引用加速频繁调用的函数
from math import exp, log1p
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float("inf")
usage_score = log1p(usage_count)
days_unused = time_since_used / 86400
time_score = exp(-days_unused / 30)
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
if not style_scores:
return
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
style_text = self.id_to_style.get(style_id)
if style_text:
# 从映射中删除
del self.style_to_id[style_text]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
if not style_text:
continue
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从映射中删除
self.style_to_id.pop(style_text, None)
self.id_to_style.pop(style_id, None)
self.id_to_situation.pop(style_id, None)
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
# 从统计中删除
self.learning_stats["style_counts"].pop(style_id, None)
self.learning_stats["style_last_used"].pop(style_id, None)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
@@ -204,7 +207,9 @@ class StyleLearner:
# 更新统计
current_time = time.time()
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats.setdefault("style_counts", {})
self.learning_stats.setdefault("style_last_used", {})
self.learning_stats["style_counts"][style_id] = self.learning_stats["style_counts"].get(style_id, 0) + 1
self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间
self.learning_stats["last_update"] = current_time
@@ -349,11 +354,11 @@ class StyleLearner:
# 保存expressor模型
model_path = os.path.join(save_dir, "expressor_model.pkl")
self.expressor.save(model_path)
# 保存映射关系和统计信息
import pickle
tmp_model_path = f"{model_path}.tmp"
self.expressor.save(tmp_model_path)
os.replace(tmp_model_path, model_path)
# 保存映射关系和统计信息(原子写)
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
@@ -368,8 +373,13 @@ class StyleLearner:
"learning_stats": self.learning_stats,
}
with open(meta_path, "wb") as f:
pickle.dump(meta_data, f)
tmp_meta_path = f"{meta_path}.tmp"
with open(tmp_meta_path, "wb") as f:
pickle.dump(meta_data, f, protocol=pickle.HIGHEST_PROTOCOL)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_meta_path, meta_path)
return True
@@ -401,8 +411,6 @@ class StyleLearner:
self.expressor.load(model_path)
# 加载映射关系和统计信息
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
if os.path.exists(meta_path):
with open(meta_path, "rb") as f:
@@ -445,14 +453,16 @@ class StyleLearnerManager:
# 🔧 最大活跃 learner 数量
MAX_ACTIVE_LEARNERS = 50
def __init__(self, model_save_path: str = "data/expression/style_models"):
def __init__(self, model_save_path: str = "data/expression/style_models", resource_limit_enabled: bool = True):
"""
Args:
model_save_path: 模型保存路径
resource_limit_enabled: 是否启用资源上限控制(默认开启)
"""
self.learners: dict[str, StyleLearner] = {}
self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间
self.model_save_path = model_save_path
self.resource_limit_enabled = resource_limit_enabled
# 确保保存目录存在
os.makedirs(model_save_path, exist_ok=True)
@@ -475,7 +485,10 @@ class StyleLearnerManager:
for chat_id, last_used in sorted_by_time[:evict_count]:
if chat_id in self.learners:
# 先保存再淘汰
self.learners[chat_id].save(self.model_save_path)
try:
self.learners[chat_id].save(self.model_save_path)
except Exception as e:
logger.error(f"LRU淘汰时保存学习器失败: chat_id={chat_id}, error={e}")
del self.learners[chat_id]
del self.learner_last_used[chat_id]
evicted.append(chat_id)
@@ -502,7 +515,11 @@ class StyleLearnerManager:
self._evict_if_needed()
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
learner = StyleLearner(
chat_id,
model_config,
resource_limit_enabled=self.resource_limit_enabled,
)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
@@ -511,6 +528,12 @@ class StyleLearnerManager:
return self.learners[chat_id]
def set_resource_limit(self, enabled: bool) -> None:
"""动态开启/关闭资源上限控制(默认关闭)。"""
self.resource_limit_enabled = enabled
for learner in self.learners.values():
learner.resource_limit_enabled = enabled
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系

View File

@@ -0,0 +1,38 @@
# 短期记忆压力泄压补丁
## 背景
部分场景下,短期记忆层在自动转移尚未触发时会快速堆积,可能导致短期记忆达到容量上限并阻塞后续写入。
## 变更(补丁)
- 新增“压力泄压”开关:可选择在占用率达到 100% 时,删除低重要性且最早的短期记忆,防止短期层持续膨胀。
- 默认关闭,需显式开启后才会执行自动删除。
## 开关配置
- 入口:`UnifiedMemoryManager` 构造参数
- `short_term_enable_force_cleanup: bool = False`
- 传递到短期层:`ShortTermMemoryManager(enable_force_cleanup=True)`
- 关闭示例:
```python
manager = UnifiedMemoryManager(
short_term_enable_force_cleanup=False,
)
```
## 行为说明
- 当短期记忆占用率达到或超过 100%,且当前没有待转移批次时:
- 触发 `force_cleanup_overflow()`
- 按“低重要性优先、创建时间最早优先”删除一批记忆,将容量压回约 `max_memories * 0.9`
- 清理在后台持久化,不阻塞主流程。
## 影响范围
- 默认行为保持与补丁前一致(开关默认 `off`)。
- 如果关闭开关,短期层将不再做强制删除,只依赖自动转移机制。
## 回滚
- 构造时将 `short_term_enable_force_cleanup=False` 即可关闭;无需代码回滚。