Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
36
docs/express_similarity.md
Normal file
36
docs/express_similarity.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# 表达相似度计算策略
|
||||
|
||||
本文档说明 `calculate_similarity` 的实现与配置,帮助在质量与性能间做权衡。
|
||||
|
||||
## 总览
|
||||
- 支持两种路径:
|
||||
1) **向量化路径(默认优先)**:TF-IDF + 余弦相似度(依赖 `scikit-learn`)
|
||||
2) **回退路径**:`difflib.SequenceMatcher`
|
||||
- 参数 `prefer_vector` 控制是否优先尝试向量化,默认 `True`。
|
||||
- 依赖缺失或文本过短时,自动回退,无需额外配置。
|
||||
|
||||
## 调用方式
|
||||
```python
|
||||
from src.chat.express.express_utils import calculate_similarity
|
||||
|
||||
sim = calculate_similarity(text1, text2) # 默认优先向量化
|
||||
sim_fast = calculate_similarity(text1, text2, prefer_vector=False) # 强制使用 SequenceMatcher
|
||||
```
|
||||
|
||||
## 依赖与回退
|
||||
- 可选依赖:`scikit-learn`
|
||||
- 缺失时自动回退到 `SequenceMatcher`,不会抛异常。
|
||||
- 文本过短(长度 < 2)时直接回退,避免稀疏向量噪声。
|
||||
|
||||
## 适用建议
|
||||
- 文本较长、对鲁棒性/语义相似度有更高要求:保持默认(向量化优先)。
|
||||
- 环境无 `scikit-learn` 或追求极简依赖:调用时设置 `prefer_vector=False`。
|
||||
- 高并发性能敏感:可在调用点酌情关闭向量化或加缓存。
|
||||
|
||||
## 返回范围
|
||||
- 相似度范围始终在 `[0, 1]`。
|
||||
- 空字符串 → `0.0`;完全相同 → `1.0`。
|
||||
|
||||
## 额外建议
|
||||
- 若需更强语义能力,可替换为向量数据库或句向量模型(需新增依赖与配置)。
|
||||
- 对热路径可增加缓存(按文本哈希),或限制输入长度以控制向量维度与内存。
|
||||
@@ -30,7 +30,7 @@
|
||||
|
||||
## 影响范围
|
||||
|
||||
- 默认行为保持与补丁前一致(开关默认 `on`)。
|
||||
- 默认行为保持与补丁前一致(开关默认 `off`)。
|
||||
- 如果关闭开关,短期层将不再做强制删除,只依赖自动转移机制。
|
||||
|
||||
## 回滚
|
||||
60
docs/style_learner_resource_limit.md
Normal file
60
docs/style_learner_resource_limit.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# StyleLearner 资源上限开关(默认开启)
|
||||
|
||||
## 概览
|
||||
StyleLearner 支持资源上限控制,用于约束风格容量与清理行为。开关默认 **开启**,以防止模型无限膨胀;可在运行时动态关闭。
|
||||
|
||||
## 开关位置与用法(务必看这里)
|
||||
|
||||
开关在 **代码层**,默认开启,不依赖配置文件。
|
||||
|
||||
1) **全局运行时切换(推荐)**
|
||||
路径:`src/chat/express/style_learner.py` 暴露的单例 `style_learner_manager`
|
||||
```python
|
||||
from src.chat.express.style_learner import style_learner_manager
|
||||
|
||||
# 关闭资源上限(放开容量,谨慎使用)
|
||||
style_learner_manager.set_resource_limit(False)
|
||||
|
||||
# 再次开启资源上限
|
||||
style_learner_manager.set_resource_limit(True)
|
||||
```
|
||||
- 影响范围:实时作用于已创建的全部 learner(逐个同步 `resource_limit_enabled`)。
|
||||
- 生效时机:调用后立即生效,无需重启。
|
||||
|
||||
2) **构造时指定(不常用)**
|
||||
- `StyleLearner(resource_limit_enabled: True|False, ...)`
|
||||
- `StyleLearnerManager(resource_limit_enabled: True|False, ...)`
|
||||
用于自定义实例化逻辑(通常保持默认即可)。
|
||||
|
||||
3) **默认行为**
|
||||
- 开关默认 **开启**,即启用容量管理与清理。
|
||||
- 没有配置文件项;若需持久化开关状态,可自行在启动代码中显式调用 `set_resource_limit`。
|
||||
|
||||
## 资源上限行为(开启时)
|
||||
- 容量参数(每个 chat):
|
||||
- `max_styles = 2000`
|
||||
- `cleanup_threshold = 0.9`(≥90% 容量触发清理)
|
||||
- `cleanup_ratio = 0.2`(清理低价值风格约 20%)
|
||||
- 价值评分:结合使用频率(log 平滑)与最近使用时间(指数衰减),得分低者优先清理。
|
||||
- 仅对单个 learner 的容量管理生效;LRU 淘汰逻辑保持不变。
|
||||
|
||||
> ⚙️ 开关作用面:
|
||||
> - **开启**:在 add_style 时会检查容量并触发 `_cleanup_styles`;预测/学习逻辑不变。
|
||||
> - **关闭**:不再触发容量清理,但 LRU 管理器仍可能在进程层面淘汰不活跃 learner。
|
||||
|
||||
## I/O 与健壮性
|
||||
- 模型与元数据保存采用原子写(`.tmp` + `os.replace`),避免部分写入。
|
||||
- `pickle` 使用 `HIGHEST_PROTOCOL`,并执行 `fsync` 确保落盘。
|
||||
|
||||
## 兼容性
|
||||
- 默认开启,无需修改配置文件;关闭后行为与旧版本类似。
|
||||
- 已有模型文件可直接加载,开关仅影响运行时清理策略。
|
||||
|
||||
## 何时建议开启/关闭
|
||||
- 开启(默认):内存/磁盘受限,或聊天风格高频增长,需防止模型膨胀。
|
||||
- 关闭:需要完整保留所有历史风格且资源充足,或进行一次性数据收集实验。
|
||||
|
||||
## 监控与调优建议
|
||||
- 监控:每 chat 风格数量、清理触发次数、删除数量、预测延迟 p95。
|
||||
- 如清理过于激进:提高 `cleanup_threshold` 或降低 `cleanup_ratio`。
|
||||
- 如内存/磁盘依旧偏高:降低 `max_styles`,或增加定期持久化与压缩策略。
|
||||
@@ -7,11 +7,26 @@ import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity as _sk_cosine_similarity
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except Exception: # pragma: no cover - 依赖缺失时静默回退
|
||||
HAS_SKLEARN = False
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
# 预编译正则,减少重复编译开销
|
||||
_RE_REPLY = re.compile(r"\[回复.*?\],说:\s*")
|
||||
_RE_AT = re.compile(r"@<[^>]*>")
|
||||
_RE_IMAGE = re.compile(r"\[图片:[^\]]*\]")
|
||||
_RE_EMOJI = re.compile(r"\[表情包:[^\]]*\]")
|
||||
|
||||
|
||||
def filter_message_content(content: str | None) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
@@ -25,29 +40,56 @@ def filter_message_content(content: str | None) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[图片:...]格式的图片ID
|
||||
content = re.sub(r"\[图片:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
# 使用预编译正则提升性能
|
||||
content = _RE_REPLY.sub("", content)
|
||||
content = _RE_AT.sub("", content)
|
||||
content = _RE_IMAGE.sub("", content)
|
||||
content = _RE_EMOJI.sub("", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def _similarity_tfidf(text1: str, text2: str) -> float | None:
|
||||
"""使用 TF-IDF + 余弦相似度;依赖 sklearn,缺失则返回 None。"""
|
||||
if not HAS_SKLEARN:
|
||||
return None
|
||||
# 过短文本用传统算法更稳健
|
||||
if len(text1) < 2 or len(text2) < 2:
|
||||
return None
|
||||
try:
|
||||
vec = TfidfVectorizer(max_features=1024, ngram_range=(1, 2))
|
||||
tfidf = vec.fit_transform([text1, text2])
|
||||
sim = float(_sk_cosine_similarity(tfidf[0], tfidf[1])[0, 0])
|
||||
return max(0.0, min(1.0, sim))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str, prefer_vector: bool = True) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
|
||||
- 当可用且文本足够长时,优先尝试 TF-IDF 向量相似度(更鲁棒)
|
||||
- 不可用或失败时回退到 SequenceMatcher
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
prefer_vector: 是否优先使用向量化方案(默认是)
|
||||
|
||||
Returns:
|
||||
相似度值 (0-1)
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
if text1 == text2:
|
||||
return 1.0
|
||||
|
||||
if prefer_vector:
|
||||
sim = _similarity_tfidf(text1, text2)
|
||||
if sim is not None:
|
||||
return sim
|
||||
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
@@ -79,18 +121,10 @@ def weighted_sample(population: list[dict], k: int, weight_key: str | None = Non
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||||
|
||||
# 等概率抽样
|
||||
selected = []
|
||||
# 等概率抽样(无放回,保持去重)
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
|
||||
return selected
|
||||
# 使用 random.sample 提升可读性和性能
|
||||
return random.sample(population_copy, k)
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
@@ -130,8 +164,9 @@ def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||||
return keywords
|
||||
except ImportError:
|
||||
logger.warning("rjieba未安装,无法提取关键词")
|
||||
# 简单分词
|
||||
# 简单分词,按长度降序优先输出较长词,提升粗略关键词质量
|
||||
words = text.split()
|
||||
words.sort(key=len, reverse=True)
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
@@ -236,15 +271,18 @@ def merge_expressions_from_multiple_chats(
|
||||
# 收集所有表达方式
|
||||
for chat_id, expressions in expressions_dict.items():
|
||||
for expr in expressions:
|
||||
# 添加source_id标识
|
||||
expr_with_source = expr.copy()
|
||||
expr_with_source["source_id"] = chat_id
|
||||
all_expressions.append(expr_with_source)
|
||||
|
||||
# 按count或last_active_time排序
|
||||
if all_expressions and "count" in all_expressions[0]:
|
||||
if not all_expressions:
|
||||
return []
|
||||
|
||||
# 选择排序键(优先 count,其次 last_active_time),无则保持原序
|
||||
sample = all_expressions[0]
|
||||
if "count" in sample:
|
||||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||
elif all_expressions and "last_active_time" in all_expressions[0]:
|
||||
elif "last_active_time" in sample:
|
||||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||||
|
||||
# 去重(基于situation和style)
|
||||
|
||||
@@ -358,7 +358,10 @@ class ExpressionLearner:
|
||||
@staticmethod
|
||||
@cached(ttl=600, key_prefix="chat_expressions")
|
||||
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""内部方法:从数据库获取表达方式(带缓存)"""
|
||||
"""内部方法:从数据库获取表达方式(带缓存)
|
||||
|
||||
🔥 优化:使用列表推导式和更高效的数据处理
|
||||
"""
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
@@ -366,67 +369,91 @@ class ExpressionLearner:
|
||||
crud = CRUDBase(Expression)
|
||||
all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000)
|
||||
|
||||
# 🔥 优化:使用列表推导式批量处理,减少循环开销
|
||||
for expr in all_expressions:
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
|
||||
expr_data = {
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": expr.type,
|
||||
"create_date": create_date,
|
||||
}
|
||||
expr_data = {
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": expr.type,
|
||||
"create_date": create_date,
|
||||
}
|
||||
|
||||
# 根据类型分类
|
||||
if expr.type == "style":
|
||||
learnt_style_expressions.append(expr_data)
|
||||
elif expr.type == "grammar":
|
||||
learnt_grammar_expressions.append(expr_data)
|
||||
# 根据类型分类(避免多次类型检查)
|
||||
if expr.type == "style":
|
||||
learnt_style_expressions.append(expr_data)
|
||||
elif expr.type == "grammar":
|
||||
learnt_grammar_expressions.append(expr_data)
|
||||
|
||||
logger.debug(f"已加载 {len(learnt_style_expressions)} 个style和 {len(learnt_grammar_expressions)} 个grammar表达方式 (chat_id={chat_id})")
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
async def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
|
||||
优化: 使用CRUD批量处理所有更改,最后统一提交
|
||||
优化: 使用分批处理和原生 SQL 操作提升性能
|
||||
"""
|
||||
try:
|
||||
# 使用CRUD查询所有表达方式
|
||||
crud = CRUDBase(Expression)
|
||||
all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式
|
||||
|
||||
BATCH_SIZE = 1000 # 分批处理,避免一次性加载过多数据
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
offset = 0
|
||||
|
||||
# 需要手动操作的情况下使用session
|
||||
async with get_db_session() as session:
|
||||
# 批量处理所有修改
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
await session.delete(expr)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
# 优化: 统一提交所有更改(从N次提交减少到1次)
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
while True:
|
||||
async with get_db_session() as session:
|
||||
# 分批查询表达方式
|
||||
batch_result = await session.execute(
|
||||
select(Expression)
|
||||
.order_by(Expression.id)
|
||||
.limit(BATCH_SIZE)
|
||||
.offset(offset)
|
||||
)
|
||||
batch_expressions = list(batch_result.scalars())
|
||||
|
||||
if not batch_expressions:
|
||||
break # 没有更多数据
|
||||
|
||||
# 批量处理当前批次
|
||||
to_delete = []
|
||||
for expr in batch_expressions:
|
||||
# 计算时间差
|
||||
time_diff_days = (current_time - expr.last_active_time) / (24 * 3600)
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 标记删除
|
||||
to_delete.append(expr)
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
# 批量删除
|
||||
if to_delete:
|
||||
for expr in to_delete:
|
||||
await session.delete(expr)
|
||||
deleted_count += len(to_delete)
|
||||
|
||||
# 提交当前批次
|
||||
await session.commit()
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
# 如果批次不满,说明已经处理完所有数据
|
||||
if len(batch_expressions) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
@@ -509,88 +536,106 @@ class ExpressionLearner:
|
||||
CRUDBase(Expression)
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
async with get_db_session() as session:
|
||||
# 🔥 优化:批量查询所有现有表达方式,避免N次数据库查询
|
||||
existing_exprs_result = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
)
|
||||
)
|
||||
existing_exprs = list(existing_exprs_result.scalars())
|
||||
|
||||
# 构建快速查找索引
|
||||
exact_match_map = {} # (situation, style) -> Expression
|
||||
situation_map = {} # situation -> Expression
|
||||
style_map = {} # style -> Expression
|
||||
|
||||
for expr in existing_exprs:
|
||||
key = (expr.situation, expr.style)
|
||||
exact_match_map[key] = expr
|
||||
# 只保留第一个匹配(优先级:完全匹配 > 情景匹配 > 表达匹配)
|
||||
if expr.situation not in situation_map:
|
||||
situation_map[expr.situation] = expr
|
||||
if expr.style not in style_map:
|
||||
style_map[expr.style] = expr
|
||||
|
||||
# 批量处理所有新表达方式
|
||||
for new_expr in expr_list:
|
||||
# 🔥 改进1:检查是否存在相同情景或相同表达的数据
|
||||
# 情况1:相同 chat_id + type + situation(相同情景,不同表达)
|
||||
query_same_situation = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
)
|
||||
)
|
||||
same_situation_expr = query_same_situation.scalar()
|
||||
|
||||
# 情况2:相同 chat_id + type + style(相同表达,不同情景)
|
||||
query_same_style = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
)
|
||||
same_style_expr = query_same_style.scalar()
|
||||
|
||||
# 情况3:完全相同(相同情景+相同表达)
|
||||
query_exact_match = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
)
|
||||
exact_match_expr = query_exact_match.scalar()
|
||||
|
||||
situation = new_expr["situation"]
|
||||
style_val = new_expr["style"]
|
||||
exact_key = (situation, style_val)
|
||||
|
||||
# 优先处理完全匹配的情况
|
||||
if exact_match_expr:
|
||||
if exact_key in exact_match_map:
|
||||
# 完全相同:增加count,更新时间
|
||||
expr_obj = exact_match_expr
|
||||
expr_obj = exact_match_map[exact_key]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
logger.debug(f"完全匹配:更新count {expr_obj.count}")
|
||||
elif same_situation_expr:
|
||||
elif situation in situation_map:
|
||||
# 相同情景,不同表达:覆盖旧的表达
|
||||
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{new_expr['style']}'")
|
||||
same_situation_expr.style = new_expr["style"]
|
||||
same_situation_expr = situation_map[situation]
|
||||
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'")
|
||||
# 更新映射
|
||||
old_key = (same_situation_expr.situation, same_situation_expr.style)
|
||||
if old_key in exact_match_map:
|
||||
del exact_match_map[old_key]
|
||||
same_situation_expr.style = style_val
|
||||
same_situation_expr.count = same_situation_expr.count + 1
|
||||
same_situation_expr.last_active_time = current_time
|
||||
elif same_style_expr:
|
||||
# 更新新的完全匹配映射
|
||||
exact_match_map[exact_key] = same_situation_expr
|
||||
elif style_val in style_map:
|
||||
# 相同表达,不同情景:覆盖旧的情景
|
||||
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{new_expr['situation']}'")
|
||||
same_style_expr.situation = new_expr["situation"]
|
||||
same_style_expr = style_map[style_val]
|
||||
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'")
|
||||
# 更新映射
|
||||
old_key = (same_style_expr.situation, same_style_expr.style)
|
||||
if old_key in exact_match_map:
|
||||
del exact_match_map[old_key]
|
||||
same_style_expr.situation = situation
|
||||
same_style_expr.count = same_style_expr.count + 1
|
||||
same_style_expr.last_active_time = current_time
|
||||
# 更新新的完全匹配映射
|
||||
exact_match_map[exact_key] = same_style_expr
|
||||
situation_map[situation] = same_style_expr
|
||||
else:
|
||||
# 完全新的表达方式:创建新记录
|
||||
new_expression = Expression(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
create_date=current_time,
|
||||
)
|
||||
session.add(new_expression)
|
||||
logger.debug(f"新增表达方式:{new_expr['situation']} -> {new_expr['style']}")
|
||||
# 更新映射
|
||||
exact_match_map[exact_key] = new_expression
|
||||
situation_map[situation] = new_expression
|
||||
style_map[style_val] = new_expression
|
||||
logger.debug(f"新增表达方式:{situation} -> {style_val}")
|
||||
|
||||
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
|
||||
exprs_result = await session.execute(
|
||||
select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
exprs = list(exprs_result.scalars())
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
# 🔥 优化:限制最大数量 - 使用已加载的数据避免重复查询
|
||||
# existing_exprs 已包含该 chat_id 和 type 的所有表达方式
|
||||
all_current_exprs = list(exact_match_map.values())
|
||||
if len(all_current_exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 按 count 排序,删除 count 最小的多余表达方式
|
||||
sorted_exprs = sorted(all_current_exprs, key=lambda e: e.count)
|
||||
for expr in sorted_exprs[: len(all_current_exprs) - MAX_EXPRESSION_COUNT]:
|
||||
await session.delete(expr)
|
||||
# 从映射中移除
|
||||
key = (expr.situation, expr.style)
|
||||
if key in exact_match_map:
|
||||
del exact_match_map[key]
|
||||
logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式")
|
||||
|
||||
# 提交后清除相关缓存
|
||||
# 提交数据库更改
|
||||
await session.commit()
|
||||
|
||||
# 🔥 清除共享组内所有 chat_id 的表达方式缓存
|
||||
# 🔥 优化:只在实际有更新时才清除缓存(移到外层,避免重复清除)
|
||||
if chat_dict: # 只有当有数据更新时才清除缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
@@ -602,53 +647,59 @@ class ExpressionLearner:
|
||||
if len(related_chat_ids) > 1:
|
||||
logger.debug(f"已清除共享组内 {len(related_chat_ids)} 个 chat_id 的表达方式缓存")
|
||||
|
||||
# 🔥 训练 StyleLearner(支持共享组)
|
||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||
if type == "style":
|
||||
try:
|
||||
logger.debug(f"开始训练 StyleLearner: 源chat_id={chat_id}, 共享组包含 {len(related_chat_ids)} 个chat_id, 样本数={len(expr_list)}")
|
||||
|
||||
# 为每个共享组内的 chat_id 训练其 StyleLearner
|
||||
for target_chat_id in related_chat_ids:
|
||||
learner = style_learner_manager.get_learner(target_chat_id)
|
||||
# 🔥 训练 StyleLearner(支持共享组)
|
||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||
if type == "style" and chat_dict:
|
||||
try:
|
||||
related_chat_ids = self.get_related_chat_ids()
|
||||
total_samples = sum(len(expr_list) for expr_list in chat_dict.values())
|
||||
logger.debug(f"开始训练 StyleLearner: 共享组包含 {len(related_chat_ids)} 个chat_id, 总样本数={total_samples}")
|
||||
|
||||
# 为每个共享组内的 chat_id 训练其 StyleLearner
|
||||
for target_chat_id in related_chat_ids:
|
||||
learner = style_learner_manager.get_learner(target_chat_id)
|
||||
|
||||
# 收集该 target_chat_id 对应的所有表达方式
|
||||
# 如果是源 chat_id,使用 chat_dict 中的数据;否则也要训练(共享组特性)
|
||||
total_success = 0
|
||||
total_samples = 0
|
||||
|
||||
for source_chat_id, expr_list in chat_dict.items():
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
# 这是最符合语义的方式:场景 -> 表达方式
|
||||
success_count = 0
|
||||
for expr in expr_list:
|
||||
situation = expr["situation"]
|
||||
style = expr["style"]
|
||||
|
||||
|
||||
# 训练映射关系: situation -> style
|
||||
if learner.learn_mapping(situation, style):
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(f"训练失败 (target={target_chat_id}): {situation} -> {style}")
|
||||
|
||||
# 保存模型
|
||||
total_success += 1
|
||||
total_samples += 1
|
||||
|
||||
# 保存模型
|
||||
if total_samples > 0:
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
|
||||
|
||||
if target_chat_id == chat_id:
|
||||
# 只为源 chat_id 记录详细日志
|
||||
|
||||
if target_chat_id == self.chat_id:
|
||||
# 只为当前 chat_id 记录详细日志
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成 (源): {success_count}/{len(expr_list)} 成功, "
|
||||
f"StyleLearner 训练完成: {total_success}/{total_samples} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {success_count}/{len(expr_list)} 成功"
|
||||
f"StyleLearner 训练完成 (共享组成员 {target_chat_id}): {total_success}/{total_samples} 成功"
|
||||
)
|
||||
|
||||
if len(related_chat_ids) > 1:
|
||||
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
|
||||
if len(related_chat_ids) > 1:
|
||||
logger.info(f"共享组内共 {len(related_chat_ids)} 个 StyleLearner 已同步训练")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}")
|
||||
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
@@ -207,31 +207,20 @@ class ExpressionSelector:
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
# 🔥 优化:提前定义转换函数,避免重复代码
|
||||
def expr_to_dict(expr, expr_type: str) -> dict[str, Any]:
|
||||
return {
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"type": expr_type,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query.scalars()
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in grammar_query.scalars()
|
||||
]
|
||||
|
||||
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
|
||||
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
@@ -251,9 +240,14 @@ class ExpressionSelector:
|
||||
|
||||
@staticmethod
|
||||
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库
|
||||
|
||||
🔥 优化:合并所有更新到一个事务中,减少数据库连接开销
|
||||
"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
# 去重处理
|
||||
updates_by_key = {}
|
||||
affected_chat_ids = set()
|
||||
for expr in expressions_to_update:
|
||||
@@ -269,9 +263,15 @@ class ExpressionSelector:
|
||||
updates_by_key[key] = expr
|
||||
affected_chat_ids.add(source_id)
|
||||
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
if not updates_by_key:
|
||||
return
|
||||
|
||||
# 🔥 优化:使用单个 session 批量处理所有更新
|
||||
current_time = time.time()
|
||||
async with get_db_session() as session:
|
||||
updated_count = 0
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query_result = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
@@ -279,25 +279,26 @@ class ExpressionSelector:
|
||||
& (Expression.style == style)
|
||||
)
|
||||
)
|
||||
query = query.scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj = query_result.scalar()
|
||||
if expr_obj:
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.last_active_time = current_time
|
||||
updated_count += 1
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
# 批量提交所有更改
|
||||
if updated_count > 0:
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新了 {updated_count} 个表达方式的count值")
|
||||
|
||||
# 清除所有受影响的chat_id的缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
for chat_id in affected_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
if affected_chat_ids:
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
for chat_id in affected_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
@@ -518,29 +519,41 @@ class ExpressionSelector:
|
||||
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
# 🔥 优化:使用更高效的模糊匹配算法
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算
|
||||
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
db_style_lower = db_style.lower()
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
for predicted_style_lower, pred_score in predicted_styles_lower:
|
||||
# 快速检查:完全匹配
|
||||
if predicted_style_lower == db_style_lower:
|
||||
max_similarity = 1.0
|
||||
best_predicted = predicted_style_lower
|
||||
break
|
||||
|
||||
# 快速检查:子串匹配
|
||||
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
|
||||
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
|
||||
similarity = 0.7
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style_lower
|
||||
continue
|
||||
|
||||
# 计算字符串相似度(较慢,只在必要时使用)
|
||||
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
best_predicted = predicted_style_lower
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
@@ -573,14 +586,15 @@ class ExpressionSelector:
|
||||
f"(候选 {len(matched_expressions)},temperature={temperature})"
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
# 🔥 优化:使用列表推导式和预定义函数减少开销
|
||||
expressions = [
|
||||
{
|
||||
"situation": expr.situation or "",
|
||||
"style": expr.style or "",
|
||||
"type": expr.type or "style",
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
"last_active_time": expr.last_active_time or 0.0,
|
||||
"source_id": expr.chat_id # 添加 source_id 以便后续更新
|
||||
}
|
||||
for expr in expressions_objs
|
||||
]
|
||||
|
||||
@@ -127,7 +127,8 @@ class SituationExtractor:
|
||||
Returns:
|
||||
情境描述列表
|
||||
"""
|
||||
situations = []
|
||||
situations: list[str] = []
|
||||
seen = set()
|
||||
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
@@ -150,6 +151,11 @@ class SituationExtractor:
|
||||
if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]):
|
||||
continue
|
||||
|
||||
# 去重,保持原有顺序
|
||||
if line in seen:
|
||||
continue
|
||||
seen.add(line)
|
||||
|
||||
situations.append(line)
|
||||
|
||||
if len(situations) >= max_situations:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
支持多聊天室独立建模和在线学习
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -16,11 +17,12 @@ logger = get_logger("expressor.style_learner")
|
||||
class StyleLearner:
|
||||
"""单个聊天室的表达风格学习器"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: dict | None = None):
|
||||
def __init__(self, chat_id: str, model_config: dict | None = None, resource_limit_enabled: bool = True):
|
||||
"""
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置
|
||||
resource_limit_enabled: 是否启用资源上限控制(默认关闭)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
@@ -34,6 +36,9 @@ class StyleLearner:
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 资源上限控制开关(默认开启,可按需关闭)
|
||||
self.resource_limit_enabled = resource_limit_enabled
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.cleanup_threshold = 0.9 # 达到90%容量时触发清理
|
||||
@@ -67,18 +72,15 @@ class StyleLearner:
|
||||
if style in self.style_to_id:
|
||||
return True
|
||||
|
||||
# 检查是否需要清理
|
||||
current_count = len(self.style_to_id)
|
||||
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
|
||||
|
||||
if current_count >= cleanup_trigger:
|
||||
if current_count >= self.max_styles:
|
||||
# 已经达到最大限制,必须清理
|
||||
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
|
||||
self._cleanup_styles()
|
||||
elif current_count >= cleanup_trigger:
|
||||
# 接近限制,提前清理
|
||||
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
|
||||
# 检查是否需要清理(仅计算一次阈值)
|
||||
if self.resource_limit_enabled:
|
||||
current_count = len(self.style_to_id)
|
||||
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
|
||||
if current_count >= cleanup_trigger:
|
||||
if current_count >= self.max_styles:
|
||||
logger.warning(f"已达到最大风格数量限制 ({self.max_styles}),开始清理")
|
||||
else:
|
||||
logger.info(f"风格数量达到 {current_count}/{self.max_styles},触发预防性清理")
|
||||
self._cleanup_styles()
|
||||
|
||||
# 生成新的style_id
|
||||
@@ -95,7 +97,8 @@ class StyleLearner:
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
# 初始化统计
|
||||
self.learning_stats["style_counts"][style_id] = 0
|
||||
self.learning_stats.setdefault("style_counts", {})[style_id] = 0
|
||||
self.learning_stats.setdefault("style_last_used", {})
|
||||
|
||||
logger.debug(f"添加风格成功: {style_id} -> {style}")
|
||||
return True
|
||||
@@ -114,64 +117,64 @@ class StyleLearner:
|
||||
3. 默认清理 cleanup_ratio (20%) 的风格
|
||||
"""
|
||||
try:
|
||||
total_styles = len(self.style_to_id)
|
||||
if total_styles == 0:
|
||||
return
|
||||
|
||||
# 只有在达到阈值时才执行昂贵的排序
|
||||
cleanup_count = max(1, int(total_styles * self.cleanup_ratio))
|
||||
if cleanup_count <= 0:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
|
||||
# 局部引用加速频繁调用的函数
|
||||
from math import exp, log1p
|
||||
|
||||
# 计算每个风格的价值分数
|
||||
style_scores = []
|
||||
for style_id in self.style_to_id.values():
|
||||
# 使用次数
|
||||
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
|
||||
|
||||
# 最后使用时间(越近越好)
|
||||
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
|
||||
|
||||
time_since_used = current_time - last_used if last_used > 0 else float("inf")
|
||||
usage_score = log1p(usage_count)
|
||||
days_unused = time_since_used / 86400
|
||||
time_score = exp(-days_unused / 30)
|
||||
|
||||
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
|
||||
# 使用对数来平滑使用次数的影响
|
||||
import math
|
||||
usage_score = math.log1p(usage_count) # log(1 + count)
|
||||
|
||||
# 时间分数:转换为天数,使用指数衰减
|
||||
days_unused = time_since_used / 86400 # 转换为天
|
||||
time_score = math.exp(-days_unused / 30) # 30天衰减因子
|
||||
|
||||
# 综合分数:80%使用频率 + 20%时间新鲜度
|
||||
total_score = 0.8 * usage_score + 0.2 * time_score
|
||||
|
||||
style_scores.append((style_id, total_score, usage_count, days_unused))
|
||||
|
||||
if not style_scores:
|
||||
return
|
||||
|
||||
# 按分数排序,分数低的先删除
|
||||
style_scores.sort(key=lambda x: x[1])
|
||||
|
||||
# 删除分数最低的风格
|
||||
deleted_styles = []
|
||||
for style_id, score, usage, days in style_scores[:cleanup_count]:
|
||||
style_text = self.id_to_style.get(style_id)
|
||||
if style_text:
|
||||
# 从映射中删除
|
||||
del self.style_to_id[style_text]
|
||||
del self.id_to_style[style_id]
|
||||
if style_id in self.id_to_situation:
|
||||
del self.id_to_situation[style_id]
|
||||
if not style_text:
|
||||
continue
|
||||
|
||||
# 从统计中删除
|
||||
if style_id in self.learning_stats["style_counts"]:
|
||||
del self.learning_stats["style_counts"][style_id]
|
||||
if style_id in self.learning_stats["style_last_used"]:
|
||||
del self.learning_stats["style_last_used"][style_id]
|
||||
# 从映射中删除
|
||||
self.style_to_id.pop(style_text, None)
|
||||
self.id_to_style.pop(style_id, None)
|
||||
self.id_to_situation.pop(style_id, None)
|
||||
|
||||
# 从expressor模型中删除
|
||||
self.expressor.remove_candidate(style_id)
|
||||
# 从统计中删除
|
||||
self.learning_stats["style_counts"].pop(style_id, None)
|
||||
self.learning_stats["style_last_used"].pop(style_id, None)
|
||||
|
||||
deleted_styles.append((style_text[:30], usage, f"{days:.1f}天"))
|
||||
# 从expressor模型中删除
|
||||
self.expressor.remove_candidate(style_id)
|
||||
|
||||
deleted_styles.append((style_text[:30], usage, f"{days:.1f}天"))
|
||||
|
||||
logger.info(
|
||||
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
|
||||
f"剩余 {len(self.style_to_id)} 个风格"
|
||||
)
|
||||
|
||||
# 记录前5个被删除的风格(用于调试)
|
||||
if deleted_styles:
|
||||
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
|
||||
|
||||
@@ -204,7 +207,9 @@ class StyleLearner:
|
||||
# 更新统计
|
||||
current_time = time.time()
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats.setdefault("style_counts", {})
|
||||
self.learning_stats.setdefault("style_last_used", {})
|
||||
self.learning_stats["style_counts"][style_id] = self.learning_stats["style_counts"].get(style_id, 0) + 1
|
||||
self.learning_stats["style_last_used"][style_id] = current_time # 更新最后使用时间
|
||||
self.learning_stats["last_update"] = current_time
|
||||
|
||||
@@ -349,11 +354,11 @@ class StyleLearner:
|
||||
|
||||
# 保存expressor模型
|
||||
model_path = os.path.join(save_dir, "expressor_model.pkl")
|
||||
self.expressor.save(model_path)
|
||||
|
||||
# 保存映射关系和统计信息
|
||||
import pickle
|
||||
tmp_model_path = f"{model_path}.tmp"
|
||||
self.expressor.save(tmp_model_path)
|
||||
os.replace(tmp_model_path, model_path)
|
||||
|
||||
# 保存映射关系和统计信息(原子写)
|
||||
meta_path = os.path.join(save_dir, "meta.pkl")
|
||||
|
||||
# 确保 learning_stats 包含所有必要字段
|
||||
@@ -368,8 +373,13 @@ class StyleLearner:
|
||||
"learning_stats": self.learning_stats,
|
||||
}
|
||||
|
||||
with open(meta_path, "wb") as f:
|
||||
pickle.dump(meta_data, f)
|
||||
tmp_meta_path = f"{meta_path}.tmp"
|
||||
with open(tmp_meta_path, "wb") as f:
|
||||
pickle.dump(meta_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
os.replace(tmp_meta_path, meta_path)
|
||||
|
||||
return True
|
||||
|
||||
@@ -401,8 +411,6 @@ class StyleLearner:
|
||||
self.expressor.load(model_path)
|
||||
|
||||
# 加载映射关系和统计信息
|
||||
import pickle
|
||||
|
||||
meta_path = os.path.join(save_dir, "meta.pkl")
|
||||
if os.path.exists(meta_path):
|
||||
with open(meta_path, "rb") as f:
|
||||
@@ -445,14 +453,16 @@ class StyleLearnerManager:
|
||||
# 🔧 最大活跃 learner 数量
|
||||
MAX_ACTIVE_LEARNERS = 50
|
||||
|
||||
def __init__(self, model_save_path: str = "data/expression/style_models"):
|
||||
def __init__(self, model_save_path: str = "data/expression/style_models", resource_limit_enabled: bool = True):
|
||||
"""
|
||||
Args:
|
||||
model_save_path: 模型保存路径
|
||||
resource_limit_enabled: 是否启用资源上限控制(默认开启)
|
||||
"""
|
||||
self.learners: dict[str, StyleLearner] = {}
|
||||
self.learner_last_used: dict[str, float] = {} # 🔧 记录最后使用时间
|
||||
self.model_save_path = model_save_path
|
||||
self.resource_limit_enabled = resource_limit_enabled
|
||||
|
||||
# 确保保存目录存在
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
@@ -475,7 +485,10 @@ class StyleLearnerManager:
|
||||
for chat_id, last_used in sorted_by_time[:evict_count]:
|
||||
if chat_id in self.learners:
|
||||
# 先保存再淘汰
|
||||
self.learners[chat_id].save(self.model_save_path)
|
||||
try:
|
||||
self.learners[chat_id].save(self.model_save_path)
|
||||
except Exception as e:
|
||||
logger.error(f"LRU淘汰时保存学习器失败: chat_id={chat_id}, error={e}")
|
||||
del self.learners[chat_id]
|
||||
del self.learner_last_used[chat_id]
|
||||
evicted.append(chat_id)
|
||||
@@ -502,7 +515,11 @@ class StyleLearnerManager:
|
||||
self._evict_if_needed()
|
||||
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
learner = StyleLearner(
|
||||
chat_id,
|
||||
model_config,
|
||||
resource_limit_enabled=self.resource_limit_enabled,
|
||||
)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
@@ -511,6 +528,12 @@ class StyleLearnerManager:
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def set_resource_limit(self, enabled: bool) -> None:
|
||||
"""动态开启/关闭资源上限控制(默认关闭)。"""
|
||||
self.resource_limit_enabled = enabled
|
||||
for learner in self.learners.values():
|
||||
learner.resource_limit_enabled = enabled
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -37,19 +38,50 @@ class InterestManager:
|
||||
self._calculation_queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
# 性能优化相关字段
|
||||
self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存
|
||||
self._cache_max_size = 1000 # 最大缓存数量
|
||||
self._cache_ttl = 300 # 缓存TTL(秒)
|
||||
self._batch_queue: asyncio.Queue = asyncio.Queue(maxsize=100) # 批处理队列
|
||||
self._batch_size = 10 # 批处理大小
|
||||
self._batch_timeout = 0.1 # 批处理超时(秒)
|
||||
self._batch_task = None
|
||||
self._is_warmed_up = False # 预热状态标记
|
||||
|
||||
# 性能统计
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
self._batch_calculations = 0
|
||||
self._total_calculation_time = 0.0
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
pass
|
||||
# 启动批处理工作线程
|
||||
if self._batch_task is None or self._batch_task.done():
|
||||
self._batch_task = asyncio.create_task(self._batch_processing_worker())
|
||||
logger.info("批处理工作线程已启动")
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭管理器"""
|
||||
self._shutdown_event.set()
|
||||
|
||||
# 取消批处理任务
|
||||
if self._batch_task and not self._batch_task.done():
|
||||
self._batch_task.cancel()
|
||||
try:
|
||||
await self._batch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._current_calculator:
|
||||
await self._current_calculator.cleanup()
|
||||
self._current_calculator = None
|
||||
|
||||
# 清理缓存
|
||||
self._result_cache.clear()
|
||||
|
||||
logger.info("兴趣值管理器已关闭")
|
||||
|
||||
@@ -91,12 +123,13 @@ class InterestManager:
|
||||
logger.error(f"注册兴趣值计算组件失败: {e}")
|
||||
return False
|
||||
|
||||
async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None) -> InterestCalculationResult:
|
||||
"""计算消息兴趣值
|
||||
async def calculate_interest(self, message: "DatabaseMessages", timeout: float | None = None, use_cache: bool = True) -> InterestCalculationResult:
|
||||
"""计算消息兴趣值(优化版,支持缓存)
|
||||
|
||||
Args:
|
||||
message: 数据库消息对象
|
||||
timeout: 最大等待时间(秒),超时则使用默认值返回;为None时不设置超时
|
||||
use_cache: 是否使用缓存,默认True
|
||||
|
||||
Returns:
|
||||
InterestCalculationResult: 计算结果或默认结果
|
||||
@@ -109,37 +142,53 @@ class InterestManager:
|
||||
interest_value=0.3,
|
||||
error_message="没有可用的兴趣值计算组件",
|
||||
)
|
||||
|
||||
message_id = getattr(message, "message_id", "")
|
||||
|
||||
# 缓存查询
|
||||
if use_cache and message_id:
|
||||
cached_result = self._get_from_cache(message_id)
|
||||
if cached_result is not None:
|
||||
self._cache_hits += 1
|
||||
logger.debug(f"命中缓存: {message_id}, 兴趣值: {cached_result.interest_value:.3f}")
|
||||
return cached_result
|
||||
self._cache_misses += 1
|
||||
|
||||
# 使用 create_task 异步执行计算
|
||||
task = asyncio.create_task(self._async_calculate(message))
|
||||
|
||||
if timeout is None:
|
||||
return await task
|
||||
|
||||
try:
|
||||
# 等待计算结果,但有超时限制
|
||||
result = await asyncio.wait_for(task, timeout=timeout)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
# 超时返回默认结果,但计算仍在后台继续
|
||||
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5")
|
||||
return InterestCalculationResult(
|
||||
success=True,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.5, # 固定默认兴趣值
|
||||
should_reply=False,
|
||||
should_act=False,
|
||||
error_message=f"计算超时({timeout}s),使用默认值",
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常,返回默认结果
|
||||
logger.error(f"兴趣值计算异常: {e}")
|
||||
return InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
)
|
||||
result = await task
|
||||
else:
|
||||
try:
|
||||
# 等待计算结果,但有超时限制
|
||||
result = await asyncio.wait_for(task, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# 超时返回默认结果,但计算仍在后台继续
|
||||
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {message_id} 使用默认兴趣值 0.5")
|
||||
return InterestCalculationResult(
|
||||
success=True,
|
||||
message_id=message_id,
|
||||
interest_value=0.5, # 固定默认兴趣值
|
||||
should_reply=False,
|
||||
should_act=False,
|
||||
error_message=f"计算超时({timeout}s),使用默认值",
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常,返回默认结果
|
||||
logger.error(f"兴趣值计算异常: {e}")
|
||||
return InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=message_id,
|
||||
interest_value=0.3,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
if use_cache and result.success and message_id:
|
||||
self._put_to_cache(message_id, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
"""异步执行兴趣值计算"""
|
||||
@@ -161,6 +210,7 @@ class InterestManager:
|
||||
|
||||
if result.success:
|
||||
self._last_calculation_time = time.time()
|
||||
self._total_calculation_time += result.calculation_time
|
||||
logger.debug(f"兴趣值计算完成: {result.interest_value:.3f} (耗时: {result.calculation_time:.3f}s)")
|
||||
else:
|
||||
self._failed_calculations += 1
|
||||
@@ -170,13 +220,15 @@ class InterestManager:
|
||||
|
||||
except Exception as e:
|
||||
self._failed_calculations += 1
|
||||
calc_time = time.time() - start_time
|
||||
self._total_calculation_time += calc_time
|
||||
logger.error(f"兴趣值计算异常: {e}")
|
||||
return InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
calculation_time=time.time() - start_time,
|
||||
calculation_time=calc_time,
|
||||
)
|
||||
|
||||
async def _calculation_worker(self):
|
||||
@@ -197,6 +249,155 @@ class InterestManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"计算工作线程异常: {e}")
|
||||
|
||||
def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None:
|
||||
"""从缓存中获取结果(LRU策略)"""
|
||||
if message_id not in self._result_cache:
|
||||
return None
|
||||
|
||||
# 检查TTL
|
||||
result = self._result_cache[message_id]
|
||||
if time.time() - result.timestamp > self._cache_ttl:
|
||||
# 过期,删除
|
||||
del self._result_cache[message_id]
|
||||
return None
|
||||
|
||||
# 更新访问顺序(LRU)
|
||||
self._result_cache.move_to_end(message_id)
|
||||
return result
|
||||
|
||||
def _put_to_cache(self, message_id: str, result: InterestCalculationResult):
|
||||
"""将结果放入缓存(LRU策略)"""
|
||||
# 如果已存在,更新
|
||||
if message_id in self._result_cache:
|
||||
self._result_cache.move_to_end(message_id)
|
||||
|
||||
self._result_cache[message_id] = result
|
||||
|
||||
# 限制缓存大小
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
# 删除最旧的项
|
||||
self._result_cache.popitem(last=False)
|
||||
|
||||
async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]:
|
||||
"""批量计算消息兴趣值(并发优化)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
timeout: 单个计算的超时时间
|
||||
|
||||
Returns:
|
||||
list[InterestCalculationResult]: 计算结果列表
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 并发计算所有消息
|
||||
tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"批量计算消息 {i} 失败: {result}")
|
||||
final_results.append(InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=getattr(messages[i], "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message=f"批量计算异常: {result!s}",
|
||||
))
|
||||
else:
|
||||
final_results.append(result)
|
||||
|
||||
self._batch_calculations += 1
|
||||
return final_results
|
||||
|
||||
async def _batch_processing_worker(self):
|
||||
"""批处理工作线程"""
|
||||
while not self._shutdown_event.is_set():
|
||||
batch = []
|
||||
deadline = time.time() + self._batch_timeout
|
||||
|
||||
try:
|
||||
# 收集批次
|
||||
while len(batch) < self._batch_size and time.time() < deadline:
|
||||
remaining_time = deadline - time.time()
|
||||
if remaining_time <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time)
|
||||
batch.append(item)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
# 处理批次
|
||||
if batch:
|
||||
await self._process_batch(batch)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"批处理工作线程异常: {e}")
|
||||
|
||||
async def _process_batch(self, batch: list):
|
||||
"""处理批次消息"""
|
||||
# 这里可以实现具体的批处理逻辑
|
||||
# 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现
|
||||
pass
|
||||
|
||||
async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None):
|
||||
"""预热兴趣计算器
|
||||
|
||||
Args:
|
||||
sample_messages: 样本消息列表,用于预热。如果为None,则只初始化计算器
|
||||
"""
|
||||
if not self._current_calculator:
|
||||
logger.warning("无法预热:没有可用的兴趣值计算组件")
|
||||
return
|
||||
|
||||
logger.info("开始预热兴趣值计算器...")
|
||||
start_time = time.time()
|
||||
|
||||
# 如果提供了样本消息,进行预热计算
|
||||
if sample_messages:
|
||||
try:
|
||||
# 批量计算样本消息
|
||||
await self.calculate_interest_batch(sample_messages, timeout=5.0)
|
||||
logger.info(f"预热完成:处理了 {len(sample_messages)} 条样本消息,耗时 {time.time() - start_time:.2f}s")
|
||||
except Exception as e:
|
||||
logger.error(f"预热过程中出现异常: {e}")
|
||||
else:
|
||||
logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s")
|
||||
|
||||
self._is_warmed_up = True
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
cleared_count = len(self._result_cache)
|
||||
self._result_cache.clear()
|
||||
logger.info(f"已清空 {cleared_count} 条缓存记录")
|
||||
|
||||
def set_cache_config(self, max_size: int | None = None, ttl: int | None = None):
|
||||
"""设置缓存配置
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存数量
|
||||
ttl: 缓存生存时间(秒)
|
||||
"""
|
||||
if max_size is not None:
|
||||
self._cache_max_size = max_size
|
||||
logger.info(f"缓存最大容量设置为: {max_size}")
|
||||
|
||||
if ttl is not None:
|
||||
self._cache_ttl = ttl
|
||||
logger.info(f"缓存TTL设置为: {ttl}秒")
|
||||
|
||||
# 如果当前缓存超过新的最大值,清理旧数据
|
||||
if max_size is not None:
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
self._result_cache.popitem(last=False)
|
||||
|
||||
def get_current_calculator(self) -> BaseInterestCalculator | None:
|
||||
"""获取当前活跃的兴趣值计算组件"""
|
||||
@@ -205,6 +406,8 @@ class InterestManager:
|
||||
def get_statistics(self) -> dict:
|
||||
"""获取管理器统计信息"""
|
||||
success_rate = 1.0 - (self._failed_calculations / max(1, self._total_calculations))
|
||||
cache_hit_rate = self._cache_hits / max(1, self._cache_hits + self._cache_misses)
|
||||
avg_calc_time = self._total_calculation_time / max(1, self._total_calculations)
|
||||
|
||||
stats = {
|
||||
"manager_statistics": {
|
||||
@@ -213,6 +416,13 @@ class InterestManager:
|
||||
"success_rate": success_rate,
|
||||
"last_calculation_time": self._last_calculation_time,
|
||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"cache_hits": self._cache_hits,
|
||||
"cache_misses": self._cache_misses,
|
||||
"cache_size": len(self._result_cache),
|
||||
"batch_calculations": self._batch_calculations,
|
||||
"average_calculation_time": avg_calc_time,
|
||||
"is_warmed_up": self._is_warmed_up,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,6 +446,82 @@ class InterestManager:
|
||||
def has_calculator(self) -> bool:
|
||||
"""检查是否有可用的计算组件"""
|
||||
return self._current_calculator is not None and self._current_calculator.is_enabled
|
||||
|
||||
async def adaptive_optimize(self):
|
||||
"""自适应优化:根据性能统计自动调整参数"""
|
||||
if not self._current_calculator:
|
||||
return
|
||||
|
||||
stats = self.get_statistics()["manager_statistics"]
|
||||
|
||||
# 根据缓存命中率调整缓存大小
|
||||
cache_hit_rate = stats["cache_hit_rate"]
|
||||
if cache_hit_rate < 0.5 and self._cache_max_size < 5000:
|
||||
# 命中率低,增加缓存容量
|
||||
new_size = min(self._cache_max_size * 2, 5000)
|
||||
logger.info(f"自适应优化:缓存命中率较低 ({cache_hit_rate:.2%}),扩大缓存容量 {self._cache_max_size} -> {new_size}")
|
||||
self._cache_max_size = new_size
|
||||
elif cache_hit_rate > 0.9 and self._cache_max_size > 100:
|
||||
# 命中率高,可以适当减小缓存
|
||||
new_size = max(self._cache_max_size // 2, 100)
|
||||
logger.info(f"自适应优化:缓存命中率很高 ({cache_hit_rate:.2%}),缩小缓存容量 {self._cache_max_size} -> {new_size}")
|
||||
self._cache_max_size = new_size
|
||||
# 清理多余缓存
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
self._result_cache.popitem(last=False)
|
||||
|
||||
# 根据平均计算时间调整批处理参数
|
||||
avg_calc_time = stats["average_calculation_time"]
|
||||
if avg_calc_time > 0.5 and self._batch_size < 50:
|
||||
# 计算较慢,增加批次大小以提高吞吐量
|
||||
new_batch_size = min(self._batch_size * 2, 50)
|
||||
logger.info(f"自适应优化:平均计算时间较长 ({avg_calc_time:.3f}s),增加批次大小 {self._batch_size} -> {new_batch_size}")
|
||||
self._batch_size = new_batch_size
|
||||
elif avg_calc_time < 0.1 and self._batch_size > 5:
|
||||
# 计算较快,可以减小批次
|
||||
new_batch_size = max(self._batch_size // 2, 5)
|
||||
logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}")
|
||||
self._batch_size = new_batch_size
|
||||
|
||||
def get_performance_report(self) -> str:
|
||||
"""生成性能报告"""
|
||||
stats = self.get_statistics()["manager_statistics"]
|
||||
|
||||
report = [
|
||||
"=" * 60,
|
||||
"兴趣值管理器性能报告",
|
||||
"=" * 60,
|
||||
f"总计算次数: {stats['total_calculations']}",
|
||||
f"失败次数: {stats['failed_calculations']}",
|
||||
f"成功率: {stats['success_rate']:.2%}",
|
||||
f"缓存命中率: {stats['cache_hit_rate']:.2%}",
|
||||
f"缓存命中: {stats['cache_hits']}",
|
||||
f"缓存未命中: {stats['cache_misses']}",
|
||||
f"当前缓存大小: {stats['cache_size']} / {self._cache_max_size}",
|
||||
f"批量计算次数: {stats['batch_calculations']}",
|
||||
f"平均计算时间: {stats['average_calculation_time']:.4f}s",
|
||||
f"是否已预热: {'是' if stats['is_warmed_up'] else '否'}",
|
||||
f"当前计算器: {stats['current_calculator'] or '无'}",
|
||||
"=" * 60,
|
||||
]
|
||||
|
||||
# 添加计算器统计
|
||||
if self._current_calculator:
|
||||
calc_stats = self.get_statistics()["calculator_statistics"]
|
||||
report.extend([
|
||||
"",
|
||||
"计算器统计:",
|
||||
f" 组件名称: {calc_stats['component_name']}",
|
||||
f" 版本: {calc_stats['component_version']}",
|
||||
f" 已启用: {calc_stats['enabled']}",
|
||||
f" 总计算: {calc_stats['total_calculations']}",
|
||||
f" 失败: {calc_stats['failed_calculations']}",
|
||||
f" 成功率: {calc_stats['success_rate']:.2%}",
|
||||
f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s",
|
||||
"=" * 60,
|
||||
])
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
# 全局实例
|
||||
|
||||
199
src/memory_graph/short_term_pressure_patch.md
Normal file
199
src/memory_graph/short_term_pressure_patch.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# 短期记忆压力泄压补丁
|
||||
|
||||
## 📋 概述
|
||||
|
||||
在高频消息场景下,短期记忆层(`ShortTermMemoryManager`)可能在自动转移机制触发前快速堆积大量记忆,当达到容量上限(`max_memories`)时可能阻塞后续写入。本功能提供一个**可选的泄压开关**,在容量溢出时自动删除低优先级记忆,防止系统阻塞。
|
||||
|
||||
**关键特性**:
|
||||
- ✅ 默认关闭,保持向后兼容
|
||||
- ✅ 基于重要性和时间的智能删除策略
|
||||
- ✅ 异步持久化,不阻塞主流程
|
||||
- ✅ 可通过配置文件或代码控制
|
||||
|
||||
---
|
||||
|
||||
## 🔧 配置方法
|
||||
|
||||
### 方法 1:代码配置(直接创建管理器)
|
||||
|
||||
如果您在代码中直接实例化 `UnifiedMemoryManager`:
|
||||
|
||||
```python
|
||||
from src.memory_graph.unified_manager import UnifiedMemoryManager
|
||||
|
||||
manager = UnifiedMemoryManager(
|
||||
short_term_enable_force_cleanup=True, # 开启泄压功能
|
||||
short_term_max_memories=30, # 短期记忆容量上限
|
||||
# ... 其他参数
|
||||
)
|
||||
```
|
||||
|
||||
### 方法 2:配置文件(通过单例获取)
|
||||
|
||||
**推荐方式**:如果您使用 `get_unified_memory_manager()` 单例,需修改配置文件。
|
||||
|
||||
#### ❌ 目前的问题
|
||||
配置文件 `config/bot_config.toml` 的 `[memory]` 节**尚未包含**此开关参数。
|
||||
|
||||
#### ✅ 解决方案
|
||||
在 `config/bot_config.toml` 的 `[memory]` 节添加:
|
||||
|
||||
```toml
|
||||
[memory]
|
||||
# ... 其他配置 ...
|
||||
short_term_max_memories = 30 # 短期记忆容量上限
|
||||
short_term_transfer_threshold = 0.6 # 转移到长期记忆的重要性阈值
|
||||
short_term_enable_force_cleanup = true # 开启压力泄压(建议高频场景开启)
|
||||
```
|
||||
|
||||
然后在 `src/memory_graph/manager_singleton.py` 第 157-175 行的 `get_unified_memory_manager()` 函数中添加读取逻辑:
|
||||
|
||||
```python
|
||||
_unified_memory_manager = UnifiedMemoryManager(
|
||||
# ... 其他参数 ...
|
||||
short_term_enable_force_cleanup=getattr(config, "short_term_enable_force_cleanup", False), # 添加此行
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 核心实现位置
|
||||
|
||||
### 1. 参数定义
|
||||
**文件**:`src/memory_graph/unified_manager.py` 第 47 行
|
||||
```python
|
||||
class UnifiedMemoryManager:
|
||||
def __init__(
|
||||
self,
|
||||
short_term_enable_force_cleanup: bool = False, # 开关参数
|
||||
):
|
||||
```
|
||||
|
||||
### 2. 传递到短期层
|
||||
**文件**:`src/memory_graph/unified_manager.py` 第 100 行
|
||||
```python
|
||||
"short_term": {
|
||||
"enable_force_cleanup": short_term_enable_force_cleanup, # 传递给 ShortTermMemoryManager
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 泄压逻辑实现
|
||||
**文件**:`src/memory_graph/short_term_manager.py` 第 693-726 行
|
||||
```python
|
||||
def force_cleanup_overflow(self, keep_ratio: float = 0.9) -> int:
|
||||
"""当短期记忆超过容量时,强制删除低重要性且最早的记忆以泄压"""
|
||||
if not self.enable_force_cleanup: # 检查开关
|
||||
return 0
|
||||
# ... 删除逻辑
|
||||
```
|
||||
|
||||
### 4. 触发条件
|
||||
**文件**:`src/memory_graph/unified_manager.py` 第 618-621 行
|
||||
```python
|
||||
# 在自动转移循环中检测
|
||||
if occupancy_ratio >= 1.0 and not transfer_cache:
|
||||
removed = self.short_term_manager.force_cleanup_overflow()
|
||||
if removed > 0:
|
||||
logger.warning(f"短期记忆占用率 {occupancy_ratio:.0%},已强制删除 {removed} 条低重要性记忆泄压")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 运行机制
|
||||
|
||||
### 触发条件(同时满足)
|
||||
1. ✅ 开关已开启(`enable_force_cleanup=True`)
|
||||
2. ✅ 短期记忆占用率 ≥ 100%(`len(memories) >= max_memories`)
|
||||
3. ✅ 当前没有待转移批次(`transfer_cache` 为空)
|
||||
|
||||
### 删除策略
|
||||
**排序规则**:双重排序,先按重要性升序,再按创建时间升序
|
||||
```python
|
||||
sorted_memories = sorted(self.memories, key=lambda m: (m.importance, m.created_at))
|
||||
```
|
||||
|
||||
**删除数量**:删除到容量的 90%
|
||||
```python
|
||||
current = len(self.memories) # 当前记忆数
|
||||
limit = int(self.max_memories * 0.9) # 目标保留数
|
||||
remove_count = current - limit # 需要删除的数量
|
||||
```
|
||||
|
||||
**示例**:
|
||||
- 容量上限 `max_memories=30`
|
||||
- 当前记忆数 `35` → 删除 `35 - 27 = 8` 条最低优先级记忆
|
||||
- 优先删除:重要性 0.1 且创建于 10 分钟前的记忆
|
||||
|
||||
### 持久化
|
||||
- 使用 `asyncio.create_task(self._save_to_disk())` 异步保存
|
||||
- **不阻塞**消息处理主流程
|
||||
|
||||
---
|
||||
|
||||
## 📊 性能影响
|
||||
|
||||
| 场景 | 开关状态 | 行为 | 适用场景 |
|
||||
|------|---------|------|---------|
|
||||
| 高频消息 | ✅ 开启 | 自动泄压,防止阻塞 | 群聊、客服场景 |
|
||||
| 低频消息 | ❌ 关闭 | 仅依赖自动转移 | 私聊、低活跃群 |
|
||||
| 调试阶段 | ❌ 关闭 | 便于观察记忆堆积 | 开发测试 |
|
||||
|
||||
**日志示例**(开启后):
|
||||
```
|
||||
[WARNING] 短期记忆压力泄压: 移除 8 条 (当前 27/30)
|
||||
[WARNING] 短期记忆占用率 100%,已强制删除 8 条低重要性记忆泄压
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚨 注意事项
|
||||
|
||||
### ⚠️ 何时开启
|
||||
- ✅ **推荐开启**:高频群聊、客服机器人、24/7 运行场景
|
||||
- ❌ **不建议开启**:需要完整保留所有短期记忆、调试阶段
|
||||
|
||||
### ⚠️ 潜在影响
|
||||
- 低重要性记忆可能被删除,**不会转移到长期记忆**
|
||||
- 如需保留所有记忆,应调大 `max_memories` 或关闭此功能
|
||||
|
||||
### ⚠️ 与自动转移的协同
|
||||
本功能是**兜底机制**,正常情况下:
|
||||
1. 优先触发自动转移(占用率 ≥ 50%)
|
||||
2. 高重要性记忆转移到长期层
|
||||
3. 仅当转移来不及时,泄压才会触发
|
||||
|
||||
---
|
||||
|
||||
## 🔙 回滚与禁用
|
||||
|
||||
### 临时禁用(无需重启)
|
||||
```python
|
||||
# 运行时修改(如果您能访问管理器实例)
|
||||
unified_manager.short_term_manager.enable_force_cleanup = False
|
||||
```
|
||||
|
||||
### 永久禁用
|
||||
**配置文件方式**:
|
||||
```toml
|
||||
[memory]
|
||||
short_term_enable_force_cleanup = false # 或直接删除此行
|
||||
```
|
||||
|
||||
**代码方式**:
|
||||
```python
|
||||
manager = UnifiedMemoryManager(
|
||||
short_term_enable_force_cleanup=False, # 显式关闭
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文档
|
||||
|
||||
- [三层记忆系统用户指南](../../docs/three_tier_memory_user_guide.md)
|
||||
- [记忆图谱架构](../../docs/memory_graph_guide.md)
|
||||
- [统一调度器指南](../../docs/unified_scheduler_guide.md)
|
||||
|
||||
---
|
||||
|
||||
**最后更新**:2025年12月16日
|
||||
@@ -117,10 +117,17 @@ class BaseInterestCalculator(ABC):
|
||||
"""
|
||||
try:
|
||||
self._enabled = True
|
||||
# 子类可以重写此方法执行自定义初始化
|
||||
await self.on_initialize()
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"初始化兴趣计算器失败: {e}")
|
||||
self._enabled = False
|
||||
return False
|
||||
|
||||
async def on_initialize(self):
|
||||
"""子类可重写的初始化钩子"""
|
||||
pass
|
||||
|
||||
async def cleanup(self) -> bool:
|
||||
"""清理组件资源
|
||||
@@ -129,10 +136,17 @@ class BaseInterestCalculator(ABC):
|
||||
bool: 清理是否成功
|
||||
"""
|
||||
try:
|
||||
# 子类可以重写此方法执行自定义清理
|
||||
await self.on_cleanup()
|
||||
self._enabled = False
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"清理兴趣计算器失败: {e}")
|
||||
return False
|
||||
|
||||
async def on_cleanup(self):
|
||||
"""子类可重写的清理钩子"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user