refactor(expression_learner): 优化表达方式查询逻辑,减少数据库查询次数并批量处理更新操作
refactor(storage): 添加消息更新批处理器,优化消息ID更新逻辑以减少数据库连接次数
This commit is contained in:
@@ -234,51 +234,45 @@ class ExpressionLearner:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
|
||||
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
|
||||
"""
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 直接从数据库查询
|
||||
# 优化: 一次查询获取所有表达方式
|
||||
async with get_db_session() as session:
|
||||
style_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||
all_expressions = await session.execute(
|
||||
select(Expression).where(Expression.chat_id == self.chat_id)
|
||||
)
|
||||
for expr in style_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
|
||||
for expr in all_expressions.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
|
||||
expr_data = {
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "style",
|
||||
"type": expr.type,
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar"))
|
||||
)
|
||||
for expr in grammar_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
|
||||
# 根据类型分类
|
||||
if expr.type == "style":
|
||||
learnt_style_expressions.append(expr_data)
|
||||
elif expr.type == "grammar":
|
||||
learnt_grammar_expressions.append(expr_data)
|
||||
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
async def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
|
||||
优化: 批量处理所有更改,最后统一提交,避免逐条提交
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -286,30 +280,32 @@ class ExpressionLearner:
|
||||
all_expressions = await session.execute(select(Expression))
|
||||
all_expressions = all_expressions.scalars().all()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
# 优化: 批量处理所有修改
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
await session.delete(expr)
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
await session.delete(expr)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
# 优化: 统一提交所有更改(从N次提交减少到1次)
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
await session.commit()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
@@ -16,6 +18,102 @@ from .message import MessageSending
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageUpdateBatcher:
|
||||
"""
|
||||
消息更新批处理器
|
||||
|
||||
优化: 将多个消息ID更新操作批量处理,减少数据库连接次数
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 20, flush_interval: float = 2.0):
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.pending_updates: deque = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动自动刷新任务"""
|
||||
if self._flush_task is None:
|
||||
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
||||
logger.debug("消息更新批处理器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
|
||||
# 刷新剩余的更新
|
||||
await self.flush()
|
||||
logger.debug("消息更新批处理器已停止")
|
||||
|
||||
async def add_update(self, mmc_message_id: str, qq_message_id: str):
|
||||
"""添加消息ID更新到批处理队列"""
|
||||
async with self._lock:
|
||||
self.pending_updates.append((mmc_message_id, qq_message_id))
|
||||
|
||||
# 如果达到批量大小,立即刷新
|
||||
if len(self.pending_updates) >= self.batch_size:
|
||||
await self.flush()
|
||||
|
||||
async def flush(self):
|
||||
"""执行批量更新"""
|
||||
async with self._lock:
|
||||
if not self.pending_updates:
|
||||
return
|
||||
|
||||
updates = list(self.pending_updates)
|
||||
self.pending_updates.clear()
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
updated_count = 0
|
||||
for mmc_id, qq_id in updates:
|
||||
result = await session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == mmc_id)
|
||||
.values(message_id=qq_id)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
updated_count += 1
|
||||
|
||||
await session.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新消息ID失败: {e}")
|
||||
|
||||
async def _auto_flush_loop(self):
|
||||
"""自动刷新循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动刷新出错: {e}")
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_update_batcher = None
|
||||
|
||||
|
||||
def get_message_update_batcher() -> MessageUpdateBatcher:
|
||||
"""获取全局消息更新批处理器"""
|
||||
global _message_update_batcher
|
||||
if _message_update_batcher is None:
|
||||
_message_update_batcher = MessageUpdateBatcher()
|
||||
return _message_update_batcher
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords) -> str:
|
||||
@@ -216,8 +314,16 @@ class MessageStorage:
|
||||
traceback.print_exc()
|
||||
|
||||
@staticmethod
|
||||
async def update_message(message_data: dict):
|
||||
"""更新消息ID(从消息字典)"""
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典
|
||||
use_batch: 是否使用批处理(默认True)
|
||||
"""
|
||||
try:
|
||||
# 从字典中提取信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
@@ -255,23 +361,29 @@ class MessageStorage:
|
||||
logger.debug(f"消息段数据: {segment_data}")
|
||||
return
|
||||
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
# 优化: 使用批处理器减少数据库连接
|
||||
if use_batch:
|
||||
batcher = get_message_update_batcher()
|
||||
await batcher.add_update(mmc_message_id, qq_message_id)
|
||||
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
|
||||
else:
|
||||
# 直接更新(保留原有逻辑用于特殊情况)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
matched_message = (
|
||||
await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)
|
||||
).scalar()
|
||||
async with get_db_session() as session:
|
||||
matched_message = (
|
||||
await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
await session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
|
||||
if matched_message:
|
||||
await session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user