Merge pull request #56 from MoFox-Studio/feature/database-refactoring
重构数据库系统,优化数据库性能
This commit is contained in:
@@ -263,7 +263,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
@@ -290,7 +291,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
|
||||
@@ -9,7 +9,8 @@ from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||
from src.common.database.core.models import AntiInjectionStats
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_models import BanUser, get_db_session
|
||||
from src.common.database.core.models import BanUser
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
@@ -15,8 +15,10 @@ from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Emoji, Images
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -204,16 +206,23 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
|
||||
will_delete_emoji = result.scalar_one_or_none()
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
else:
|
||||
await session.delete(will_delete_emoji)
|
||||
result = 1 # Successfully deleted one record
|
||||
await session.commit()
|
||||
# 使用CRUD进行删除
|
||||
crud = CRUDBase(Emoji)
|
||||
will_delete_emoji = await crud.get_by(emoji_hash=self.hash)
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
else:
|
||||
await crud.delete(will_delete_emoji.id)
|
||||
result = 1 # Successfully deleted one record
|
||||
|
||||
# 使缓存失效
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
await cache.delete(generate_cache_key("emoji_by_hash", self.hash))
|
||||
await cache.delete(generate_cache_key("emoji_description", self.hash))
|
||||
await cache.delete(generate_cache_key("emoji_tag", self.hash))
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {e!s}")
|
||||
result = 0
|
||||
@@ -697,23 +706,27 @@ class EmojiManager:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
if emoji_hash:
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
query = result.scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
result = await session.execute(select(Emoji))
|
||||
query = result.scalars().all()
|
||||
# 使用CRUD进行查询
|
||||
crud = CRUDBase(Emoji)
|
||||
|
||||
if emoji_hash:
|
||||
# 查询特定hash的表情包
|
||||
emoji_record = await crud.get_by(emoji_hash=emoji_hash)
|
||||
emoji_instances = [emoji_record] if emoji_record else []
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
# 查询所有表情包
|
||||
from src.common.database.api.query import QueryBuilder
|
||||
query = QueryBuilder(Emoji)
|
||||
emoji_instances = await query.all()
|
||||
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
return emoji_objects
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
return emoji_objects
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}")
|
||||
@@ -734,8 +747,9 @@ class EmojiManager:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
@cached(ttl=1800, key_prefix="emoji_tag") # 缓存30分钟
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
"""根据哈希值获取已注册表情包的描述(带30分钟缓存)
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包的哈希值
|
||||
@@ -765,8 +779,9 @@ class EmojiManager:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
|
||||
return None
|
||||
|
||||
@cached(ttl=1800, key_prefix="emoji_description") # 缓存30分钟
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
"""根据哈希值获取已注册表情包的描述(带30分钟缓存)
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包的哈希值
|
||||
|
||||
@@ -10,6 +10,8 @@ from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("energy_system")
|
||||
@@ -203,21 +205,19 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.database.core.models import ChatStreams
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
# 使用CRUD进行查询(已有缓存)
|
||||
crud = CRUDBase(ChatStreams)
|
||||
stream = await crud.get_by(stream_id=stream_id)
|
||||
|
||||
if stream and stream.stream_interest_score is not None:
|
||||
interest_score = float(stream.stream_interest_score)
|
||||
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
||||
return interest_score
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
|
||||
return 0.3
|
||||
if stream and stream.stream_interest_score is not None:
|
||||
interest_score = float(stream.stream_interest_score)
|
||||
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
||||
return interest_score
|
||||
else:
|
||||
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
|
||||
return 0.3
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取聊天流兴趣度失败,使用默认值: {e}")
|
||||
|
||||
@@ -10,8 +10,10 @@ from sqlalchemy import select
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -232,21 +234,26 @@ class ExpressionLearner:
|
||||
|
||||
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
获取指定chat_id的style和grammar表达方式(带10分钟缓存)
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
|
||||
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
|
||||
优化: 使用CRUD和缓存,减少数据库访问
|
||||
"""
|
||||
# 使用静态方法以正确处理缓存键
|
||||
return await self._get_expressions_by_chat_id_cached(self.chat_id)
|
||||
|
||||
@staticmethod
|
||||
@cached(ttl=600, key_prefix="chat_expressions")
|
||||
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""内部方法:从数据库获取表达方式(带缓存)"""
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 优化: 一次查询获取所有表达方式
|
||||
async with get_db_session() as session:
|
||||
all_expressions = await session.execute(
|
||||
select(Expression).where(Expression.chat_id == self.chat_id)
|
||||
)
|
||||
# 使用CRUD查询
|
||||
crud = CRUDBase(Expression)
|
||||
all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000)
|
||||
|
||||
for expr in all_expressions.scalars():
|
||||
for expr in all_expressions:
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
|
||||
@@ -255,7 +262,7 @@ class ExpressionLearner:
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"source_id": chat_id,
|
||||
"type": expr.type,
|
||||
"create_date": create_date,
|
||||
}
|
||||
@@ -272,18 +279,19 @@ class ExpressionLearner:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
|
||||
优化: 批量处理所有更改,最后统一提交,避免逐条提交
|
||||
优化: 使用CRUD批量处理所有更改,最后统一提交
|
||||
"""
|
||||
try:
|
||||
# 使用CRUD查询所有表达方式
|
||||
crud = CRUDBase(Expression)
|
||||
all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
# 需要手动操作的情况下使用session
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
all_expressions = await session.execute(select(Expression))
|
||||
all_expressions = all_expressions.scalars().all()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
# 优化: 批量处理所有修改
|
||||
# 批量处理所有修改
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
@@ -383,10 +391,12 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
crud = CRUDBase(Expression)
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
async with get_db_session() as session:
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
|
||||
query = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
@@ -416,7 +426,7 @@ class ExpressionLearner:
|
||||
)
|
||||
session.add(new_expression)
|
||||
|
||||
# 限制最大数量
|
||||
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
|
||||
exprs_result = await session.execute(
|
||||
select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
@@ -427,6 +437,15 @@ class ExpressionLearner:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
await session.delete(expr)
|
||||
|
||||
# 提交后清除相关缓存
|
||||
await session.commit()
|
||||
|
||||
# 清除该chat_id的表达方式缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
|
||||
# 🔥 训练 StyleLearner
|
||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||
|
||||
@@ -9,8 +9,10 @@ from json_repair import repair_json
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -150,6 +152,8 @@ class ExpressionSelector:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 使用CRUD查询(由于需要IN条件,使用session)
|
||||
async with get_db_session() as session:
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = await session.execute(
|
||||
@@ -207,6 +211,7 @@ class ExpressionSelector:
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
affected_chat_ids = set()
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
@@ -218,6 +223,8 @@ class ExpressionSelector:
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
affected_chat_ids.add(source_id)
|
||||
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
@@ -240,6 +247,13 @@ class ExpressionSelector:
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 清除所有受影响的chat_id的缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
for chat_id in affected_chat_ids:
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
|
||||
@@ -649,8 +649,8 @@ class BotInterestManager:
|
||||
# 导入SQLAlchemy相关模块
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
@@ -731,8 +731,8 @@ class BotInterestManager:
|
||||
# 导入SQLAlchemy相关模块
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
# 将兴趣标签转换为JSON格式
|
||||
tags_data = []
|
||||
|
||||
@@ -9,8 +9,8 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -9,8 +9,10 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams # 新增导入
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
@@ -441,16 +443,20 @@ class ChatManager:
|
||||
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
async def _db_find_stream_async(s_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (
|
||||
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
model_instance = await _db_find_stream_async(stream_id)
|
||||
# 使用优化后的API查询(带缓存)
|
||||
model_instance, _ = await get_or_create_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
defaults={
|
||||
"user_platform": user_info.platform if user_info else platform,
|
||||
"user_id": user_info.user_id if user_info else "",
|
||||
"user_nickname": user_info.user_nickname if user_info else "",
|
||||
"user_cardname": user_info.user_cardname if user_info else "",
|
||||
"group_platform": group_info.platform if group_info else None,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
"group_name": group_info.group_name if group_info else None,
|
||||
}
|
||||
)
|
||||
|
||||
if model_instance:
|
||||
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||||
@@ -696,9 +702,11 @@ class ChatManager:
|
||||
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(ChatStreams))
|
||||
for model_instance in result.scalars().all():
|
||||
# 使用CRUD批量查询
|
||||
crud = CRUDBase(ChatStreams)
|
||||
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
|
||||
|
||||
for model_instance in all_streams:
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
@@ -734,7 +742,6 @@ class ChatManager:
|
||||
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
await session.commit()
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
|
||||
@@ -3,13 +3,14 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
@@ -18,6 +19,309 @@ from .message import MessageSending
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
消息存储批处理器
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval
|
||||
self.pending_messages: deque = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task = None
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""启动自动刷新任务"""
|
||||
if self._flush_task is None and not self._running:
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._auto_flush_loop())
|
||||
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
|
||||
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
|
||||
# 刷新剩余的消息
|
||||
await self.flush()
|
||||
logger.info("消息存储批处理器已停止")
|
||||
|
||||
async def add_message(self, message_data: dict):
|
||||
"""
|
||||
添加消息到批处理队列
|
||||
|
||||
Args:
|
||||
message_data: 包含消息对象和chat_stream的字典
|
||||
{
|
||||
'message': DatabaseMessages | MessageSending,
|
||||
'chat_stream': ChatStream
|
||||
}
|
||||
"""
|
||||
async with self._lock:
|
||||
self.pending_messages.append(message_data)
|
||||
|
||||
# 如果达到批量大小,立即刷新
|
||||
if len(self.pending_messages) >= self.batch_size:
|
||||
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
|
||||
await self.flush()
|
||||
|
||||
async def flush(self):
|
||||
"""执行批量写入"""
|
||||
async with self._lock:
|
||||
if not self.pending_messages:
|
||||
return
|
||||
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
|
||||
if not messages_to_store:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
|
||||
try:
|
||||
# 准备所有消息对象
|
||||
messages_objects = []
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_obj = await self._prepare_message_object(
|
||||
msg_data['message'],
|
||||
msg_data['chat_stream']
|
||||
)
|
||||
if message_obj:
|
||||
messages_objects.append(message_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息对象失败: {e}")
|
||||
continue
|
||||
|
||||
# 批量写入数据库
|
||||
if messages_objects:
|
||||
async with get_db_session() as session:
|
||||
session.add_all(messages_objects)
|
||||
await session.commit()
|
||||
success_count = len(messages_objects)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
f"(耗时: {elapsed:.3f}秒)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量存储消息失败: {e}", exc_info=True)
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
"""准备消息对象(从原 store_message 逻辑提取)"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# 如果是 DatabaseMessages,直接使用它的字段
|
||||
if isinstance(message, DatabaseMessages):
|
||||
processed_plain_text = message.processed_plain_text
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
safe_processed_plain_text = processed_plain_text or ""
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
chat_id = message.chat_id
|
||||
reply_to = ""
|
||||
is_mentioned = message.is_mentioned
|
||||
interest_value = message.interest_value or 0.0
|
||||
priority_mode = ""
|
||||
priority_info_json = None
|
||||
is_emoji = message.is_emoji or False
|
||||
is_picid = message.is_picid or False
|
||||
is_notify = message.is_notify or False
|
||||
is_command = message.is_command or False
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
memorized_times = 0
|
||||
|
||||
user_platform = message.user_info.platform if message.user_info else ""
|
||||
user_id = message.user_info.user_id if message.user_info else ""
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||
|
||||
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
|
||||
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
|
||||
else:
|
||||
# MessageSending 处理逻辑
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
safe_processed_plain_text = processed_plain_text or ""
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_display_message = re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
||||
interest_value = 0
|
||||
is_mentioned = False
|
||||
reply_to = message.reply_to
|
||||
priority_mode = ""
|
||||
priority_info = {}
|
||||
is_emoji = False
|
||||
is_picid = False
|
||||
is_notify = False
|
||||
is_command = False
|
||||
key_words = ""
|
||||
key_words_lite = ""
|
||||
else:
|
||||
filtered_display_message = ""
|
||||
interest_value = message.interest_value
|
||||
is_mentioned = message.is_mentioned
|
||||
reply_to = ""
|
||||
priority_mode = message.priority_mode
|
||||
priority_info = message.priority_info
|
||||
is_emoji = message.is_emoji
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict()
|
||||
|
||||
msg_id = message.message_info.message_id
|
||||
msg_time = float(message.message_info.time or time.time())
|
||||
chat_id = chat_stream.stream_id
|
||||
memorized_times = message.memorized_times
|
||||
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
|
||||
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
||||
|
||||
user_platform = user_info_dict.get("platform")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
|
||||
chat_info_stream_id = chat_info_dict.get("stream_id")
|
||||
chat_info_platform = chat_info_dict.get("platform")
|
||||
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
|
||||
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
|
||||
chat_info_user_platform = user_info_from_chat.get("platform")
|
||||
chat_info_user_id = user_info_from_chat.get("user_id")
|
||||
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
|
||||
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
|
||||
chat_info_group_platform = group_info_from_chat.get("platform")
|
||||
chat_info_group_id = group_info_from_chat.get("group_id")
|
||||
chat_info_group_name = group_info_from_chat.get("group_name")
|
||||
|
||||
# 创建消息对象
|
||||
return Messages(
|
||||
message_id=msg_id,
|
||||
time=msg_time,
|
||||
chat_id=chat_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_stream_id,
|
||||
chat_info_platform=chat_info_platform,
|
||||
chat_info_user_platform=chat_info_user_platform,
|
||||
chat_info_user_id=chat_info_user_id,
|
||||
chat_info_user_nickname=chat_info_user_nickname,
|
||||
chat_info_user_cardname=chat_info_user_cardname,
|
||||
chat_info_group_platform=chat_info_group_platform,
|
||||
chat_info_group_id=chat_info_group_id,
|
||||
chat_info_group_name=chat_info_group_name,
|
||||
chat_info_create_time=chat_info_create_time,
|
||||
chat_info_last_active_time=chat_info_last_active_time,
|
||||
user_platform=user_platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息对象失败: {e}")
|
||||
return None
|
||||
|
||||
async def _auto_flush_loop(self):
|
||||
"""自动刷新循环"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动刷新失败: {e}")
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_storage_batcher: Optional[MessageStorageBatcher] = None
|
||||
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
||||
|
||||
|
||||
def get_message_storage_batcher() -> MessageStorageBatcher:
|
||||
"""获取消息存储批处理器单例"""
|
||||
global _message_storage_batcher
|
||||
if _message_storage_batcher is None:
|
||||
_message_storage_batcher = MessageStorageBatcher(
|
||||
batch_size=50, # 批量大小:50条消息
|
||||
flush_interval=5.0 # 刷新间隔:5秒
|
||||
)
|
||||
return _message_storage_batcher
|
||||
|
||||
|
||||
class MessageUpdateBatcher:
|
||||
"""
|
||||
消息更新批处理器
|
||||
@@ -102,10 +406,6 @@ class MessageUpdateBatcher:
|
||||
logger.error(f"自动刷新出错: {e}")
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_update_batcher = None
|
||||
|
||||
|
||||
def get_message_update_batcher() -> MessageUpdateBatcher:
|
||||
"""获取全局消息更新批处理器"""
|
||||
global _message_update_batcher
|
||||
@@ -133,8 +433,25 @@ class MessageStorage:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
|
||||
"""
|
||||
存储消息到数据库
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
chat_stream: 聊天流对象
|
||||
use_batch: 是否使用批处理(默认True,推荐)。设为False时立即写入数据库。
|
||||
"""
|
||||
# 使用批处理器(推荐)
|
||||
if use_batch:
|
||||
batcher = get_message_storage_batcher()
|
||||
await batcher.add_message({
|
||||
'message': message,
|
||||
'chat_stream': chat_stream
|
||||
})
|
||||
return
|
||||
|
||||
# 直接写入模式(保留用于特殊场景)
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
@@ -367,7 +684,7 @@ class MessageStorage:
|
||||
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
|
||||
else:
|
||||
# 直接更新(保留原有逻辑用于特殊情况)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
from src.common.database.core import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
matched_message = (
|
||||
@@ -510,7 +827,7 @@ class MessageStorage:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
query = (
|
||||
|
||||
@@ -8,8 +8,8 @@ from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ActionRecords, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
@@ -990,7 +990,7 @@ async def build_readable_messages(
|
||||
# 从第一条消息中获取chat_id
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.compatibility import get_db_session
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
|
||||
@@ -3,8 +3,8 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.database.compatibility import db_get, db_query, db_save
|
||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -102,8 +102,9 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
)
|
||||
else:
|
||||
# 创建新记录
|
||||
new_record = await db_save(
|
||||
new_record = await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="create",
|
||||
data={
|
||||
"timestamp": str(current_time),
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
|
||||
@@ -12,7 +12,8 @@ from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session
|
||||
from src.common.database.core.models import ImageDescriptions, Images
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -25,7 +25,8 @@ from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||
from src.common.database.core.models import Videos
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
Reference in New Issue
Block a user