refactor(expression_learner): 优化表达方式查询逻辑,减少数据库查询次数并批量处理更新操作

refactor(storage): 添加消息更新批处理器,优化消息ID更新逻辑以减少数据库连接次数
This commit is contained in:
Windpicker-owo
2025-11-01 01:07:37 +08:00
parent 66d0375d45
commit 0148f1e533
2 changed files with 172 additions and 64 deletions

View File

@@ -234,51 +234,45 @@ class ExpressionLearner:
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
# 优化: 一次查询获取所有表达方式
async with get_db_session() as session:
style_query = await session.execute(
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
all_expressions = await session.execute(
select(Expression).where(Expression.chat_id == self.chat_id)
)
for expr in style_query.scalars():
for expr in all_expressions.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
expr_data = {
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "style",
"type": expr.type,
"create_date": create_date,
}
)
grammar_query = await session.execute(
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar"))
)
for expr in grammar_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
# 根据类型分类
if expr.type == "style":
learnt_style_expressions.append(expr_data)
elif expr.type == "grammar":
learnt_grammar_expressions.append(expr_data)
return learnt_style_expressions, learnt_grammar_expressions
async def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
优化: 批量处理所有更改,最后统一提交,避免逐条提交
"""
try:
async with get_db_session() as session:
@@ -289,6 +283,7 @@ class ExpressionLearner:
updated_count = 0
deleted_count = 0
# 优化: 批量处理所有修改
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
@@ -301,14 +296,15 @@ class ExpressionLearner:
if new_count <= 0.01:
# 如果count太小删除这个表达方式
await session.delete(expr)
await session.commit()
deleted_count += 1
else:
# 更新count
expr.count = new_count
updated_count += 1
# 优化: 统一提交所有更改从N次提交减少到1次
if updated_count > 0 or deleted_count > 0:
await session.commit()
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:

View File

@@ -1,6 +1,8 @@
import asyncio
import re
import time
import traceback
from collections import deque
import orjson
from sqlalchemy import desc, select, update
@@ -16,6 +18,102 @@ from .message import MessageSending
logger = get_logger("message_storage")
class MessageUpdateBatcher:
"""
消息更新批处理器
优化: 将多个消息ID更新操作批量处理减少数据库连接次数
"""
def __init__(self, batch_size: int = 20, flush_interval: float = 2.0):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_updates: deque = deque()
self._lock = asyncio.Lock()
self._flush_task = None
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None:
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.debug("消息更新批处理器已启动")
async def stop(self):
"""停止批处理器"""
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
# 刷新剩余的更新
await self.flush()
logger.debug("消息更新批处理器已停止")
async def add_update(self, mmc_message_id: str, qq_message_id: str):
"""添加消息ID更新到批处理队列"""
async with self._lock:
self.pending_updates.append((mmc_message_id, qq_message_id))
# 如果达到批量大小,立即刷新
if len(self.pending_updates) >= self.batch_size:
await self.flush()
async def flush(self):
"""执行批量更新"""
async with self._lock:
if not self.pending_updates:
return
updates = list(self.pending_updates)
self.pending_updates.clear()
try:
async with get_db_session() as session:
updated_count = 0
for mmc_id, qq_id in updates:
result = await session.execute(
update(Messages)
.where(Messages.message_id == mmc_id)
.values(message_id=qq_id)
)
if result.rowcount > 0:
updated_count += 1
await session.commit()
if updated_count > 0:
logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID")
except Exception as e:
logger.error(f"批量更新消息ID失败: {e}")
async def _auto_flush_loop(self):
"""自动刷新循环"""
while True:
try:
await asyncio.sleep(self.flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动刷新出错: {e}")
# 全局批处理器实例
_message_update_batcher = None
def get_message_update_batcher() -> MessageUpdateBatcher:
"""获取全局消息更新批处理器"""
global _message_update_batcher
if _message_update_batcher is None:
_message_update_batcher = MessageUpdateBatcher()
return _message_update_batcher
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
@@ -216,8 +314,16 @@ class MessageStorage:
traceback.print_exc()
@staticmethod
async def update_message(message_data: dict):
"""更新消息ID从消息字典"""
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
Args:
message_data: 消息数据字典
use_batch: 是否使用批处理默认True
"""
try:
# 从字典中提取信息
message_info = message_data.get("message_info", {})
@@ -255,7 +361,13 @@ class MessageStorage:
logger.debug(f"消息段数据: {segment_data}")
return
# 使用上下文管理器确保session正确管理
# 优化: 使用批处理器减少数据库连接
if use_batch:
batcher = get_message_update_batcher()
await batcher.add_update(mmc_message_id, qq_message_id)
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
else:
# 直接更新(保留原有逻辑用于特殊情况)
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session: