rufffffff
This commit is contained in:
@@ -263,8 +263,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
@@ -291,8 +291,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from src.common.database.core.models import AntiInjectionStats
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import AntiInjectionStats
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.core.models import BanUser
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import BanUser
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
@@ -15,9 +15,9 @@ from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Emoji, Images
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -215,7 +215,7 @@ class MaiEmoji:
|
||||
else:
|
||||
await crud.delete(will_delete_emoji.id)
|
||||
result = 1 # Successfully deleted one record
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
try:
|
||||
# 使用CRUD进行查询
|
||||
crud = CRUDBase(Emoji)
|
||||
|
||||
|
||||
if emoji_hash:
|
||||
# 查询特定hash的表情包
|
||||
emoji_record = await crud.get_by(emoji_hash=emoji_hash)
|
||||
|
||||
@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("energy_system")
|
||||
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
|
||||
# 从数据库获取聊天流兴趣分数
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.core.models import ChatStreams
|
||||
|
||||
|
||||
@@ -236,12 +236,12 @@ class ExpressionLearner:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式(带10分钟缓存)
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
|
||||
|
||||
优化: 使用CRUD和缓存,减少数据库访问
|
||||
"""
|
||||
# 使用静态方法以正确处理缓存键
|
||||
return await self._get_expressions_by_chat_id_cached(self.chat_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@cached(ttl=600, key_prefix="chat_expressions")
|
||||
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
@@ -278,7 +278,7 @@ class ExpressionLearner:
|
||||
async def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
|
||||
|
||||
优化: 使用CRUD批量处理所有更改,最后统一提交
|
||||
"""
|
||||
try:
|
||||
@@ -288,7 +288,7 @@ class ExpressionLearner:
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
|
||||
# 需要手动操作的情况下使用session
|
||||
async with get_db_session() as session:
|
||||
# 批量处理所有修改
|
||||
@@ -391,7 +391,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
crud = CRUDBase(Expression)
|
||||
CRUDBase(Expression)
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
async with get_db_session() as session:
|
||||
for new_expr in expr_list:
|
||||
@@ -437,10 +437,10 @@ class ExpressionLearner:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
await session.delete(expr)
|
||||
|
||||
|
||||
# 提交后清除相关缓存
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 清除该chat_id的表达方式缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
|
||||
@@ -9,10 +9,8 @@ from json_repair import repair_json
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -152,7 +150,7 @@ class ExpressionSelector:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
# 使用CRUD查询(由于需要IN条件,使用session)
|
||||
async with get_db_session() as session:
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
@@ -224,7 +222,7 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
affected_chat_ids.add(source_id)
|
||||
|
||||
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
@@ -247,7 +245,7 @@ class ExpressionSelector:
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 清除所有受影响的chat_id的缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
|
||||
@@ -728,7 +728,6 @@ class MemorySystem:
|
||||
context = context or {}
|
||||
|
||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
start_time = time.time()
|
||||
|
||||
@@ -4,15 +4,14 @@ import time
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams # 新增导入
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
@@ -708,7 +707,7 @@ class ChatManager:
|
||||
# 使用CRUD批量查询
|
||||
crud = CRUDBase(ChatStreams)
|
||||
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
|
||||
|
||||
|
||||
for model_instance in all_streams:
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
|
||||
@@ -22,14 +22,14 @@ logger = get_logger("message_storage")
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
消息存储批处理器
|
||||
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
@@ -51,7 +51,7 @@ class MessageStorageBatcher:
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
@@ -67,7 +67,7 @@ class MessageStorageBatcher:
|
||||
async def add_message(self, message_data: dict):
|
||||
"""
|
||||
添加消息到批处理队列
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 包含消息对象和chat_stream的字典
|
||||
{
|
||||
@@ -97,23 +97,23 @@ class MessageStorageBatcher:
|
||||
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||
messages_dicts = []
|
||||
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data['message'],
|
||||
msg_data['chat_stream']
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"]
|
||||
)
|
||||
if message_dict:
|
||||
messages_dicts.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 批量写入数据库 - 使用高效的批量INSERT
|
||||
if messages_dicts:
|
||||
from sqlalchemy import insert
|
||||
@@ -122,7 +122,7 @@ class MessageStorageBatcher:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
success_count = len(messages_dicts)
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
@@ -134,18 +134,18 @@ class MessageStorageBatcher:
|
||||
|
||||
async def _prepare_message_dict(self, message, chat_stream):
|
||||
"""准备消息字典数据(用于批量INSERT)
|
||||
|
||||
|
||||
这个方法准备字典而不是ORM对象,性能更高
|
||||
"""
|
||||
message_obj = await self._prepare_message_object(message, chat_stream)
|
||||
if message_obj is None:
|
||||
return None
|
||||
|
||||
|
||||
# 将ORM对象转换为字典(只包含列字段)
|
||||
message_dict = {}
|
||||
for column in Messages.__table__.columns:
|
||||
message_dict[column.name] = getattr(message_obj, column.name)
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
is_public_notice = getattr(message, 'is_public_notice', False)
|
||||
notice_type = getattr(message, 'notice_type', None)
|
||||
actions = getattr(message, 'actions', None)
|
||||
should_reply = getattr(message, 'should_reply', None)
|
||||
should_act = getattr(message, 'should_act', None)
|
||||
additional_config = getattr(message, 'additional_config', None)
|
||||
is_public_notice = getattr(message, "is_public_notice", False)
|
||||
notice_type = getattr(message, "notice_type", None)
|
||||
actions = getattr(message, "actions", None)
|
||||
should_reply = getattr(message, "should_reply", None)
|
||||
should_act = getattr(message, "should_act", None)
|
||||
additional_config = getattr(message, "additional_config", None)
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
|
||||
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_storage_batcher: Optional[MessageStorageBatcher] = None
|
||||
_message_storage_batcher: MessageStorageBatcher | None = None
|
||||
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher:
|
||||
class MessageUpdateBatcher:
|
||||
"""
|
||||
消息更新批处理器
|
||||
|
||||
|
||||
优化: 将多个消息ID更新操作批量处理,减少数据库连接次数
|
||||
"""
|
||||
|
||||
@@ -478,7 +478,7 @@ class MessageStorage:
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
|
||||
"""
|
||||
存储消息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
chat_stream: 聊天流对象
|
||||
@@ -488,11 +488,11 @@ class MessageStorage:
|
||||
if use_batch:
|
||||
batcher = get_message_storage_batcher()
|
||||
await batcher.add_message({
|
||||
'message': message,
|
||||
'chat_stream': chat_stream
|
||||
"message": message,
|
||||
"chat_stream": chat_stream
|
||||
})
|
||||
return
|
||||
|
||||
|
||||
# 直接写入模式(保留用于特殊场景)
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -675,9 +675,9 @@ class MessageStorage:
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典
|
||||
use_batch: 是否使用批处理(默认True)
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.compatibility import db_get, db_query, db_save
|
||||
from src.common.database.compatibility import db_get, db_query
|
||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
@@ -12,8 +12,8 @@ from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.common.database.core.models import ImageDescriptions, Images
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import ImageDescriptions, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -25,8 +25,8 @@ from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from src.common.database.core.models import Videos
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Videos
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
Reference in New Issue
Block a user