Merge pull request #56 from MoFox-Studio/feature/database-refactoring

重构数据库系统,优化数据库性能
This commit is contained in:
拾风
2025-11-01 17:38:18 +08:00
committed by GitHub
73 changed files with 8853 additions and 2612 deletions

View File

@@ -263,7 +263,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import delete
from src.common.database.sqlalchemy_models import Messages, get_db_session
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
message_id = message_data.get("message_id")
if not message_id:
@@ -290,7 +291,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import update
from src.common.database.sqlalchemy_models import Messages, get_db_session
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
message_id = message_data.get("message_id")
if not message_id:

View File

@@ -9,7 +9,8 @@ from typing import Any, TypeVar, cast
from sqlalchemy import delete, select
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.database.core.models import AntiInjectionStats
from src.common.database.core import get_db_session
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -8,7 +8,8 @@ import datetime
from sqlalchemy import select
from src.common.database.sqlalchemy_models import BanUser, get_db_session
from src.common.database.core.models import BanUser
from src.common.database.core import get_db_session
from src.common.logger import get_logger
from ..types import DetectionResult

View File

@@ -15,8 +15,10 @@ from rich.traceback import install
from sqlalchemy import select
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Emoji, Images
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -204,16 +206,23 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
async with get_db_session() as session:
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
will_delete_emoji = result.scalar_one_or_none()
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
else:
await session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record
await session.commit()
# 使用CRUD进行删除
crud = CRUDBase(Emoji)
will_delete_emoji = await crud.get_by(emoji_hash=self.hash)
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
else:
await crud.delete(will_delete_emoji.id)
result = 1 # Successfully deleted one record
# 使缓存失效
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("emoji_by_hash", self.hash))
await cache.delete(generate_cache_key("emoji_description", self.hash))
await cache.delete(generate_cache_key("emoji_tag", self.hash))
except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {e!s}")
result = 0
@@ -697,23 +706,27 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表
"""
try:
async with get_db_session() as session:
if emoji_hash:
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
query = result.scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
result = await session.execute(select(Emoji))
query = result.scalars().all()
# 使用CRUD进行查询
crud = CRUDBase(Emoji)
if emoji_hash:
# 查询特定hash的表情包
emoji_record = await crud.get_by(emoji_hash=emoji_hash)
emoji_instances = [emoji_record] if emoji_record else []
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
# 查询所有表情包
from src.common.database.api.query import QueryBuilder
query = QueryBuilder(Emoji)
emoji_instances = await query.all()
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
return emoji_objects
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
return emoji_objects
except Exception as e:
logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}")
@@ -734,8 +747,9 @@ class EmojiManager:
return emoji
return None # 如果循环结束还没找到,则返回 None
@cached(ttl=1800, key_prefix="emoji_tag") # 缓存30分钟
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None:
"""根据哈希值获取已注册表情包的描述
"""根据哈希值获取已注册表情包的描述带30分钟缓存
Args:
emoji_hash: 表情包的哈希值
@@ -765,8 +779,9 @@ class EmojiManager:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
return None
@cached(ttl=1800, key_prefix="emoji_description") # 缓存30分钟
async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None:
"""根据哈希值获取已注册表情包的描述
"""根据哈希值获取已注册表情包的描述带30分钟缓存
Args:
emoji_hash: 表情包的哈希值

View File

@@ -10,6 +10,8 @@ from enum import Enum
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.config.config import global_config
logger = get_logger("energy_system")
@@ -203,21 +205,19 @@ class RelationshipEnergyCalculator(EnergyCalculator):
try:
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.database.core.models import ChatStreams
async with get_db_session() as session:
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
result = await session.execute(stmt)
stream = result.scalar_one_or_none()
# 使用CRUD进行查询已有缓存
crud = CRUDBase(ChatStreams)
stream = await crud.get_by(stream_id=stream_id)
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
return interest_score
else:
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
return 0.3
if stream and stream.stream_interest_score is not None:
interest_score = float(stream.stream_interest_score)
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
return interest_score
else:
logger.debug(f"聊天流 {stream_id} 无兴趣分数,使用默认值")
return 0.3
except Exception as e:
logger.warning(f"获取聊天流兴趣度失败,使用默认值: {e}")

View File

@@ -10,8 +10,10 @@ from sqlalchemy import select
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -232,21 +234,26 @@ class ExpressionLearner:
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
获取指定chat_id的style和grammar表达方式带10分钟缓存
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 一次查询获取所有类型的表达方式,避免多次数据库查询
优化: 使用CRUD和缓存减少数据库访问
"""
# 使用静态方法以正确处理缓存键
return await self._get_expressions_by_chat_id_cached(self.chat_id)
@staticmethod
@cached(ttl=600, key_prefix="chat_expressions")
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""内部方法:从数据库获取表达方式(带缓存)"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 优化: 一次查询获取所有表达方式
async with get_db_session() as session:
all_expressions = await session.execute(
select(Expression).where(Expression.chat_id == self.chat_id)
)
# 使用CRUD查询
crud = CRUDBase(Expression)
all_expressions = await crud.get_multi(chat_id=chat_id, limit=10000)
for expr in all_expressions.scalars():
for expr in all_expressions:
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
@@ -255,7 +262,7 @@ class ExpressionLearner:
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"source_id": chat_id,
"type": expr.type,
"create_date": create_date,
}
@@ -272,18 +279,19 @@ class ExpressionLearner:
"""
对数据库中的所有表达方式应用全局衰减
优化: 批量处理所有更改,最后统一提交,避免逐条提交
优化: 使用CRUD批量处理所有更改,最后统一提交
"""
try:
# 使用CRUD查询所有表达方式
crud = CRUDBase(Expression)
all_expressions = await crud.get_multi(limit=100000) # 获取所有表达方式
updated_count = 0
deleted_count = 0
# 需要手动操作的情况下使用session
async with get_db_session() as session:
# 获取所有表达方式
all_expressions = await session.execute(select(Expression))
all_expressions = all_expressions.scalars().all()
updated_count = 0
deleted_count = 0
# 优化: 批量处理所有修改
# 批量处理所有修改
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
@@ -383,10 +391,12 @@ class ExpressionLearner:
current_time = time.time()
# 存储到数据库 Expression 表
crud = CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
# 注意: get_all_by 不支持复杂条件,这里仍需使用 session
query = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
@@ -416,7 +426,7 @@ class ExpressionLearner:
)
session.add(new_expression)
# 限制最大数量
# 限制最大数量 - 使用 get_all_by_sorted 获取排序结果
exprs_result = await session.execute(
select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
@@ -427,6 +437,15 @@ class ExpressionLearner:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 提交后清除相关缓存
await session.commit()
# 清除该chat_id的表达方式缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", chat_id))
# 🔥 训练 StyleLearner
# 只对 style 类型的表达方式进行训练grammar 不需要训练到模型)

View File

@@ -9,8 +9,10 @@ from json_repair import repair_json
from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -150,6 +152,8 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 使用CRUD查询由于需要IN条件使用session
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = await session.execute(
@@ -207,6 +211,7 @@ class ExpressionSelector:
if not expressions_to_update:
return
updates_by_key = {}
affected_chat_ids = set()
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
@@ -218,6 +223,8 @@ class ExpressionSelector:
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
for chat_id, expr_type, situation, style in updates_by_key:
async with get_db_session() as session:
query = await session.execute(
@@ -240,6 +247,13 @@ class ExpressionSelector:
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
# 清除所有受影响的chat_id的缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
for chat_id in affected_chat_ids:
await cache.delete(generate_cache_key("chat_expressions", chat_id))
async def select_suitable_expressions(
self,

View File

@@ -649,8 +649,8 @@ class BotInterestManager:
# 导入SQLAlchemy相关模块
import orjson
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests
async with get_db_session() as session:
# 查询最新的兴趣标签配置
@@ -731,8 +731,8 @@ class BotInterestManager:
# 导入SQLAlchemy相关模块
import orjson
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests
# 将兴趣标签转换为JSON格式
tags_data = []

View File

@@ -9,8 +9,8 @@ from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -9,8 +9,10 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
@@ -441,16 +443,20 @@ class ChatManager:
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 检查数据库中是否存在
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
.scalars()
.first()
)
model_instance = await _db_find_stream_async(stream_id)
# 使用优化后的API查询带缓存
model_instance, _ = await get_or_create_chat_stream(
stream_id=stream_id,
platform=platform,
defaults={
"user_platform": user_info.platform if user_info else platform,
"user_id": user_info.user_id if user_info else "",
"user_nickname": user_info.user_nickname if user_info else "",
"user_cardname": user_info.user_cardname if user_info else "",
"group_platform": group_info.platform if group_info else None,
"group_id": group_info.group_id if group_info else None,
"group_name": group_info.group_name if group_info else None,
}
)
if model_instance:
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
@@ -696,9 +702,11 @@ class ChatManager:
async def _db_load_all_streams_async():
loaded_streams_data = []
async with get_db_session() as session:
result = await session.execute(select(ChatStreams))
for model_instance in result.scalars().all():
# 使用CRUD批量查询
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -734,7 +742,6 @@ class ChatManager:
"interruption_count": getattr(model_instance, "interruption_count", 0),
}
loaded_streams_data.append(data_for_from_dict)
await session.commit()
return loaded_streams_data
try:

View File

@@ -3,13 +3,14 @@ import re
import time
import traceback
from collections import deque
from typing import Optional
import orjson
from sqlalchemy import desc, select, update
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.database.core import get_db_session
from src.common.database.core.models import Images, Messages
from src.common.logger import get_logger
from .chat_stream import ChatStream
@@ -18,6 +19,309 @@ from .message import MessageSending
logger = get_logger("message_storage")
class MessageStorageBatcher:
"""
消息存储批处理器
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
"""
初始化批处理器
Args:
batch_size: 批量大小,达到此数量立即写入
flush_interval: 自动刷新间隔(秒)
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_messages: deque = deque()
self._lock = asyncio.Lock()
self._flush_task = None
self._running = False
async def start(self):
"""启动自动刷新任务"""
if self._flush_task is None and not self._running:
self._running = True
self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.info(f"消息存储批处理器已启动 (批量大小: {self.batch_size}, 刷新间隔: {self.flush_interval}秒)")
async def stop(self):
"""停止批处理器"""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
self._flush_task = None
# 刷新剩余的消息
await self.flush()
logger.info("消息存储批处理器已停止")
async def add_message(self, message_data: dict):
"""
添加消息到批处理队列
Args:
message_data: 包含消息对象和chat_stream的字典
{
'message': DatabaseMessages | MessageSending,
'chat_stream': ChatStream
}
"""
async with self._lock:
self.pending_messages.append(message_data)
# 如果达到批量大小,立即刷新
if len(self.pending_messages) >= self.batch_size:
logger.debug(f"达到批量大小 {self.batch_size},立即刷新")
await self.flush()
async def flush(self):
"""执行批量写入"""
async with self._lock:
if not self.pending_messages:
return
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
if not messages_to_store:
return
start_time = time.time()
success_count = 0
try:
# 准备所有消息对象
messages_objects = []
for msg_data in messages_to_store:
try:
message_obj = await self._prepare_message_object(
msg_data['message'],
msg_data['chat_stream']
)
if message_obj:
messages_objects.append(message_obj)
except Exception as e:
logger.error(f"准备消息对象失败: {e}")
continue
# 批量写入数据库
if messages_objects:
async with get_db_session() as session:
session.add_all(messages_objects)
await session.commit()
success_count = len(messages_objects)
elapsed = time.time() - start_time
logger.info(
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
f"(耗时: {elapsed:.3f}秒)"
)
except Exception as e:
logger.error(f"批量存储消息失败: {e}", exc_info=True)
async def _prepare_message_object(self, message, chat_stream):
"""准备消息对象(从原 store_message 逻辑提取)"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# 如果是 DatabaseMessages直接使用它的字段
if isinstance(message, DatabaseMessages):
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = ""
is_mentioned = message.is_mentioned
interest_value = message.interest_value or 0.0
priority_mode = ""
priority_info_json = None
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
key_words = ""
key_words_lite = ""
memorized_times = 0
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
filtered_display_message = re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
msg_id = message.message_info.message_id
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
group_info_from_chat = chat_info_dict.get("group_info") or {}
user_info_from_chat = chat_info_dict.get("user_info") or {}
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
chat_info_user_platform = user_info_from_chat.get("platform")
chat_info_user_id = user_info_from_chat.get("user_id")
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
chat_info_group_platform = group_info_from_chat.get("platform")
chat_info_group_id = group_info_from_chat.get("group_id")
chat_info_group_name = group_info_from_chat.get("group_name")
# 创建消息对象
return Messages(
message_id=msg_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
key_words=key_words,
key_words_lite=key_words_lite,
)
except Exception as e:
logger.error(f"准备消息对象失败: {e}")
return None
async def _auto_flush_loop(self):
"""自动刷新循环"""
while self._running:
try:
await asyncio.sleep(self.flush_interval)
await self.flush()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动刷新失败: {e}")
# 全局批处理器实例
_message_storage_batcher: Optional[MessageStorageBatcher] = None
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
def get_message_storage_batcher() -> MessageStorageBatcher:
"""获取消息存储批处理器单例"""
global _message_storage_batcher
if _message_storage_batcher is None:
_message_storage_batcher = MessageStorageBatcher(
batch_size=50, # 批量大小50条消息
flush_interval=5.0 # 刷新间隔5秒
)
return _message_storage_batcher
class MessageUpdateBatcher:
"""
消息更新批处理器
@@ -102,10 +406,6 @@ class MessageUpdateBatcher:
logger.error(f"自动刷新出错: {e}")
# 全局批处理器实例
_message_update_batcher = None
def get_message_update_batcher() -> MessageUpdateBatcher:
"""获取全局消息更新批处理器"""
global _message_update_batcher
@@ -133,8 +433,25 @@ class MessageStorage:
return []
@staticmethod
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
"""
存储消息到数据库
Args:
message: 消息对象
chat_stream: 聊天流对象
use_batch: 是否使用批处理默认True推荐。设为False时立即写入数据库。
"""
# 使用批处理器(推荐)
if use_batch:
batcher = get_message_storage_batcher()
await batcher.add_message({
'message': message,
'chat_stream': chat_stream
})
return
# 直接写入模式(保留用于特殊场景)
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
@@ -367,7 +684,7 @@ class MessageStorage:
logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}")
else:
# 直接更新(保留原有逻辑用于特殊情况)
from src.common.database.sqlalchemy_models import get_db_session
from src.common.database.core import get_db_session
async with get_db_session() as session:
matched_message = (
@@ -510,7 +827,7 @@ class MessageStorage:
async with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
from src.common.database.core.models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值
query = (

View File

@@ -8,8 +8,8 @@ from rich.traceback import install
from sqlalchemy import and_, select
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ActionRecords, Images
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config
@@ -990,7 +990,7 @@ async def build_readable_messages(
# 从第一条消息中获取chat_id
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.compatibility import get_db_session
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id

View File

@@ -3,8 +3,8 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
from src.common.database.compatibility import db_get, db_query, db_save
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
@@ -102,8 +102,9 @@ class OnlineTimeRecordTask(AsyncTask):
)
else:
# 创建新记录
new_record = await db_save(
new_record = await db_query(
model_class=OnlineTime,
query_type="create",
data={
"timestamp": str(current_time),
"duration": 5, # 初始时长为5分钟

View File

@@ -12,7 +12,8 @@ from PIL import Image
from rich.traceback import install
from sqlalchemy import and_, select
from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session
from src.common.database.core.models import ImageDescriptions, Images
from src.common.database.core import get_db_session
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest

View File

@@ -25,7 +25,8 @@ from typing import Any
from PIL import Image
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.database.core.models import Videos
from src.common.database.core import get_db_session
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest