diff --git a/docs/express_similarity.md b/docs/express_similarity.md new file mode 100644 index 000000000..04055e29c --- /dev/null +++ b/docs/express_similarity.md @@ -0,0 +1,36 @@ +# 表达相似度计算策略 + +本文档说明 `calculate_similarity` 的实现与配置,帮助在质量与性能间做权衡。 + +## 总览 +- 支持两种路径: + 1) **向量化路径(默认优先)**:TF-IDF + 余弦相似度(依赖 `scikit-learn`) + 2) **回退路径**:`difflib.SequenceMatcher` +- 参数 `prefer_vector` 控制是否优先尝试向量化,默认 `True`。 +- 依赖缺失或文本过短时,自动回退,无需额外配置。 + +## 调用方式 +```python +from src.chat.express.express_utils import calculate_similarity + +sim = calculate_similarity(text1, text2) # 默认优先向量化 +sim_fast = calculate_similarity(text1, text2, prefer_vector=False) # 强制使用 SequenceMatcher +``` + +## 依赖与回退 +- 可选依赖:`scikit-learn` + - 缺失时自动回退到 `SequenceMatcher`,不会抛异常。 +- 文本过短(长度 < 2)时直接回退,避免稀疏向量噪声。 + +## 适用建议 +- 文本较长、对鲁棒性/语义相似度有更高要求:保持默认(向量化优先)。 +- 环境无 `scikit-learn` 或追求极简依赖:调用时设置 `prefer_vector=False`。 +- 高并发性能敏感:可在调用点酌情关闭向量化或加缓存。 + +## 返回范围 +- 相似度范围始终在 `[0, 1]`。 +- 空字符串 → `0.0`;完全相同 → `1.0`。 + +## 额外建议 +- 若需更强语义能力,可替换为向量数据库或句向量模型(需新增依赖与配置)。 +- 对热路径可增加缓存(按文本哈希),或限制输入长度以控制向量维度与内存。 diff --git a/docs/changelogs/short_term_pressure_patch.md b/docs/short_term_pressure_patch.md similarity index 95% rename from docs/changelogs/short_term_pressure_patch.md rename to docs/short_term_pressure_patch.md index 65dd7ea76..0e124932c 100644 --- a/docs/changelogs/short_term_pressure_patch.md +++ b/docs/short_term_pressure_patch.md @@ -30,7 +30,7 @@ ## 影响范围 -- 默认行为保持与补丁前一致(开关默认 `on`)。 +- 默认行为保持与补丁前一致(开关默认 `off`)。 - 如果关闭开关,短期层将不再做强制删除,只依赖自动转移机制。 ## 回滚 diff --git a/docs/style_learner_resource_limit.md b/docs/style_learner_resource_limit.md new file mode 100644 index 000000000..da2550742 --- /dev/null +++ b/docs/style_learner_resource_limit.md @@ -0,0 +1,60 @@ +# StyleLearner 资源上限开关(默认开启) + +## 概览 +StyleLearner 支持资源上限控制,用于约束风格容量与清理行为。开关默认 **开启**,以防止模型无限膨胀;可在运行时动态关闭。 + +## 开关位置与用法(务必看这里) + +开关在 **代码层**,默认开启,不依赖配置文件。 + +1) **全局运行时切换(推荐)** + 路径:`src/chat/express/style_learner.py` 暴露的单例 `style_learner_manager` + ```python + from src.chat.express.style_learner import style_learner_manager + + # 关闭资源上限(放开容量,谨慎使用) + style_learner_manager.set_resource_limit(False) + + # 再次开启资源上限 + style_learner_manager.set_resource_limit(True) + ``` + - 影响范围:实时作用于已创建的全部 learner(逐个同步 `resource_limit_enabled`)。 + - 生效时机:调用后立即生效,无需重启。 + +2) **构造时指定(不常用)** + - `StyleLearner(resource_limit_enabled: True|False, ...)` + - `StyleLearnerManager(resource_limit_enabled: True|False, ...)` + 用于自定义实例化逻辑(通常保持默认即可)。 + +3) **默认行为** + - 开关默认 **开启**,即启用容量管理与清理。 + - 没有配置文件项;若需持久化开关状态,可自行在启动代码中显式调用 `set_resource_limit`。 + +## 资源上限行为(开启时) +- 容量参数(每个 chat): + - `max_styles = 2000` + - `cleanup_threshold = 0.9`(≥90% 容量触发清理) + - `cleanup_ratio = 0.2`(清理低价值风格约 20%) +- 价值评分:结合使用频率(log 平滑)与最近使用时间(指数衰减),得分低者优先清理。 +- 仅对单个 learner 的容量管理生效;LRU 淘汰逻辑保持不变。 + +> ⚙️ 开关作用面: +> - **开启**:在 add_style 时会检查容量并触发 `_cleanup_styles`;预测/学习逻辑不变。 +> - **关闭**:不再触发容量清理,但 LRU 管理器仍可能在进程层面淘汰不活跃 learner。 + +## I/O 与健壮性 +- 模型与元数据保存采用原子写(`.tmp` + `os.replace`),避免部分写入。 +- `pickle` 使用 `HIGHEST_PROTOCOL`,并执行 `fsync` 确保落盘。 + +## 兼容性 +- 默认开启,无需修改配置文件;关闭后行为与旧版本类似。 +- 已有模型文件可直接加载,开关仅影响运行时清理策略。 + +## 何时建议开启/关闭 +- 开启(默认):内存/磁盘受限,或聊天风格高频增长,需防止模型膨胀。 +- 关闭:需要完整保留所有历史风格且资源充足,或进行一次性数据收集实验。 + +## 监控与调优建议 +- 监控:每 chat 风格数量、清理触发次数、删除数量、预测延迟 p95。 +- 如清理过于激进:提高 `cleanup_threshold` 或降低 `cleanup_ratio`。 +- 如内存/磁盘依旧偏高:降低 `max_styles`,或增加定期持久化与压缩策略。 diff --git a/src/chat/express/express_utils.py b/src/chat/express/express_utils.py index 96c175648..13e0efd0d 100644 --- a/src/chat/express/express_utils.py +++ b/src/chat/express/express_utils.py @@ -7,11 +7,26 @@ import random import re from typing import Any +try: + from sklearn.feature_extraction.text import TfidfVectorizer + from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity + + HAS_SKLEARN = True +except Exception: # pragma: no cover - 依赖缺失时静默回退 + HAS_SKLEARN = False + from src.common.logger import get_logger logger = get_logger("express_utils") +# 预编译正则,减少重复编译开销 +_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*") +_RE_AT = re.compile(r"@<[^>]*>") +_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]") +_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]") + + def filter_message_content(content: str | None) -> str: """ 过滤消息内容,移除回复、@、图片等格式 @@ -25,29 +40,56 @@ def filter_message_content(content: str | None) -> str: if not content: return "" - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r"\[回复.*?\],说:\s*", "", content) - # 移除@<...>格式的内容 - content = re.sub(r"@<[^>]*>", "", content) - # 移除[图片:...]格式的图片ID - content = re.sub(r"\[图片:[^\]]*\]", "", content) - # 移除[表情包:...]格式的内容 - content = re.sub(r"\[表情包:[^\]]*\]", "", content) + # 使用预编译正则提升性能 + content = _RE_REPLY.sub("", content) + content = _RE_AT.sub("", content) + content = _RE_IMAGE.sub("", content) + content = _RE_EMOJI.sub("", content) return content.strip() -def calculate_similarity(text1: str, text2: str) -> float: +def _similarity_tfidf(text1: str, text2: str) -> float | None: + """使用 TF-IDF + 余弦相似度;依赖 sklearn,缺失则返回 None。""" + if not HAS_SKLEARN: + return None + # 过短文本用传统算法更稳健 + if len(text1) < 2 or len(text2) < 2: + return None + try: + vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2)) + tfidf = vec.fit_transform([text1, text2]) + sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0]) + return max(0.0, min(1.0, sim)) + except Exception: + return None + + +def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float: """ 计算两个文本的相似度,返回0-1之间的值 + - 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒) + - 不可用或失败时回退到 SequenceMatcher + Args: text1: 第一个文本 text2: 第二个文本 + prefer_vector: 是否优先使用向量化方案(默认是) Returns: 相似度值 (0-1) """ + if not text1 or not text2: + return 0.0 + if text1 == text2: + return 1.0 + + if prefer_vector: + sim = _similarity_tfidf(text1, text2) + if sim is not None: + return sim + return difflib.SequenceMatcher(None, text1, text2).ratio() @@ -79,18 +121,10 @@ def weighted_sample(population: list[dict], k: int, weight_key: str | None = Non except (ValueError, TypeError) as e: logger.warning(f"加权抽样失败,使用等概率抽样: {e}") - # 等概率抽样 - selected = [] + # 等概率抽样(无放回,保持去重) population_copy = population.copy() - - for _ in range(k): - if not population_copy: - break - # 随机选择一个元素 - idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(idx)) - - return selected + # 使用 random.sample 提升可读性和性能 + return random.sample(population_copy, k) def normalize_text(text: str) -> str: @@ -130,8 +164,9 @@ def extract_keywords(text: str, max_keywords: int = 10) -> list[str]: return keywords except ImportError: logger.warning("rjieba未安装,无法提取关键词") - # 简单分词 + # 简单分词,按长度降序优先输出较长词,提升粗略关键词质量 words = text.split() + words.sort(key=len, reverse=True) return words[:max_keywords] @@ -236,15 +271,18 @@ def merge_expressions_from_multiple_chats( # 收集所有表达方式 for chat_id, expressions in expressions_dict.items(): for expr in expressions: - # 添加source_id标识 expr_with_source = expr.copy() expr_with_source["source_id"] = chat_id all_expressions.append(expr_with_source) - # 按count或last_active_time排序 - if all_expressions and "count" in all_expressions[0]: + if not all_expressions: + return [] + + # 选择排序键(优先 count,其次 last_active_time),无则保持原序 + sample = all_expressions[0] + if "count" in sample: all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True) - elif all_expressions and "last_active_time" in all_expressions[0]: + elif "last_active_time" in sample: all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True) # 去重(基于situation和style) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index f4086573e..5870a8bdb 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -358,7 +358,10 @@ class ExpressionLearner: @staticmethod @cached(ttl=600, key_prefix="chat_expressions") async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]: - """内部方法:从数据库获取表达方式(带缓存)""" + """内部方法:从数据库获取表达方式(带缓存) + + 🔥 优化:使用列表推导式和更高效的数据处理 + """ learnt_style_expressions = [] learnt_grammar_expressions = [] @@ -366,67 +369,91 @@ class ExpressionLearner: crud = CRUDBase(Expression) all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000) + # 🔥 优化:使用列表推导式批量处理,减少循环开销 for expr in all_expressions: - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - expr_data = { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": expr.type, - "create_date": create_date, - } + expr_data = { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": expr.type, + "create_date": create_date, + } - # 根据类型分类 - if expr.type == "style": - learnt_style_expressions.append(expr_data) - elif expr.type == "grammar": - learnt_grammar_expressions.append(expr_data) + # 根据类型分类(避免多次类型检查) + if expr.type == "style": + learnt_style_expressions.append(expr_data) + elif expr.type == "grammar": + learnt_grammar_expressions.append(expr_data) + logger.debug(f"已加载 {len(learnt_style_expressions)} 个style和 {len(learnt_grammar_expressions)} 个grammar表达方式 (chat_id={chat_id})") return learnt_style_expressions, learnt_grammar_expressions async def _apply_global_decay_to_database(self, current_time: float) -> None: """ 对数据库中的所有表达方式应用全局衰减 - 优化: 使用CRUD批量处理所有更改,最后统一提交 + 优化: 使用分批处理和原生 SQL 操作提升性能 """ try: - # 使用CRUD查询所有表达方式 - crud = CRUDBase(Expression) - all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式 - + BATCH_SIZE = 1000 # 分批处理,避免一次性加载过多数据 updated_count = 0 deleted_count = 0 + offset = 0 - # 需要手动操作的情况下使用session - async with get_db_session() as session: - # 批量处理所有修改 - for expr in all_expressions: - # 计算时间差 - last_active = expr.last_active_time - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days) - new_count = max(0.01, expr.count - decay_value) - - if new_count <= 0.01: - # 如果count太小,删除这个表达方式 - await session.delete(expr) - deleted_count += 1 - else: - # 更新count - expr.count = new_count - updated_count += 1 - - # 优化: 统一提交所有更改(从N次提交减少到1次) - if updated_count > 0 or deleted_count > 0: + while True: + async with get_db_session() as session: + # 分批查询表达方式 + batch_result = await session.execute( + select(Expression) + .order_by(Expression.id) + .limit(BATCH_SIZE) + .offset(offset) + ) + batch_expressions = list(batch_result.scalars()) + + if not batch_expressions: + break # 没有更多数据 + + # 批量处理当前批次 + to_delete = [] + for expr in batch_expressions: + # 计算时间差 + time_diff_days = (current_time - expr.last_active_time) / (24 * 3600) + + # 计算衰减值 + decay_value = self.calculate_decay_factor(time_diff_days) + new_count = max(0.01, expr.count - decay_value) + + if new_count <= 0.01: + # 标记删除 + to_delete.append(expr) + else: + # 更新count + expr.count = new_count + updated_count += 1 + + # 批量删除 + if to_delete: + for expr in to_delete: + await session.delete(expr) + deleted_count += len(to_delete) + + # 提交当前批次 await session.commit() - logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") + + # 如果批次不满,说明已经处理完所有数据 + if len(batch_expressions) < BATCH_SIZE: + break + + offset += BATCH_SIZE + + if updated_count > 0 or deleted_count > 0: + logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") except Exception as e: logger.error(f"数据库全局衰减失败: {e}") @@ -509,88 +536,106 @@ class ExpressionLearner: CRUDBase(Expression) for chat_id, expr_list in chat_dict.items(): async with get_db_session() as session: + # 🔥 优化:批量查询所有现有表达方式,避免N次数据库查询 + existing_exprs_result = await session.execute( + select(Expression).where( + (Expression.chat_id == chat_id) + & (Expression.type == type) + ) + ) + existing_exprs = list(existing_exprs_result.scalars()) + + # 构建快速查找索引 + exact_match_map = {} # (situation, style) -> Expression + situation_map = {} # situation -> Expression + style_map = {} # style -> Expression + + for expr in existing_exprs: + key = (expr.situation, expr.style) + exact_match_map[key] = expr + # 只保留第一个匹配(优先级:完全匹配 > 情景匹配 > 表达匹配) + if expr.situation not in situation_map: + situation_map[expr.situation] = expr + if expr.style not in style_map: + style_map[expr.style] = expr + + # 批量处理所有新表达方式 for new_expr in expr_list: - # 🔥 改进1:检查是否存在相同情景或相同表达的数据 - # 情况1:相同 chat_id + type + situation(相同情景,不同表达) - query_same_situation = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.situation == new_expr["situation"]) - ) - ) - same_situation_expr = query_same_situation.scalar() - - # 情况2:相同 chat_id + type + style(相同表达,不同情景) - query_same_style = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.style == new_expr["style"]) - ) - ) - same_style_expr = query_same_style.scalar() - - # 情况3:完全相同(相同情景+相同表达) - query_exact_match = await session.execute( - select(Expression).where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.situation == new_expr["situation"]) - & (Expression.style == new_expr["style"]) - ) - ) - exact_match_expr = query_exact_match.scalar() - + situation = new_expr["situation"] + style_val = new_expr["style"] + exact_key = (situation, style_val) + # 优先处理完全匹配的情况 - if exact_match_expr: + if exact_key in exact_match_map: # 完全相同:增加count,更新时间 - expr_obj = exact_match_expr + expr_obj = exact_match_map[exact_key] expr_obj.count = expr_obj.count + 1 expr_obj.last_active_time = current_time logger.debug(f"完全匹配:更新count {expr_obj.count}") - elif same_situation_expr: + elif situation in situation_map: # 相同情景,不同表达:覆盖旧的表达 - logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'") - same_situation_expr.style = new_expr["style"] + same_situation_expr = situation_map[situation] + logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'") + # 更新映射 + old_key = (same_situation_expr.situation, same_situation_expr.style) + if old_key in exact_match_map: + del exact_match_map[old_key] + same_situation_expr.style = style_val same_situation_expr.count = same_situation_expr.count + 1 same_situation_expr.last_active_time = current_time - elif same_style_expr: + # 更新新的完全匹配映射 + exact_match_map[exact_key] = same_situation_expr + elif style_val in style_map: # 相同表达,不同情景:覆盖旧的情景 - logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'") - same_style_expr.situation = new_expr["situation"] + same_style_expr = style_map[style_val] + logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'") + # 更新映射 + old_key = (same_style_expr.situation, same_style_expr.style) + if old_key in exact_match_map: + del exact_match_map[old_key] + same_style_expr.situation = situation same_style_expr.count = same_style_expr.count + 1 same_style_expr.last_active_time = current_time + # 更新新的完全匹配映射 + exact_match_map[exact_key] = same_style_expr + situation_map[situation] = same_style_expr else: # 完全新的表达方式:创建新记录 new_expression = Expression( - situation=new_expr["situation"], - style=new_expr["style"], + situation=situation, + style=style_val, count=1, last_active_time=current_time, chat_id=chat_id, type=type, - create_date=current_time, # 手动设置创建日期 + create_date=current_time, ) session.add(new_expression) - logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}") + # 更新映射 + exact_match_map[exact_key] = new_expression + situation_map[situation] = new_expression + style_map[style_val] = new_expression + logger.debug(f"新增表达方式:{situation} -> {style_val}") - # 限制最大数量 - 使用 get_all_by_sorted 获取排序结果 - exprs_result = await session.execute( - select(Expression) - .where((Expression.chat_id == chat_id) & (Expression.type == type)) - .order_by(Expression.count.asc()) - ) - exprs = list(exprs_result.scalars()) - if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除count最小的多余表达方式 - for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: + # 🔥 优化:限制最大数量 - 使用已加载的数据避免重复查询 + # existing_exprs 已包含该 chat_id 和 type 的所有表达方式 + all_current_exprs = list(exact_match_map.values()) + if len(all_current_exprs) > MAX_EXPRESSION_COUNT: + # 按 count 排序,删除 count 最小的多余表达方式 + sorted_exprs = sorted(all_current_exprs, key=lambda e: e.count) + for expr in sorted_exprs[: len(all_current_exprs) - MAX_EXPRESSION_COUNT]: await session.delete(expr) + # 从映射中移除 + key = (expr.situation, expr.style) + if key in exact_match_map: + del exact_match_map[key] + logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式") - # 提交后清除相关缓存 + # 提交数据库更改 await session.commit() - # 🔥 清除共享组内所有 chat_id 的表达方式缓存 + # 🔥 优化:只在实际有更新时才清除缓存(移到外层,避免重复清除) + if chat_dict: # 只有当有数据更新时才清除缓存 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key cache = await get_cache() @@ -602,53 +647,59 @@ class ExpressionLearner: if len(related_chat_ids) > 1: logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存") - # 🔥 训练 StyleLearner(支持共享组) - # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) - if type == "style": - try: - logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}") - - # 为每个共享组内的 chat_id 训练其 StyleLearner - for target_chat_id in related_chat_ids: - learner = style_learner_manager.get_learner(target_chat_id) + # 🔥 训练 StyleLearner(支持共享组) + # 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型) + if type == "style" and chat_dict: + try: + related_chat_ids = self.get_related_chat_ids() + total_samples = sum(len(expr_list) for expr_list in chat_dict.values()) + logger.debug(f"开始训练 StyleLearner: 共享组包含 {len(related_chat_ids)} 个chat_id, 总样本数={total_samples}") + # 为每个共享组内的 chat_id 训练其 StyleLearner + for target_chat_id in related_chat_ids: + learner = style_learner_manager.get_learner(target_chat_id) + + # 收集该 target_chat_id 对应的所有表达方式 + # 如果是源 chat_id,使用 chat_dict 中的数据;否则也要训练(共享组特性) + total_success = 0 + total_samples = 0 + + for source_chat_id, expr_list in chat_dict.items(): # 为每个学习到的表达方式训练模型 # 使用 situation 作为输入,style 作为目标 - # 这是最符合语义的方式:场景 -> 表达方式 - success_count = 0 for expr in expr_list: situation = expr["situation"] style = expr["style"] - + # 训练映射关系: situation -> style if learner.learn_mapping(situation, style): - success_count += 1 - else: - logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}") - - # 保存模型 + total_success += 1 + total_samples += 1 + + # 保存模型 + if total_samples > 0: if learner.save(style_learner_manager.model_save_path): logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}") else: logger.error(f"StyleLearner 模型保存失败: {target_chat_id}") - - if target_chat_id == chat_id: - # 只为源 chat_id 记录详细日志 + + if target_chat_id == self.chat_id: + # 只为当前 chat_id 记录详细日志 logger.info( - f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, " + f"StyleLearner 训练完成: {total_success}/{total_samples} 成功, " f"当前风格总数={len(learner.get_all_styles())}, " f"总样本数={learner.learning_stats['total_samples']}" ) else: logger.debug( - f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功" + f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {total_success}/{total_samples} 成功" ) - if len(related_chat_ids) > 1: - logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练") + if len(related_chat_ids) > 1: + logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练") - except Exception as e: - logger.error(f"训练 StyleLearner 失败: {e}") + except Exception as e: + logger.error(f"训练 StyleLearner 失败: {e}") return learnt_expressions return None diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 3359e7c05..cfe335cd4 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -207,31 +207,20 @@ class ExpressionSelector: select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")) ) - style_exprs = [ - { + # 🔥 优化:提前定义转换函数,避免重复代码 + def expr_to_dict(expr, expr_type: str) -> dict[str, Any]: + return { "situation": expr.situation, "style": expr.style, "count": expr.count, "last_active_time": expr.last_active_time, "source_id": expr.chat_id, - "type": "style", + "type": expr_type, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, } - for expr in style_query.scalars() - ] - - grammar_exprs = [ - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": expr.chat_id, - "type": "grammar", - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } - for expr in grammar_query.scalars() - ] + + style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()] + grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()] style_num = int(total_num * style_percentage) grammar_num = int(total_num * grammar_percentage) @@ -251,9 +240,14 @@ class ExpressionSelector: @staticmethod async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1): - """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" + """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库 + + 🔥 优化:合并所有更新到一个事务中,减少数据库连接开销 + """ if not expressions_to_update: return + + # 去重处理 updates_by_key = {} affected_chat_ids = set() for expr in expressions_to_update: @@ -269,9 +263,15 @@ class ExpressionSelector: updates_by_key[key] = expr affected_chat_ids.add(source_id) - for chat_id, expr_type, situation, style in updates_by_key: - async with get_db_session() as session: - query = await session.execute( + if not updates_by_key: + return + + # 🔥 优化:使用单个 session 批量处理所有更新 + current_time = time.time() + async with get_db_session() as session: + updated_count = 0 + for chat_id, expr_type, situation, style in updates_by_key: + query_result = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == expr_type) @@ -279,25 +279,26 @@ class ExpressionSelector: & (Expression.style == style) ) ) - query = query.scalar() - if query: - expr_obj = query + expr_obj = query_result.scalar() + if expr_obj: current_count = expr_obj.count new_count = min(current_count + increment, 5.0) expr_obj.count = new_count - expr_obj.last_active_time = time.time() + expr_obj.last_active_time = current_time + updated_count += 1 - logger.debug( - f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" - ) + # 批量提交所有更改 + if updated_count > 0: await session.commit() + logger.debug(f"批量更新了 {updated_count} 个表达方式的count值") # 清除所有受影响的chat_id的缓存 - from src.common.database.optimization.cache_manager import get_cache - from src.common.database.utils.decorators import generate_cache_key - cache = await get_cache() - for chat_id in affected_chat_ids: - await cache.delete(generate_cache_key("chat_expressions", chat_id)) + if affected_chat_ids: + from src.common.database.optimization.cache_manager import get_cache + from src.common.database.utils.decorators import generate_cache_key + cache = await get_cache() + for chat_id in affected_chat_ids: + await cache.delete(generate_cache_key("chat_expressions", chat_id)) async def select_suitable_expressions( self, @@ -518,29 +519,41 @@ class ExpressionSelector: logger.warning("数据库中完全没有任何表达方式,需要先学习") return [] - # 🔥 使用模糊匹配而不是精确匹配 - # 计算每个预测style与数据库style的相似度 + # 🔥 优化:使用更高效的模糊匹配算法 from difflib import SequenceMatcher + # 预处理:提前计算所有预测 style 的小写版本,避免重复计算 + predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]] + matched_expressions = [] for expr in all_expressions: db_style = expr.style or "" + db_style_lower = db_style.lower() max_similarity = 0.0 best_predicted = "" # 与每个预测的style计算相似度 - for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测 - # 计算字符串相似度 - similarity = SequenceMatcher(None, predicted_style, db_style).ratio() - - # 也检查包含关系(如果一个是另一个的子串,给更高分) - if len(predicted_style) >= 2 and len(db_style) >= 2: - if predicted_style in db_style or db_style in predicted_style: - similarity = max(similarity, 0.7) - + for predicted_style_lower, pred_score in predicted_styles_lower: + # 快速检查:完全匹配 + if predicted_style_lower == db_style_lower: + max_similarity = 1.0 + best_predicted = predicted_style_lower + break + + # 快速检查:子串匹配 + if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2: + if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower: + similarity = 0.7 + if similarity > max_similarity: + max_similarity = similarity + best_predicted = predicted_style_lower + continue + + # 计算字符串相似度(较慢,只在必要时使用) + similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio() if similarity > max_similarity: max_similarity = similarity - best_predicted = predicted_style + best_predicted = predicted_style_lower # 🔥 降低阈值到30%,因为StyleLearner预测质量较差 if max_similarity >= 0.3: # 30%相似度阈值 @@ -573,14 +586,15 @@ class ExpressionSelector: f"(候选 {len(matched_expressions)},temperature={temperature})" ) - # 转换为字典格式 + # 🔥 优化:使用列表推导式和预定义函数减少开销 expressions = [ { "situation": expr.situation or "", "style": expr.style or "", "type": expr.type or "style", "count": float(expr.count) if expr.count else 0.0, - "last_active_time": expr.last_active_time or 0.0 + "last_active_time": expr.last_active_time or 0.0, + "source_id": expr.chat_id # 添加 source_id 以便后续更新 } for expr in expressions_objs ] diff --git a/src/chat/express/situation_extractor.py b/src/chat/express/situation_extractor.py index 2fd6c9205..47e35e78a 100644 --- a/src/chat/express/situation_extractor.py +++ b/src/chat/express/situation_extractor.py @@ -127,7 +127,8 @@ class SituationExtractor: Returns: 情境描述列表 """ - situations = [] + situations: list[str] = [] + seen = set() for line in response.splitlines(): line = line.strip() @@ -150,6 +151,11 @@ class SituationExtractor: if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]): continue + # 去重,保持原有顺序 + if line in seen: + continue + seen.add(line) + situations.append(line) if len(situations) >= max_situations: diff --git a/src/chat/express/style_learner.py b/src/chat/express/style_learner.py index 3b099f3fd..ec76428d0 100644 --- a/src/chat/express/style_learner.py +++ b/src/chat/express/style_learner.py @@ -4,6 +4,7 @@ 支持多聊天室独立建模和在线学习 """ import os +import pickle import time from src.common.logger import get_logger @@ -16,11 +17,12 @@ logger = get_logger("expressor.style_learner") class StyleLearner: """单个聊天室的表达风格学习器""" - def __init__(self, chat_id: str, model_config: dict | None = None): + def __init__(self, chat_id: str, model_config: dict | None = None, resource_limit_enabled: bool = True): """ Args: chat_id: 聊天室ID model_config: 模型配置 + resource_limit_enabled: 是否启用资源上限控制(默认关闭) """ self.chat_id = chat_id self.model_config = model_config or { @@ -34,6 +36,9 @@ class StyleLearner: # 初始化表达模型 self.expressor = ExpressorModel(**self.model_config) + # 资源上限控制开关(默认开启,可按需关闭) + self.resource_limit_enabled = resource_limit_enabled + # 动态风格管理 self.max_styles = 2000 # 每个chat_id最多2000个风格 self.cleanup_threshold = 0.9 # 达到90%容量时触发清理 @@ -67,18 +72,15 @@ class StyleLearner: if style in self.style_to_id: return True - # 检查是否需要清理 - current_count = len(self.style_to_id) - cleanup_trigger = int(self.max_styles * self.cleanup_threshold) - - if current_count >= cleanup_trigger: - if current_count >= self.max_styles: - # 已经达到最大限制,必须清理 - logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理") - self._cleanup_styles() - elif current_count >= cleanup_trigger: - # 接近限制,提前清理 - logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理") + # 检查是否需要清理(仅计算一次阈值) + if self.resource_limit_enabled: + current_count = len(self.style_to_id) + cleanup_trigger = int(self.max_styles * self.cleanup_threshold) + if current_count >= cleanup_trigger: + if current_count >= self.max_styles: + logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理") + else: + logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理") self._cleanup_styles() # 生成新的style_id @@ -95,7 +97,8 @@ class StyleLearner: self.expressor.add_candidate(style_id, style, situation) # 初始化统计 - self.learning_stats["style_counts"][style_id] = 0 + self.learning_stats.setdefault("style_counts", {})[style_id] = 0 + self.learning_stats.setdefault("style_last_used", {}) logger.debug(f"添加风格成功: {style_id} -> {style}") return True @@ -114,64 +117,64 @@ class StyleLearner: 3. 默认清理 cleanup_ratio (20%) 的风格 """ try: + total_styles = len(self.style_to_id) + if total_styles == 0: + return + + # 只有在达到阈值时才执行昂贵的排序 + cleanup_count = max(1, int(total_styles * self.cleanup_ratio)) + if cleanup_count <= 0: + return + current_time = time.time() - cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio)) + # 局部引用加速频繁调用的函数 + from math import exp, log1p # 计算每个风格的价值分数 style_scores = [] for style_id in self.style_to_id.values(): - # 使用次数 usage_count = self.learning_stats["style_counts"].get(style_id, 0) - - # 最后使用时间(越近越好) last_used = self.learning_stats["style_last_used"].get(style_id, 0) + time_since_used = current_time - last_used if last_used > 0 else float("inf") + usage_score = log1p(usage_count) + days_unused = time_since_used / 86400 + time_score = exp(-days_unused / 30) - # 综合分数:使用次数越多越好,距离上次使用时间越短越好 - # 使用对数来平滑使用次数的影响 - import math - usage_score = math.log1p(usage_count) # log(1 + count) - - # 时间分数:转换为天数,使用指数衰减 - days_unused = time_since_used / 86400 # 转换为天 - time_score = math.exp(-days_unused / 30) # 30天衰减因子 - - # 综合分数:80%使用频率 + 20%时间新鲜度 total_score = 0.8 * usage_score + 0.2 * time_score - style_scores.append((style_id, total_score, usage_count, days_unused)) + if not style_scores: + return + # 按分数排序,分数低的先删除 style_scores.sort(key=lambda x: x[1]) - # 删除分数最低的风格 deleted_styles = [] for style_id, score, usage, days in style_scores[:cleanup_count]: style_text = self.id_to_style.get(style_id) - if style_text: - # 从映射中删除 - del self.style_to_id[style_text] - del self.id_to_style[style_id] - if style_id in self.id_to_situation: - del self.id_to_situation[style_id] + if not style_text: + continue - # 从统计中删除 - if style_id in self.learning_stats["style_counts"]: - del self.learning_stats["style_counts"][style_id] - if style_id in self.learning_stats["style_last_used"]: - del self.learning_stats["style_last_used"][style_id] + # 从映射中删除 + self.style_to_id.pop(style_text, None) + self.id_to_style.pop(style_id, None) + self.id_to_situation.pop(style_id, None) - # 从expressor模型中删除 - self.expressor.remove_candidate(style_id) + # 从统计中删除 + self.learning_stats["style_counts"].pop(style_id, None) + self.learning_stats["style_last_used"].pop(style_id, None) - deleted_styles.append((style_text[:30], usage, f"{days:.1f}天")) + # 从expressor模型中删除 + self.expressor.remove_candidate(style_id) + + deleted_styles.append((style_text[:30], usage, f"{days:.1f}天")) logger.info( f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格," f"剩余 {len(self.style_to_id)} 个风格" ) - # 记录前5个被删除的风格(用于调试) if deleted_styles: logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}") @@ -204,7 +207,9 @@ class StyleLearner: # 更新统计 current_time = time.time() self.learning_stats["total_samples"] += 1 - self.learning_stats["style_counts"][style_id] += 1 + self.learning_stats.setdefault("style_counts", {}) + self.learning_stats.setdefault("style_last_used", {}) + self.learning_stats["style_counts"][style_id] = self.learning_stats["style_counts"].get(style_id, 0) + 1 self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间 self.learning_stats["last_update"] = current_time @@ -349,11 +354,11 @@ class StyleLearner: # 保存expressor模型 model_path = os.path.join(save_dir, "expressor_model.pkl") - self.expressor.save(model_path) - - # 保存映射关系和统计信息 - import pickle + tmp_model_path = f"{model_path}.tmp" + self.expressor.save(tmp_model_path) + os.replace(tmp_model_path, model_path) + # 保存映射关系和统计信息(原子写) meta_path = os.path.join(save_dir, "meta.pkl") # 确保 learning_stats 包含所有必要字段 @@ -368,8 +373,13 @@ class StyleLearner: "learning_stats": self.learning_stats, } - with open(meta_path, "wb") as f: - pickle.dump(meta_data, f) + tmp_meta_path = f"{meta_path}.tmp" + with open(tmp_meta_path, "wb") as f: + pickle.dump(meta_data, f, protocol=pickle.HIGHEST_PROTOCOL) + f.flush() + os.fsync(f.fileno()) + + os.replace(tmp_meta_path, meta_path) return True @@ -401,8 +411,6 @@ class StyleLearner: self.expressor.load(model_path) # 加载映射关系和统计信息 - import pickle - meta_path = os.path.join(save_dir, "meta.pkl") if os.path.exists(meta_path): with open(meta_path, "rb") as f: @@ -445,14 +453,16 @@ class StyleLearnerManager: # 🔧 最大活跃 learner 数量 MAX_ACTIVE_LEARNERS = 50 - def __init__(self, model_save_path: str = "data/expression/style_models"): + def __init__(self, model_save_path: str = "data/expression/style_models", resource_limit_enabled: bool = True): """ Args: model_save_path: 模型保存路径 + resource_limit_enabled: 是否启用资源上限控制(默认开启) """ self.learners: dict[str, StyleLearner] = {} self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间 self.model_save_path = model_save_path + self.resource_limit_enabled = resource_limit_enabled # 确保保存目录存在 os.makedirs(model_save_path, exist_ok=True) @@ -475,7 +485,10 @@ class StyleLearnerManager: for chat_id, last_used in sorted_by_time[:evict_count]: if chat_id in self.learners: # 先保存再淘汰 - self.learners[chat_id].save(self.model_save_path) + try: + self.learners[chat_id].save(self.model_save_path) + except Exception as e: + logger.error(f"LRU淘汰时保存学习器失败: chat_id={chat_id}, error={e}") del self.learners[chat_id] del self.learner_last_used[chat_id] evicted.append(chat_id) @@ -502,7 +515,11 @@ class StyleLearnerManager: self._evict_if_needed() # 创建新的学习器 - learner = StyleLearner(chat_id, model_config) + learner = StyleLearner( + chat_id, + model_config, + resource_limit_enabled=self.resource_limit_enabled, + ) # 尝试加载已保存的模型 learner.load(self.model_save_path) @@ -511,6 +528,12 @@ class StyleLearnerManager: return self.learners[chat_id] + def set_resource_limit(self, enabled: bool) -> None: + """动态开启/关闭资源上限控制(默认关闭)。""" + self.resource_limit_enabled = enabled + for learner in self.learners.values(): + learner.resource_limit_enabled = enabled + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: """ 学习一个映射关系 diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py index 10e43aee4..ddcec1465 100644 --- a/src/chat/interest_system/interest_manager.py +++ b/src/chat/interest_system/interest_manager.py @@ -5,6 +5,7 @@ import asyncio import time +from collections import OrderedDict from typing import TYPE_CHECKING from src.common.logger import get_logger @@ -37,19 +38,50 @@ class InterestManager: self._calculation_queue = asyncio.Queue() self._worker_task = None self._shutdown_event = asyncio.Event() + + # 性能优化相关字段 + self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存 + self._cache_max_size = 1000 # 最大缓存数量 + self._cache_ttl = 300 # 缓存TTL(秒) + self._batch_queue: asyncio.Queue = asyncio.Queue(maxsize=100) # 批处理队列 + self._batch_size = 10 # 批处理大小 + self._batch_timeout = 0.1 # 批处理超时(秒) + self._batch_task = None + self._is_warmed_up = False # 预热状态标记 + + # 性能统计 + self._cache_hits = 0 + self._cache_misses = 0 + self._batch_calculations = 0 + self._total_calculation_time = 0.0 + self._initialized = True async def initialize(self): """初始化管理器""" - pass + # 启动批处理工作线程 + if self._batch_task is None or self._batch_task.done(): + self._batch_task = asyncio.create_task(self._batch_processing_worker()) + logger.info("批处理工作线程已启动") async def shutdown(self): """关闭管理器""" self._shutdown_event.set() + + # 取消批处理任务 + if self._batch_task and not self._batch_task.done(): + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass if self._current_calculator: await self._current_calculator.cleanup() self._current_calculator = None + + # 清理缓存 + self._result_cache.clear() logger.info("兴趣值管理器已关闭") @@ -91,12 +123,13 @@ class InterestManager: logger.error(f"注册兴趣值计算组件失败: {e}") return False - async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None) -> InterestCalculationResult: - """计算消息兴趣值 + async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None, use_cache: bool = True) -> InterestCalculationResult: + """计算消息兴趣值(优化版,支持缓存) Args: message: 数据库消息对象 timeout: 最大等待时间(秒),超时则使用默认值返回;为None时不设置超时 + use_cache: 是否使用缓存,默认True Returns: InterestCalculationResult: 计算结果或默认结果 @@ -109,37 +142,53 @@ class InterestManager: interest_value=0.3, error_message="没有可用的兴趣值计算组件", ) + + message_id = getattr(message, "message_id", "") + + # 缓存查询 + if use_cache and message_id: + cached_result = self._get_from_cache(message_id) + if cached_result is not None: + self._cache_hits += 1 + logger.debug(f"命中缓存: {message_id}, 兴趣值: {cached_result.interest_value:.3f}") + return cached_result + self._cache_misses += 1 # 使用 create_task 异步执行计算 task = asyncio.create_task(self._async_calculate(message)) if timeout is None: - return await task - - try: - # 等待计算结果,但有超时限制 - result = await asyncio.wait_for(task, timeout=timeout) - return result - except asyncio.TimeoutError: - # 超时返回默认结果,但计算仍在后台继续 - logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5") - return InterestCalculationResult( - success=True, - message_id=getattr(message, "message_id", ""), - interest_value=0.5, # 固定默认兴趣值 - should_reply=False, - should_act=False, - error_message=f"计算超时({timeout}s),使用默认值", - ) - except Exception as e: - # 发生异常,返回默认结果 - logger.error(f"兴趣值计算异常: {e}") - return InterestCalculationResult( - success=False, - message_id=getattr(message, "message_id", ""), - interest_value=0.3, - error_message=f"计算异常: {e!s}", - ) + result = await task + else: + try: + # 等待计算结果,但有超时限制 + result = await asyncio.wait_for(task, timeout=timeout) + except asyncio.TimeoutError: + # 超时返回默认结果,但计算仍在后台继续 + logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {message_id} 使用默认兴趣值 0.5") + return InterestCalculationResult( + success=True, + message_id=message_id, + interest_value=0.5, # 固定默认兴趣值 + should_reply=False, + should_act=False, + error_message=f"计算超时({timeout}s),使用默认值", + ) + except Exception as e: + # 发生异常,返回默认结果 + logger.error(f"兴趣值计算异常: {e}") + return InterestCalculationResult( + success=False, + message_id=message_id, + interest_value=0.3, + error_message=f"计算异常: {e!s}", + ) + + # 缓存结果 + if use_cache and result.success and message_id: + self._put_to_cache(message_id, result) + + return result async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: """异步执行兴趣值计算""" @@ -161,6 +210,7 @@ class InterestManager: if result.success: self._last_calculation_time = time.time() + self._total_calculation_time += result.calculation_time logger.debug(f"兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)") else: self._failed_calculations += 1 @@ -170,13 +220,15 @@ class InterestManager: except Exception as e: self._failed_calculations += 1 + calc_time = time.time() - start_time + self._total_calculation_time += calc_time logger.error(f"兴趣值计算异常: {e}") return InterestCalculationResult( success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=f"计算异常: {e!s}", - calculation_time=time.time() - start_time, + calculation_time=calc_time, ) async def _calculation_worker(self): @@ -197,6 +249,155 @@ class InterestManager: break except Exception as e: logger.error(f"计算工作线程异常: {e}") + + def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None: + """从缓存中获取结果(LRU策略)""" + if message_id not in self._result_cache: + return None + + # 检查TTL + result = self._result_cache[message_id] + if time.time() - result.timestamp > self._cache_ttl: + # 过期,删除 + del self._result_cache[message_id] + return None + + # 更新访问顺序(LRU) + self._result_cache.move_to_end(message_id) + return result + + def _put_to_cache(self, message_id: str, result: InterestCalculationResult): + """将结果放入缓存(LRU策略)""" + # 如果已存在,更新 + if message_id in self._result_cache: + self._result_cache.move_to_end(message_id) + + self._result_cache[message_id] = result + + # 限制缓存大小 + while len(self._result_cache) > self._cache_max_size: + # 删除最旧的项 + self._result_cache.popitem(last=False) + + async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]: + """批量计算消息兴趣值(并发优化) + + Args: + messages: 消息列表 + timeout: 单个计算的超时时间 + + Returns: + list[InterestCalculationResult]: 计算结果列表 + """ + if not messages: + return [] + + # 并发计算所有消息 + tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理异常 + final_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"批量计算消息 {i} 失败: {result}") + final_results.append(InterestCalculationResult( + success=False, + message_id=getattr(messages[i], "message_id", ""), + interest_value=0.3, + error_message=f"批量计算异常: {result!s}", + )) + else: + final_results.append(result) + + self._batch_calculations += 1 + return final_results + + async def _batch_processing_worker(self): + """批处理工作线程""" + while not self._shutdown_event.is_set(): + batch = [] + deadline = time.time() + self._batch_timeout + + try: + # 收集批次 + while len(batch) < self._batch_size and time.time() < deadline: + remaining_time = deadline - time.time() + if remaining_time <= 0: + break + + try: + item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time) + batch.append(item) + except asyncio.TimeoutError: + break + + # 处理批次 + if batch: + await self._process_batch(batch) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"批处理工作线程异常: {e}") + + async def _process_batch(self, batch: list): + """处理批次消息""" + # 这里可以实现具体的批处理逻辑 + # 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现 + pass + + async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None): + """预热兴趣计算器 + + Args: + sample_messages: 样本消息列表,用于预热。如果为None,则只初始化计算器 + """ + if not self._current_calculator: + logger.warning("无法预热:没有可用的兴趣值计算组件") + return + + logger.info("开始预热兴趣值计算器...") + start_time = time.time() + + # 如果提供了样本消息,进行预热计算 + if sample_messages: + try: + # 批量计算样本消息 + await self.calculate_interest_batch(sample_messages, timeout=5.0) + logger.info(f"预热完成:处理了 {len(sample_messages)} 条样本消息,耗时 {time.time() - start_time:.2f}s") + except Exception as e: + logger.error(f"预热过程中出现异常: {e}") + else: + logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s") + + self._is_warmed_up = True + + def clear_cache(self): + """清空缓存""" + cleared_count = len(self._result_cache) + self._result_cache.clear() + logger.info(f"已清空 {cleared_count} 条缓存记录") + + def set_cache_config(self, max_size: int | None = None, ttl: int | None = None): + """设置缓存配置 + + Args: + max_size: 最大缓存数量 + ttl: 缓存生存时间(秒) + """ + if max_size is not None: + self._cache_max_size = max_size + logger.info(f"缓存最大容量设置为: {max_size}") + + if ttl is not None: + self._cache_ttl = ttl + logger.info(f"缓存TTL设置为: {ttl}秒") + + # 如果当前缓存超过新的最大值,清理旧数据 + if max_size is not None: + while len(self._result_cache) > self._cache_max_size: + self._result_cache.popitem(last=False) def get_current_calculator(self) -> BaseInterestCalculator | None: """获取当前活跃的兴趣值计算组件""" @@ -205,6 +406,8 @@ class InterestManager: def get_statistics(self) -> dict: """获取管理器统计信息""" success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations)) + cache_hit_rate = self._cache_hits / max(1, self._cache_hits + self._cache_misses) + avg_calc_time = self._total_calculation_time / max(1, self._total_calculations) stats = { "manager_statistics": { @@ -213,6 +416,13 @@ class InterestManager: "success_rate": success_rate, "last_calculation_time": self._last_calculation_time, "current_calculator": self._current_calculator.component_name if self._current_calculator else None, + "cache_hit_rate": cache_hit_rate, + "cache_hits": self._cache_hits, + "cache_misses": self._cache_misses, + "cache_size": len(self._result_cache), + "batch_calculations": self._batch_calculations, + "average_calculation_time": avg_calc_time, + "is_warmed_up": self._is_warmed_up, } } @@ -236,6 +446,82 @@ class InterestManager: def has_calculator(self) -> bool: """检查是否有可用的计算组件""" return self._current_calculator is not None and self._current_calculator.is_enabled + + async def adaptive_optimize(self): + """自适应优化:根据性能统计自动调整参数""" + if not self._current_calculator: + return + + stats = self.get_statistics()["manager_statistics"] + + # 根据缓存命中率调整缓存大小 + cache_hit_rate = stats["cache_hit_rate"] + if cache_hit_rate < 0.5 and self._cache_max_size < 5000: + # 命中率低,增加缓存容量 + new_size = min(self._cache_max_size * 2, 5000) + logger.info(f"自适应优化:缓存命中率较低 ({cache_hit_rate:.2%}),扩大缓存容量 {self._cache_max_size} -> {new_size}") + self._cache_max_size = new_size + elif cache_hit_rate > 0.9 and self._cache_max_size > 100: + # 命中率高,可以适当减小缓存 + new_size = max(self._cache_max_size // 2, 100) + logger.info(f"自适应优化:缓存命中率很高 ({cache_hit_rate:.2%}),缩小缓存容量 {self._cache_max_size} -> {new_size}") + self._cache_max_size = new_size + # 清理多余缓存 + while len(self._result_cache) > self._cache_max_size: + self._result_cache.popitem(last=False) + + # 根据平均计算时间调整批处理参数 + avg_calc_time = stats["average_calculation_time"] + if avg_calc_time > 0.5 and self._batch_size < 50: + # 计算较慢,增加批次大小以提高吞吐量 + new_batch_size = min(self._batch_size * 2, 50) + logger.info(f"自适应优化:平均计算时间较长 ({avg_calc_time:.3f}s),增加批次大小 {self._batch_size} -> {new_batch_size}") + self._batch_size = new_batch_size + elif avg_calc_time < 0.1 and self._batch_size > 5: + # 计算较快,可以减小批次 + new_batch_size = max(self._batch_size // 2, 5) + logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}") + self._batch_size = new_batch_size + + def get_performance_report(self) -> str: + """生成性能报告""" + stats = self.get_statistics()["manager_statistics"] + + report = [ + "=" * 60, + "兴趣值管理器性能报告", + "=" * 60, + f"总计算次数: {stats['total_calculations']}", + f"失败次数: {stats['failed_calculations']}", + f"成功率: {stats['success_rate']:.2%}", + f"缓存命中率: {stats['cache_hit_rate']:.2%}", + f"缓存命中: {stats['cache_hits']}", + f"缓存未命中: {stats['cache_misses']}", + f"当前缓存大小: {stats['cache_size']} / {self._cache_max_size}", + f"批量计算次数: {stats['batch_calculations']}", + f"平均计算时间: {stats['average_calculation_time']:.4f}s", + f"是否已预热: {'是' if stats['is_warmed_up'] else '否'}", + f"当前计算器: {stats['current_calculator'] or '无'}", + "=" * 60, + ] + + # 添加计算器统计 + if self._current_calculator: + calc_stats = self.get_statistics()["calculator_statistics"] + report.extend([ + "", + "计算器统计:", + f" 组件名称: {calc_stats['component_name']}", + f" 版本: {calc_stats['component_version']}", + f" 已启用: {calc_stats['enabled']}", + f" 总计算: {calc_stats['total_calculations']}", + f" 失败: {calc_stats['failed_calculations']}", + f" 成功率: {calc_stats['success_rate']:.2%}", + f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s", + "=" * 60, + ]) + + return "\n".join(report) # 全局实例 diff --git a/src/memory_graph/short_term_pressure_patch.md b/src/memory_graph/short_term_pressure_patch.md new file mode 100644 index 000000000..6967fe41d --- /dev/null +++ b/src/memory_graph/short_term_pressure_patch.md @@ -0,0 +1,199 @@ +# 短期记忆压力泄压补丁 + +## 📋 概述 + +在高频消息场景下,短期记忆层(`ShortTermMemoryManager`)可能在自动转移机制触发前快速堆积大量记忆,当达到容量上限(`max_memories`)时可能阻塞后续写入。本功能提供一个**可选的泄压开关**,在容量溢出时自动删除低优先级记忆,防止系统阻塞。 + +**关键特性**: +- ✅ 默认关闭,保持向后兼容 +- ✅ 基于重要性和时间的智能删除策略 +- ✅ 异步持久化,不阻塞主流程 +- ✅ 可通过配置文件或代码控制 + +--- + +## 🔧 配置方法 + +### 方法 1:代码配置(直接创建管理器) + +如果您在代码中直接实例化 `UnifiedMemoryManager`: + +```python +from src.memory_graph.unified_manager import UnifiedMemoryManager + +manager = UnifiedMemoryManager( + short_term_enable_force_cleanup=True, # 开启泄压功能 + short_term_max_memories=30, # 短期记忆容量上限 + # ... 其他参数 +) +``` + +### 方法 2:配置文件(通过单例获取) + +**推荐方式**:如果您使用 `get_unified_memory_manager()` 单例,需修改配置文件。 + +#### ❌ 目前的问题 +配置文件 `config/bot_config.toml` 的 `[memory]` 节**尚未包含**此开关参数。 + +#### ✅ 解决方案 +在 `config/bot_config.toml` 的 `[memory]` 节添加: + +```toml +[memory] +# ... 其他配置 ... +short_term_max_memories = 30 # 短期记忆容量上限 +short_term_transfer_threshold = 0.6 # 转移到长期记忆的重要性阈值 +short_term_enable_force_cleanup = true # 开启压力泄压(建议高频场景开启) +``` + +然后在 `src/memory_graph/manager_singleton.py` 第 157-175 行的 `get_unified_memory_manager()` 函数中添加读取逻辑: + +```python +_unified_memory_manager = UnifiedMemoryManager( + # ... 其他参数 ... + short_term_enable_force_cleanup=getattr(config, "short_term_enable_force_cleanup", False), # 添加此行 +) +``` + +--- + +## ⚙️ 核心实现位置 + +### 1. 参数定义 +**文件**:`src/memory_graph/unified_manager.py` 第 47 行 +```python +class UnifiedMemoryManager: + def __init__( + self, + short_term_enable_force_cleanup: bool = False, # 开关参数 + ): +``` + +### 2. 传递到短期层 +**文件**:`src/memory_graph/unified_manager.py` 第 100 行 +```python +"short_term": { + "enable_force_cleanup": short_term_enable_force_cleanup, # 传递给 ShortTermMemoryManager +} +``` + +### 3. 泄压逻辑实现 +**文件**:`src/memory_graph/short_term_manager.py` 第 693-726 行 +```python +def force_cleanup_overflow(self, keep_ratio: float = 0.9) -> int: + """当短期记忆超过容量时,强制删除低重要性且最早的记忆以泄压""" + if not self.enable_force_cleanup: # 检查开关 + return 0 + # ... 删除逻辑 +``` + +### 4. 触发条件 +**文件**:`src/memory_graph/unified_manager.py` 第 618-621 行 +```python +# 在自动转移循环中检测 +if occupancy_ratio >= 1.0 and not transfer_cache: + removed = self.short_term_manager.force_cleanup_overflow() + if removed > 0: + logger.warning(f"短期记忆占用率 {occupancy_ratio:.0%},已强制删除 {removed} 条低重要性记忆泄压") +``` + +--- + +## 🔄 运行机制 + +### 触发条件(同时满足) +1. ✅ 开关已开启(`enable_force_cleanup=True`) +2. ✅ 短期记忆占用率 ≥ 100%(`len(memories) >= max_memories`) +3. ✅ 当前没有待转移批次(`transfer_cache` 为空) + +### 删除策略 +**排序规则**:双重排序,先按重要性升序,再按创建时间升序 +```python +sorted_memories = sorted(self.memories, key=lambda m: (m.importance, m.created_at)) +``` + +**删除数量**:删除到容量的 90% +```python +current = len(self.memories) # 当前记忆数 +limit = int(self.max_memories * 0.9) # 目标保留数 +remove_count = current - limit # 需要删除的数量 +``` + +**示例**: +- 容量上限 `max_memories=30` +- 当前记忆数 `35` → 删除 `35 - 27 = 8` 条最低优先级记忆 +- 优先删除:重要性 0.1 且创建于 10 分钟前的记忆 + +### 持久化 +- 使用 `asyncio.create_task(self._save_to_disk())` 异步保存 +- **不阻塞**消息处理主流程 + +--- + +## 📊 性能影响 + +| 场景 | 开关状态 | 行为 | 适用场景 | +|------|---------|------|---------| +| 高频消息 | ✅ 开启 | 自动泄压,防止阻塞 | 群聊、客服场景 | +| 低频消息 | ❌ 关闭 | 仅依赖自动转移 | 私聊、低活跃群 | +| 调试阶段 | ❌ 关闭 | 便于观察记忆堆积 | 开发测试 | + +**日志示例**(开启后): +``` +[WARNING] 短期记忆压力泄压: 移除 8 条 (当前 27/30) +[WARNING] 短期记忆占用率 100%,已强制删除 8 条低重要性记忆泄压 +``` + +--- + +## 🚨 注意事项 + +### ⚠️ 何时开启 +- ✅ **推荐开启**:高频群聊、客服机器人、24/7 运行场景 +- ❌ **不建议开启**:需要完整保留所有短期记忆、调试阶段 + +### ⚠️ 潜在影响 +- 低重要性记忆可能被删除,**不会转移到长期记忆** +- 如需保留所有记忆,应调大 `max_memories` 或关闭此功能 + +### ⚠️ 与自动转移的协同 +本功能是**兜底机制**,正常情况下: +1. 优先触发自动转移(占用率 ≥ 50%) +2. 高重要性记忆转移到长期层 +3. 仅当转移来不及时,泄压才会触发 + +--- + +## 🔙 回滚与禁用 + +### 临时禁用(无需重启) +```python +# 运行时修改(如果您能访问管理器实例) +unified_manager.short_term_manager.enable_force_cleanup = False +``` + +### 永久禁用 +**配置文件方式**: +```toml +[memory] +short_term_enable_force_cleanup = false # 或直接删除此行 +``` + +**代码方式**: +```python +manager = UnifiedMemoryManager( + short_term_enable_force_cleanup=False, # 显式关闭 +) +``` + +--- + +## 📚 相关文档 + +- [三层记忆系统用户指南](../../docs/three_tier_memory_user_guide.md) +- [记忆图谱架构](../../docs/memory_graph_guide.md) +- [统一调度器指南](../../docs/unified_scheduler_guide.md) + +--- + +**最后更新**:2025年12月16日 diff --git a/src/plugin_system/base/base_interest_calculator.py b/src/plugin_system/base/base_interest_calculator.py index 17ce66c0c..c8192a74f 100644 --- a/src/plugin_system/base/base_interest_calculator.py +++ b/src/plugin_system/base/base_interest_calculator.py @@ -117,10 +117,17 @@ class BaseInterestCalculator(ABC): """ try: self._enabled = True + # 子类可以重写此方法执行自定义初始化 + await self.on_initialize() return True - except Exception: + except Exception as e: + logger.error(f"初始化兴趣计算器失败: {e}") self._enabled = False return False + + async def on_initialize(self): + """子类可重写的初始化钩子""" + pass async def cleanup(self) -> bool: """清理组件资源 @@ -129,10 +136,17 @@ class BaseInterestCalculator(ABC): bool: 清理是否成功 """ try: + # 子类可以重写此方法执行自定义清理 + await self.on_cleanup() self._enabled = False return True - except Exception: + except Exception as e: + logger.error(f"清理兴趣计算器失败: {e}") return False + + async def on_cleanup(self): + """子类可重写的清理钩子""" + pass @property def is_enabled(self) -> bool: