rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
parent 08a9a2c2e8
commit cb97b2d8d3
50 changed files with 742 additions and 759 deletions

View File

@@ -263,8 +263,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import delete
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id")
if not message_id:
@@ -291,8 +291,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import update
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id")
if not message_id:

View File

@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
from sqlalchemy import delete, select
from src.common.database.core.models import AntiInjectionStats
from src.common.database.core import get_db_session
from src.common.database.core.models import AntiInjectionStats
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -8,8 +8,8 @@ import datetime
from sqlalchemy import select
from src.common.database.core.models import BanUser
from src.common.database.core import get_db_session
from src.common.database.core.models import BanUser
from src.common.logger import get_logger
from ..types import DetectionResult

View File

@@ -15,9 +15,9 @@ from rich.traceback import install
from sqlalchemy import select
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Emoji, Images
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -215,7 +215,7 @@ class MaiEmoji:
else:
await crud.delete(will_delete_emoji.id)
result = 1 # Successfully deleted one record
# 使缓存失效
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
@@ -708,7 +708,7 @@ class EmojiManager:
try:
# 使用CRUD进行查询
crud = CRUDBase(Emoji)
if emoji_hash:
# 查询特定hash的表情包
emoji_record = await crud.get_by(emoji_hash=emoji_hash)

View File

@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("energy_system")
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数
try:
from sqlalchemy import select
from src.common.database.core.models import ChatStreams

View File

@@ -236,12 +236,12 @@ class ExpressionLearner:
"""
获取指定chat_id的style和grammar表达方式带10分钟缓存
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 使用CRUD和缓存减少数据库访问
"""
# 使用静态方法以正确处理缓存键
return await self._get_expressions_by_chat_id_cached(self.chat_id)
@staticmethod
@cached(ttl=600, key_prefix="chat_expressions")
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
@@ -278,7 +278,7 @@ class ExpressionLearner:
async def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
优化: 使用CRUD批量处理所有更改最后统一提交
"""
try:
@@ -288,7 +288,7 @@ class ExpressionLearner:
updated_count = 0
deleted_count = 0
# 需要手动操作的情况下使用session
async with get_db_session() as session:
# 批量处理所有修改
@@ -391,7 +391,7 @@ class ExpressionLearner:
current_time = time.time()
# 存储到数据库 Expression 表
crud = CRUDBase(Expression)
CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
@@ -437,10 +437,10 @@ class ExpressionLearner:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 提交后清除相关缓存
await session.commit()
# 清除该chat_id的表达方式缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key

View File

@@ -9,10 +9,8 @@ from json_repair import repair_json
from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -152,7 +150,7 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 使用CRUD查询由于需要IN条件使用session
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
@@ -224,7 +222,7 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
for chat_id, expr_type, situation, style in updates_by_key:
async with get_db_session() as session:
query = await session.execute(
@@ -247,7 +245,7 @@ class ExpressionSelector:
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
# 清除所有受影响的chat_id的缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key

View File

@@ -728,7 +728,6 @@ class MemorySystem:
context = context or {}
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
resolved_user_id = GLOBAL_MEMORY_SCOPE
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()

View File

@@ -4,15 +4,14 @@ import time
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.api.crud import CRUDBase
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
@@ -708,7 +707,7 @@ class ChatManager:
# 使用CRUD批量查询
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
"platform": model_instance.user_platform,

View File

@@ -22,14 +22,14 @@ logger = get_logger("message_storage")
class MessageStorageBatcher:
"""
消息存储批处理器
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
"""
初始化批处理器
Args:
batch_size: 批量大小,达到此数量立即写入
flush_interval: 自动刷新间隔(秒)
@@ -51,7 +51,7 @@ class MessageStorageBatcher:
async def stop(self):
"""停止批处理器"""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
@@ -67,7 +67,7 @@ class MessageStorageBatcher:
async def add_message(self, message_data: dict):
"""
添加消息到批处理队列
Args:
message_data: 包含消息对象和chat_stream的字典
{
@@ -97,23 +97,23 @@ class MessageStorageBatcher:
start_time = time.time()
success_count = 0
try:
# 🔧 优化准备字典数据而不是ORM对象使用批量INSERT
messages_dicts = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data['message'],
msg_data['chat_stream']
msg_data["message"],
msg_data["chat_stream"]
)
if message_dict:
messages_dicts.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
continue
# 批量写入数据库 - 使用高效的批量INSERT
if messages_dicts:
from sqlalchemy import insert
@@ -122,7 +122,7 @@ class MessageStorageBatcher:
await session.execute(stmt)
await session.commit()
success_count = len(messages_dicts)
elapsed = time.time() - start_time
logger.info(
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
@@ -134,18 +134,18 @@ class MessageStorageBatcher:
async def _prepare_message_dict(self, message, chat_stream):
"""准备消息字典数据用于批量INSERT
这个方法准备字典而不是ORM对象性能更高
"""
message_obj = await self._prepare_message_object(message, chat_stream)
if message_obj is None:
return None
# 将ORM对象转换为字典只包含列字段
message_dict = {}
for column in Messages.__table__.columns:
message_dict[column.name] = getattr(message_obj, column.name)
return message_dict
async def _prepare_message_object(self, message, chat_stream):
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
is_public_notice = getattr(message, 'is_public_notice', False)
notice_type = getattr(message, 'notice_type', None)
actions = getattr(message, 'actions', None)
should_reply = getattr(message, 'should_reply', None)
should_act = getattr(message, 'should_act', None)
additional_config = getattr(message, 'additional_config', None)
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
actions = getattr(message, "actions", None)
should_reply = getattr(message, "should_reply", None)
should_act = getattr(message, "should_act", None)
additional_config = getattr(message, "additional_config", None)
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
# 全局批处理器实例
_message_storage_batcher: Optional[MessageStorageBatcher] = None
_message_storage_batcher: MessageStorageBatcher | None = None
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
@@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher:
class MessageUpdateBatcher:
"""
消息更新批处理器
优化: 将多个消息ID更新操作批量处理减少数据库连接次数
"""
@@ -478,7 +478,7 @@ class MessageStorage:
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
"""
存储消息到数据库
Args:
message: 消息对象
chat_stream: 聊天流对象
@@ -488,11 +488,11 @@ class MessageStorage:
if use_batch:
batcher = get_message_storage_batcher()
await batcher.add_message({
'message': message,
'chat_stream': chat_stream
"message": message,
"chat_stream": chat_stream
})
return
# 直接写入模式(保留用于特殊场景)
try:
# 过滤敏感信息的正则模式
@@ -675,9 +675,9 @@ class MessageStorage:
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
Args:
message_data: 消息数据字典
use_batch: 是否使用批处理默认True

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any
from src.common.database.compatibility import db_get, db_query, db_save
from src.common.database.compatibility import db_get, db_query
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask

View File

@@ -12,8 +12,8 @@ from PIL import Image
from rich.traceback import install
from sqlalchemy import and_, select
from src.common.database.core.models import ImageDescriptions, Images
from src.common.database.core import get_db_session
from src.common.database.core.models import ImageDescriptions, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest

View File

@@ -25,8 +25,8 @@ from typing import Any
from PIL import Image
from src.common.database.core.models import Videos
from src.common.database.core import get_db_session
from src.common.database.core.models import Videos
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest