refactor(expression_learner): 优化表达方式查询逻辑,减少数据库查询次数并批量处理更新操作

refactor(storage): 添加消息更新批处理器,优化消息ID更新逻辑以减少数据库连接次数
This commit is contained in:
Windpicker-owo
2025-11-01 01:07:37 +08:00
parent 66d0375d45
commit 0148f1e533
2 changed files with 172 additions and 64 deletions

View File

@@ -234,51 +234,45 @@ class ExpressionLearner:
""" """
获取指定chat_id的style和grammar表达方式 获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
""" """
learnt_style_expressions = [] learnt_style_expressions = []
learnt_grammar_expressions = [] learnt_grammar_expressions = []
# 直接从数据库查询 # 优化: 一次查询获取所有表达方式
async with get_db_session() as session: async with get_db_session() as session:
style_query = await session.execute( all_expressions = await session.execute(
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")) select(Expression).where(Expression.chat_id == self.chat_id)
) )
for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time for expr in all_expressions.scalars():
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time # 确保create_date存在如果不存在则使用last_active_time
learnt_style_expressions.append( create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
{
expr_data = {
"situation": expr.situation, "situation": expr.situation,
"style": expr.style, "style": expr.style,
"count": expr.count, "count": expr.count,
"last_active_time": expr.last_active_time, "last_active_time": expr.last_active_time,
"source_id": self.chat_id, "source_id": self.chat_id,
"type": "style", "type": expr.type,
"create_date": create_date, "create_date": create_date,
} }
)
grammar_query = await session.execute( # 根据类型分类
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")) if expr.type == "style":
) learnt_style_expressions.append(expr_data)
for expr in grammar_query.scalars(): elif expr.type == "grammar":
# 确保create_date存在如果不存在则使用last_active_time learnt_grammar_expressions.append(expr_data)
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions return learnt_style_expressions, learnt_grammar_expressions
async def _apply_global_decay_to_database(self, current_time: float) -> None: async def _apply_global_decay_to_database(self, current_time: float) -> None:
""" """
对数据库中的所有表达方式应用全局衰减 对数据库中的所有表达方式应用全局衰减
优化: 批量处理所有更改,最后统一提交,避免逐条提交
""" """
try: try:
async with get_db_session() as session: async with get_db_session() as session:
@@ -286,30 +280,32 @@ class ExpressionLearner:
all_expressions = await session.execute(select(Expression)) all_expressions = await session.execute(select(Expression))
all_expressions = all_expressions.scalars().all() all_expressions = all_expressions.scalars().all()
updated_count = 0 updated_count = 0
deleted_count = 0 deleted_count = 0
for expr in all_expressions: # 优化: 批量处理所有修改
# 计算时间差 for expr in all_expressions:
last_active = expr.last_active_time # 计算时间差
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值 # 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days) decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value) new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01: if new_count <= 0.01:
# 如果count太小删除这个表达方式 # 如果count太小删除这个表达方式
await session.delete(expr) await session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
updated_count += 1
# 优化: 统一提交所有更改从N次提交减少到1次
if updated_count > 0 or deleted_count > 0:
await session.commit() await session.commit()
deleted_count += 1 logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
else:
# 更新count
expr.count = new_count
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e: except Exception as e:
logger.error(f"数据库全局衰减失败: {e}") logger.error(f"数据库全局衰减失败: {e}")

View File

@@ -1,6 +1,8 @@
import asyncio
import re import re
import time import time
import traceback import traceback
from collections import deque
import orjson import orjson
from sqlalchemy import desc, select, update from sqlalchemy import desc, select, update
@@ -16,6 +18,102 @@ from .message import MessageSending
logger = get_logger("message_storage") logger = get_logger("message_storage")
class MessageUpdateBatcher:
"""
消息更新批处理器
优化: 将多个消息ID更新操作批量处理减少数据库连接次数
"""
def __init__(self, batch_size: int = 20, flush_interval: float = 2.0):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_updates: deque = deque()
self._lock = asyncio.Lock()
self._flush_task = None
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None:
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.debug("消息更新批处理器已启动")
async def stop(self):
"""停止批处理器"""
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
# 刷新剩余的更新
await self.flush()
logger.debug("消息更新批处理器已停止")
async def add_update(self, mmc_message_id: str, qq_message_id: str):
"""添加消息ID更新到批处理队列"""
async with self._lock:
self.pending_updates.append((mmc_message_id, qq_message_id))
# 如果达到批量大小,立即刷新
if len(self.pending_updates) >= self.batch_size:
await self.flush()
async def flush(self):
"""执行批量更新"""
async with self._lock:
if not self.pending_updates:
return
updates = list(self.pending_updates)
self.pending_updates.clear()
try:
async with get_db_session() as session:
updated_count = 0
for mmc_id, qq_id in updates:
result = await session.execute(
update(Messages)
.where(Messages.message_id == mmc_id)
.values(message_id=qq_id)
)
if result.rowcount > 0:
updated_count += 1
await session.commit()
if updated_count > 0:
logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID")
except Exception as e:
logger.error(f"批量更新消息ID失败: {e}")
async def _auto_flush_loop(self):
"""自动刷新循环"""
while True:
try:
await asyncio.sleep(self.flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动刷新出错: {e}")
# 全局批处理器实例
_message_update_batcher = None
def get_message_update_batcher() -> MessageUpdateBatcher:
"""获取全局消息更新批处理器"""
global _message_update_batcher
if _message_update_batcher is None:
_message_update_batcher = MessageUpdateBatcher()
return _message_update_batcher
class MessageStorage: class MessageStorage:
@staticmethod @staticmethod
def _serialize_keywords(keywords) -> str: def _serialize_keywords(keywords) -> str:
@@ -216,8 +314,16 @@ class MessageStorage:
traceback.print_exc() traceback.print_exc()
@staticmethod @staticmethod
async def update_message(message_data: dict): async def update_message(message_data: dict, use_batch: bool = True):
"""更新消息ID从消息字典""" """
更新消息ID从消息字典
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
Args:
message_data: 消息数据字典
use_batch: 是否使用批处理默认True
"""
try: try:
# 从字典中提取信息 # 从字典中提取信息
message_info = message_data.get("message_info", {}) message_info = message_data.get("message_info", {})
@@ -255,23 +361,29 @@ class MessageStorage:
logger.debug(f"消息段数据: {segment_data}") logger.debug(f"消息段数据: {segment_data}")
return return
# 使用上下文管理器确保session正确管理 # 优化: 使用批处理器减少数据库连接
from src.common.database.sqlalchemy_models import get_db_session if use_batch:
batcher = get_message_update_batcher()
await batcher.add_update(mmc_message_id, qq_message_id)
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
else:
# 直接更新(保留原有逻辑用于特殊情况)
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session: async with get_db_session() as session:
matched_message = ( matched_message = (
await session.execute( await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
) )
).scalar() ).scalar()
if matched_message: if matched_message:
await session.execute( await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
) )
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else: else:
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
except Exception as e: except Exception as e:
logger.error(f"更新消息ID失败: {e}") logger.error(f"更新消息ID失败: {e}")