rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
parent 08a9a2c2e8
commit cb97b2d8d3
50 changed files with 742 additions and 759 deletions

View File

@@ -16,7 +16,7 @@ models_file = os.path.join(
print(f"正在清理文件: {models_file}")
# 读取文件
with open(models_file, "r", encoding="utf-8") as f:
with open(models_file, encoding="utf-8") as f:
lines = f.readlines()
# 找到最后一个模型类的结束位置MonthlyPlan的 __table_args__ 结束)
@@ -26,7 +26,7 @@ found_end = False
for i, line in enumerate(lines, 1):
keep_lines.append(line)
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
if i > 580 and line.strip() == ")":
# 再检查前一行是否有 Index 相关内容
@@ -43,7 +43,7 @@ if not found_end:
with open(models_file, "w", encoding="utf-8") as f:
f.writelines(keep_lines)
print(f"✅ 文件清理完成")
print("✅ 文件清理完成")
print(f"保留行数: {len(keep_lines)}")
print(f"原始行数: {len(lines)}")
print(f"删除行数: {len(lines) - len(keep_lines)}")

View File

@@ -4,20 +4,20 @@
import re
# 读取原始文件
with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f:
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
content = f.read()
# 找到get_string_field函数的开始和结束
get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数')
get_string_field_end = content.find('\n\nclass ChatStreams(Base):')
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
get_string_field = content[get_string_field_start:get_string_field_end]
# 找到第一个class定义开始
first_class_pos = content.find('class ChatStreams(Base):')
first_class_pos = content.find("class ChatStreams(Base):")
# 找到所有class定义直到遇到非class的def
# 简单策略:找到所有以"class "开头且继承Base的类
classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)'
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
if matches:
@@ -53,14 +53,14 @@ Base = declarative_base()
'''
new_content = header + get_string_field + '\n\n' + models_content
new_content = header + get_string_field + "\n\n" + models_content
# 写入新文件
with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f:
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
f.write(new_content)
print('✅ Models file rewritten successfully')
print(f'File size: {len(new_content)} characters')
print("✅ Models file rewritten successfully")
print(f"File size: {len(new_content)} characters")
pattern = r"^class \w+\(Base\):"
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
print(f'Number of model classes: {model_count}')
print(f"Number of model classes: {model_count}")

View File

@@ -8,54 +8,53 @@
import re
from pathlib import Path
from typing import Dict, List, Tuple
# 定义导入映射规则
IMPORT_MAPPINGS = {
# 模型导入
r'from src\.common\.database\.sqlalchemy_models import (.+)':
r'from src.common.database.core.models import \1',
r"from src\.common\.database\.sqlalchemy_models import (.+)":
r"from src.common.database.core.models import \1",
# API导入 - 需要特殊处理
r'from src\.common\.database\.sqlalchemy_database_api import (.+)':
r'from src.common.database.compatibility import \1',
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
r"from src.common.database.compatibility import \1",
# get_db_session 从 sqlalchemy_database_api 导入
r'from src\.common\.database\.sqlalchemy_database_api import get_db_session':
r'from src.common.database.core import get_db_session',
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
r"from src.common.database.core import get_db_session",
# get_db_session 从 sqlalchemy_models 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)':
lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}'
if 'get_db_session' in m.group(0) else m.group(0),
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
if "get_db_session" in m.group(0) else m.group(0),
# get_engine 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)':
lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}',
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
# Base 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)':
lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}',
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
# initialize_database 导入
r'from src\.common\.database\.sqlalchemy_models import initialize_database':
r'from src.common.database.core import check_and_migrate_database as initialize_database',
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
r"from src.common.database.core import check_and_migrate_database as initialize_database",
# database.py 导入
r'from src\.common\.database\.database import stop_database':
r'from src.common.database.core import close_engine as stop_database',
r'from src\.common\.database\.database import initialize_sql_database':
r'from src.common.database.core import check_and_migrate_database as initialize_sql_database',
r"from src\.common\.database\.database import stop_database":
r"from src.common.database.core import close_engine as stop_database",
r"from src\.common\.database\.database import initialize_sql_database":
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
}
# 需要排除的文件
EXCLUDE_PATTERNS = [
'**/database_refactoring_plan.md', # 文档文件
'**/old/**', # 旧文件目录
'**/sqlalchemy_*.py', # 旧的数据库文件本身
'**/database.py', # 旧的database文件
'**/db_*.py', # 旧的db文件
"**/database_refactoring_plan.md", # 文档文件
"**/old/**", # 旧文件目录
"**/sqlalchemy_*.py", # 旧的数据库文件本身
"**/database.py", # 旧的database文件
"**/db_*.py", # 旧的db文件
]
@@ -67,47 +66,47 @@ def should_exclude(file_path: Path) -> bool:
return False
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]:
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
"""更新单个文件中的导入语句
Args:
file_path: 文件路径
dry_run: 是否只是预览而不实际修改
Returns:
(修改次数, 修改详情列表)
"""
try:
content = file_path.read_text(encoding='utf-8')
content = file_path.read_text(encoding="utf-8")
original_content = content
changes = []
# 应用每个映射规则
for pattern, replacement in IMPORT_MAPPINGS.items():
matches = list(re.finditer(pattern, content))
for match in matches:
old_line = match.group(0)
# 处理函数类型的替换
if callable(replacement):
new_line_result = replacement(match)
new_line = new_line_result if isinstance(new_line_result, str) else old_line
else:
new_line = re.sub(pattern, replacement, old_line)
if old_line != new_line and isinstance(new_line, str):
content = content.replace(old_line, new_line, 1)
changes.append(f" - {old_line}")
changes.append(f" + {new_line}")
# 如果有修改且不是dry_run写回文件
if content != original_content:
if not dry_run:
file_path.write_text(content, encoding='utf-8')
file_path.write_text(content, encoding="utf-8")
return len(changes) // 2, changes
return 0, []
except Exception as e:
print(f"❌ 处理文件 {file_path} 时出错: {e}")
return 0, []
@@ -116,34 +115,34 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
def main():
"""主函数"""
print("🔍 搜索需要更新导入的文件...")
# 获取项目根目录
root_dir = Path(__file__).parent.parent
# 搜索所有Python文件
all_python_files = list(root_dir.rglob("*.py"))
# 过滤掉排除的文件
target_files = [f for f in all_python_files if not should_exclude(f)]
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
print("\n" + "="*80)
# 第一遍:预览模式
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
files_to_update = []
for file_path in target_files:
count, changes = update_imports_in_file(file_path, dry_run=True)
if count > 0:
files_to_update.append((file_path, count, changes))
if not files_to_update:
print("✅ 没有文件需要更新!")
return
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
total_changes = 0
for file_path, count, changes in files_to_update:
rel_path = file_path.relative_to(root_dir)
@@ -153,23 +152,23 @@ def main():
if len(changes) > 10:
print(f" ... 还有 {len(changes) - 10}")
total_changes += count
print("\n" + "="*80)
print(f"\n📊 统计:")
print("\n📊 统计:")
print(f" - 需要更新的文件: {len(files_to_update)}")
print(f" - 总修改次数: {total_changes}")
# 询问是否继续
print("\n" + "="*80)
response = input("\n是否执行更新?(yes/no): ").strip().lower()
if response != 'yes':
if response != "yes":
print("❌ 已取消更新")
return
# 第二遍:实际更新
print("\n✨ 开始更新文件...\n")
success_count = 0
for file_path, _, _ in files_to_update:
count, _ = update_imports_in_file(file_path, dry_run=False)
@@ -177,7 +176,7 @@ def main():
rel_path = file_path.relative_to(root_dir)
print(f"{rel_path} ({count} 处修改)")
success_count += 1
print("\n" + "="*80)
print(f"\n🎉 完成!成功更新 {success_count} 个文件")

View File

@@ -263,8 +263,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import delete
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id")
if not message_id:
@@ -291,8 +291,8 @@ class AntiPromptInjector:
try:
from sqlalchemy import update
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id")
if not message_id:

View File

@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
from sqlalchemy import delete, select
from src.common.database.core.models import AntiInjectionStats
from src.common.database.core import get_db_session
from src.common.database.core.models import AntiInjectionStats
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -8,8 +8,8 @@ import datetime
from sqlalchemy import select
from src.common.database.core.models import BanUser
from src.common.database.core import get_db_session
from src.common.database.core.models import BanUser
from src.common.logger import get_logger
from ..types import DetectionResult

View File

@@ -15,9 +15,9 @@ from rich.traceback import install
from sqlalchemy import select
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Emoji, Images
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -215,7 +215,7 @@ class MaiEmoji:
else:
await crud.delete(will_delete_emoji.id)
result = 1 # Successfully deleted one record
# 使缓存失效
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
@@ -708,7 +708,7 @@ class EmojiManager:
try:
# 使用CRUD进行查询
crud = CRUDBase(Emoji)
if emoji_hash:
# 查询特定hash的表情包
emoji_record = await crud.get_by(emoji_hash=emoji_hash)

View File

@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("energy_system")
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数
try:
from sqlalchemy import select
from src.common.database.core.models import ChatStreams

View File

@@ -236,12 +236,12 @@ class ExpressionLearner:
"""
获取指定chat_id的style和grammar表达方式带10分钟缓存
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
优化: 使用CRUD和缓存减少数据库访问
"""
# 使用静态方法以正确处理缓存键
return await self._get_expressions_by_chat_id_cached(self.chat_id)
@staticmethod
@cached(ttl=600, key_prefix="chat_expressions")
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
@@ -278,7 +278,7 @@ class ExpressionLearner:
async def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
优化: 使用CRUD批量处理所有更改最后统一提交
"""
try:
@@ -288,7 +288,7 @@ class ExpressionLearner:
updated_count = 0
deleted_count = 0
# 需要手动操作的情况下使用session
async with get_db_session() as session:
# 批量处理所有修改
@@ -391,7 +391,7 @@ class ExpressionLearner:
current_time = time.time()
# 存储到数据库 Expression 表
crud = CRUDBase(Expression)
CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session:
for new_expr in expr_list:
@@ -437,10 +437,10 @@ class ExpressionLearner:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
await session.delete(expr)
# 提交后清除相关缓存
await session.commit()
# 清除该chat_id的表达方式缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key

View File

@@ -9,10 +9,8 @@ from json_repair import repair_json
from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@@ -152,7 +150,7 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 使用CRUD查询由于需要IN条件使用session
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
@@ -224,7 +222,7 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
affected_chat_ids.add(source_id)
for chat_id, expr_type, situation, style in updates_by_key:
async with get_db_session() as session:
query = await session.execute(
@@ -247,7 +245,7 @@ class ExpressionSelector:
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
# 清除所有受影响的chat_id的缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key

View File

@@ -728,7 +728,6 @@ class MemorySystem:
context = context or {}
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
resolved_user_id = GLOBAL_MEMORY_SCOPE
self.status = MemorySystemStatus.RETRIEVING
start_time = time.time()

View File

@@ -4,15 +4,14 @@ import time
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.api.crud import CRUDBase
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
@@ -708,7 +707,7 @@ class ChatManager:
# 使用CRUD批量查询
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
"platform": model_instance.user_platform,

View File

@@ -22,14 +22,14 @@ logger = get_logger("message_storage")
class MessageStorageBatcher:
"""
消息存储批处理器
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
"""
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
"""
初始化批处理器
Args:
batch_size: 批量大小,达到此数量立即写入
flush_interval: 自动刷新间隔(秒)
@@ -51,7 +51,7 @@ class MessageStorageBatcher:
async def stop(self):
"""停止批处理器"""
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
@@ -67,7 +67,7 @@ class MessageStorageBatcher:
async def add_message(self, message_data: dict):
"""
添加消息到批处理队列
Args:
message_data: 包含消息对象和chat_stream的字典
{
@@ -97,23 +97,23 @@ class MessageStorageBatcher:
start_time = time.time()
success_count = 0
try:
# 🔧 优化准备字典数据而不是ORM对象使用批量INSERT
messages_dicts = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data['message'],
msg_data['chat_stream']
msg_data["message"],
msg_data["chat_stream"]
)
if message_dict:
messages_dicts.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
continue
# 批量写入数据库 - 使用高效的批量INSERT
if messages_dicts:
from sqlalchemy import insert
@@ -122,7 +122,7 @@ class MessageStorageBatcher:
await session.execute(stmt)
await session.commit()
success_count = len(messages_dicts)
elapsed = time.time() - start_time
logger.info(
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
@@ -134,18 +134,18 @@ class MessageStorageBatcher:
async def _prepare_message_dict(self, message, chat_stream):
"""准备消息字典数据用于批量INSERT
这个方法准备字典而不是ORM对象性能更高
"""
message_obj = await self._prepare_message_object(message, chat_stream)
if message_obj is None:
return None
# 将ORM对象转换为字典只包含列字段
message_dict = {}
for column in Messages.__table__.columns:
message_dict[column.name] = getattr(message_obj, column.name)
return message_dict
async def _prepare_message_object(self, message, chat_stream):
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
is_public_notice = getattr(message, 'is_public_notice', False)
notice_type = getattr(message, 'notice_type', None)
actions = getattr(message, 'actions', None)
should_reply = getattr(message, 'should_reply', None)
should_act = getattr(message, 'should_act', None)
additional_config = getattr(message, 'additional_config', None)
is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, "notice_type", None)
actions = getattr(message, "actions", None)
should_reply = getattr(message, "should_reply", None)
should_act = getattr(message, "should_act", None)
additional_config = getattr(message, "additional_config", None)
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
# 全局批处理器实例
_message_storage_batcher: Optional[MessageStorageBatcher] = None
_message_storage_batcher: MessageStorageBatcher | None = None
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
@@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher:
class MessageUpdateBatcher:
"""
消息更新批处理器
优化: 将多个消息ID更新操作批量处理减少数据库连接次数
"""
@@ -478,7 +478,7 @@ class MessageStorage:
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
"""
存储消息到数据库
Args:
message: 消息对象
chat_stream: 聊天流对象
@@ -488,11 +488,11 @@ class MessageStorage:
if use_batch:
batcher = get_message_storage_batcher()
await batcher.add_message({
'message': message,
'chat_stream': chat_stream
"message": message,
"chat_stream": chat_stream
})
return
# 直接写入模式(保留用于特殊场景)
try:
# 过滤敏感信息的正则模式
@@ -675,9 +675,9 @@ class MessageStorage:
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
Args:
message_data: 消息数据字典
use_batch: 是否使用批处理默认True

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any
from src.common.database.compatibility import db_get, db_query, db_save
from src.common.database.compatibility import db_get, db_query
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask

View File

@@ -12,8 +12,8 @@ from PIL import Image
from rich.traceback import install
from sqlalchemy import and_, select
from src.common.database.core.models import ImageDescriptions, Images
from src.common.database.core import get_db_session
from src.common.database.core.models import ImageDescriptions, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest

View File

@@ -25,8 +25,8 @@ from typing import Any
from PIL import Image
from src.common.database.core.models import Videos
from src.common.database.core import get_db_session
from src.common.database.core.models import Videos
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest

View File

@@ -9,6 +9,37 @@
"""
# ===== 核心层 =====
# ===== API层 =====
from src.common.database.api import (
AggregateQuery,
CRUDBase,
QueryBuilder,
# ChatStreams API
get_active_streams,
# Messages API
get_chat_history,
get_message_count,
# PersonInfo API
get_or_create_person,
# ActionRecords API
get_recent_actions,
# LLMUsage API
get_usage_statistics,
record_llm_usage,
# 业务API
save_message,
store_action_info,
update_person_affinity,
)
# ===== 兼容层向后兼容旧API=====
from src.common.database.compatibility import (
MODEL_MAPPING,
build_filters,
db_get,
db_query,
db_save,
)
from src.common.database.core import (
Base,
check_and_migrate_database,
@@ -27,29 +58,6 @@ from src.common.database.optimization import (
get_preloader,
)
# ===== API层 =====
from src.common.database.api import (
AggregateQuery,
CRUDBase,
QueryBuilder,
# ActionRecords API
get_recent_actions,
# ChatStreams API
get_active_streams,
# Messages API
get_chat_history,
get_message_count,
# PersonInfo API
get_or_create_person,
# LLMUsage API
get_usage_statistics,
record_llm_usage,
# 业务API
save_message,
store_action_info,
update_person_affinity,
)
# ===== Utils层 =====
from src.common.database.utils import (
cached,
@@ -66,61 +74,52 @@ from src.common.database.utils import (
transactional,
)
# ===== 兼容层向后兼容旧API=====
from src.common.database.compatibility import (
MODEL_MAPPING,
build_filters,
db_get,
db_query,
db_save,
)
__all__ = [
# 核心层
"Base",
"get_engine",
"get_session_factory",
"get_db_session",
"check_and_migrate_database",
# 优化层
"MultiLevelCache",
"DataPreloader",
"AdaptiveBatchScheduler",
"get_cache",
"get_preloader",
"get_batch_scheduler",
# API层 - 基础类
"CRUDBase",
"QueryBuilder",
"AggregateQuery",
# API层 - 业务API
"store_action_info",
"get_recent_actions",
"get_chat_history",
"get_message_count",
"save_message",
"get_or_create_person",
"update_person_affinity",
"get_active_streams",
"record_llm_usage",
"get_usage_statistics",
# Utils层
"retry",
"timeout",
"cached",
"measure_time",
"transactional",
"db_operation",
"get_monitor",
"record_operation",
"record_cache_hit",
"record_cache_miss",
"print_stats",
"reset_stats",
# 兼容层
"MODEL_MAPPING",
"AdaptiveBatchScheduler",
"AggregateQuery",
# 核心层
"Base",
# API层 - 基础类
"CRUDBase",
"DataPreloader",
# 优化层
"MultiLevelCache",
"QueryBuilder",
"build_filters",
"cached",
"check_and_migrate_database",
"db_get",
"db_operation",
"db_query",
"db_save",
"db_get",
"get_active_streams",
"get_batch_scheduler",
"get_cache",
"get_chat_history",
"get_db_session",
"get_engine",
"get_message_count",
"get_monitor",
"get_or_create_person",
"get_preloader",
"get_recent_actions",
"get_session_factory",
"get_usage_statistics",
"measure_time",
"print_stats",
"record_cache_hit",
"record_cache_miss",
"record_llm_usage",
"record_operation",
"reset_stats",
# Utils层
"retry",
"save_message",
# API层 - 业务API
"store_action_info",
"timeout",
"transactional",
"update_person_affinity",
]

View File

@@ -11,49 +11,49 @@ from src.common.database.api.query import AggregateQuery, QueryBuilder
# 业务特定API
from src.common.database.api.specialized import (
# ActionRecords
get_recent_actions,
store_action_info,
# ChatStreams
get_active_streams,
get_or_create_chat_stream,
# LLMUsage
get_usage_statistics,
record_llm_usage,
# Messages
get_chat_history,
get_message_count,
save_message,
get_or_create_chat_stream,
# PersonInfo
get_or_create_person,
update_person_affinity,
# ActionRecords
get_recent_actions,
# LLMUsage
get_usage_statistics,
# UserRelationships
get_user_relationship,
record_llm_usage,
save_message,
store_action_info,
update_person_affinity,
update_relationship_affinity,
)
__all__ = [
"AggregateQuery",
# 基础类
"CRUDBase",
"QueryBuilder",
"AggregateQuery",
# ActionRecords API
"store_action_info",
"get_recent_actions",
"get_active_streams",
# Messages API
"get_chat_history",
"get_message_count",
"save_message",
# PersonInfo API
"get_or_create_person",
"update_person_affinity",
# ChatStreams API
"get_or_create_chat_stream",
"get_active_streams",
# LLMUsage API
"record_llm_usage",
# PersonInfo API
"get_or_create_person",
"get_recent_actions",
"get_usage_statistics",
# UserRelationships API
"get_user_relationship",
# LLMUsage API
"record_llm_usage",
"save_message",
# ActionRecords API
"store_action_info",
"update_person_affinity",
"update_relationship_affinity",
]

View File

@@ -110,12 +110,12 @@ class CRUDBase:
if instance is not None:
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
# 从字典重建对象返回detached状态所有字段已加载
return _dict_to_model(self.model, instance_dict)
@@ -159,12 +159,12 @@ class CRUDBase:
if instance is not None:
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
# 从字典重建对象返回detached状态所有字段已加载
return _dict_to_model(self.model, instance_dict)
@@ -206,7 +206,7 @@ class CRUDBase:
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
if isinstance(value, (list, tuple, set)):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
@@ -219,12 +219,12 @@ class CRUDBase:
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
instances_dicts = [_model_to_dict(inst) for inst in instances]
# 写入缓存
if use_cache:
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
@@ -266,13 +266,13 @@ class CRUDBase:
await session.refresh(instance)
# 注意commit在get_db_session的context manager退出时自动执行
# 但为了明确性这里不需要显式commit
# 注意create不清除缓存因为
# 1. 新记录不会影响已有的单条查询缓存get/get_by
# 2. get_multi的缓存会自然过期TTL机制
# 3. 清除所有缓存代价太大,影响性能
# 如果需要强一致性应该在查询时设置use_cache=False
return instance
async def update(
@@ -397,7 +397,7 @@ class CRUDBase:
# 应用过滤条件
for key, value in filters.items():
if hasattr(self.model, key):
if isinstance(value, (list, tuple, set)):
if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value))
else:
stmt = stmt.where(getattr(self.model, key) == value)
@@ -466,14 +466,14 @@ class CRUDBase:
for instance in instances:
await session.refresh(instance)
# 批量创建的缓存策略:
# bulk_create通常用于批量导入场景此时清除缓存是合理的
# 因为可能创建大量记录,缓存的列表查询会明显过期
cache = await get_cache()
await cache.clear()
logger.info(f"批量创建{len(instances)}{self.model_name}记录后已清除缓存")
return instances
async def bulk_update(

View File

@@ -207,12 +207,12 @@ class QueryBuilder(Generic[T]):
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
instances_dicts = [_model_to_dict(inst) for inst in instances]
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
@@ -241,12 +241,12 @@ class QueryBuilder(Generic[T]):
if instance is not None:
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
# 从字典重建对象返回detached状态所有字段已加载
return _dict_to_model(self.model, instance_dict)

View File

@@ -4,7 +4,7 @@
"""
import time
from typing import Any, Optional
from typing import Any
import orjson
@@ -42,11 +42,11 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: dict | None = None,
action_name: str = "",
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""存储动作信息到数据库
Args:
chat_stream: 聊天流对象
action_build_into_prompt: 是否将此动作构建到提示中
@@ -55,7 +55,7 @@ async def store_action_info(
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
保存的记录数据或None
"""
@@ -71,7 +71,7 @@ async def store_action_info(
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update(
@@ -89,20 +89,20 @@ async def store_action_info(
"chat_info_platform": "",
}
)
# 使用get_or_create保存记录
saved_record, created = await _action_records_crud.get_or_create(
defaults=record_data,
action_id=action_id,
)
if saved_record:
logger.debug(f"成功存储动作信息: {action_name} (ID: {action_id})")
return {col.name: getattr(saved_record, col.name) for col in saved_record.__table__.columns}
else:
logger.error(f"存储动作信息失败: {action_name}")
return None
except Exception as e:
logger.error(f"存储动作信息时发生错误: {e}", exc_info=True)
return None
@@ -113,11 +113,11 @@ async def get_recent_actions(
limit: int = 10,
) -> list[ActionRecords]:
"""获取最近的动作记录
Args:
chat_id: 聊天ID
limit: 限制数量
Returns:
动作记录列表
"""
@@ -132,12 +132,12 @@ async def get_chat_history(
offset: int = 0,
) -> list[Messages]:
"""获取聊天历史
Args:
stream_id: 流ID
limit: 限制数量
offset: 偏移量
Returns:
消息列表
"""
@@ -153,10 +153,10 @@ async def get_chat_history(
async def get_message_count(stream_id: str) -> int:
"""获取消息数量
Args:
stream_id: 流ID
Returns:
消息数量
"""
@@ -167,13 +167,13 @@ async def get_message_count(stream_id: str) -> int:
async def save_message(
message_data: dict[str, Any],
use_batch: bool = True,
) -> Optional[Messages]:
) -> Messages | None:
"""保存消息
Args:
message_data: 消息数据
use_batch: 是否使用批处理
Returns:
保存的消息实例
"""
@@ -185,15 +185,15 @@ async def save_message(
async def get_or_create_person(
platform: str,
person_id: str,
defaults: Optional[dict[str, Any]] = None,
) -> tuple[Optional[PersonInfo], bool]:
defaults: dict[str, Any] | None = None,
) -> tuple[PersonInfo | None, bool]:
"""获取或创建人员信息
Args:
platform: 平台
person_id: 人员ID
defaults: 默认值
Returns:
(人员信息实例, 是否新创建)
"""
@@ -210,12 +210,12 @@ async def update_person_affinity(
affinity_delta: float,
) -> bool:
"""更新人员好感度
Args:
platform: 平台
person_id: 人员ID
affinity_delta: 好感度变化值
Returns:
是否成功
"""
@@ -225,26 +225,26 @@ async def update_person_affinity(
platform=platform,
person_id=person_id,
)
if not person:
logger.warning(f"人员不存在: {platform}/{person_id}")
return False
# 更新好感度
new_affinity = (person.affinity or 0.0) + affinity_delta
await _person_info_crud.update(
person.id,
{"affinity": new_affinity},
)
# 使缓存失效
cache = await get_cache()
cache_key = generate_cache_key("person_info", platform, person_id)
await cache.delete(cache_key)
logger.debug(f"更新好感度: {platform}/{person_id} {affinity_delta:+.2f} -> {new_affinity:.2f}")
return True
except Exception as e:
logger.error(f"更新好感度失败: {e}", exc_info=True)
return False
@@ -255,15 +255,15 @@ async def update_person_affinity(
async def get_or_create_chat_stream(
stream_id: str,
platform: str,
defaults: Optional[dict[str, Any]] = None,
) -> tuple[Optional[ChatStreams], bool]:
defaults: dict[str, Any] | None = None,
) -> tuple[ChatStreams | None, bool]:
"""获取或创建聊天流
Args:
stream_id: 流ID
platform: 平台
defaults: 默认值
Returns:
(聊天流实例, 是否新创建)
"""
@@ -275,23 +275,23 @@ async def get_or_create_chat_stream(
async def get_active_streams(
platform: Optional[str] = None,
platform: str | None = None,
limit: int = 100,
) -> list[ChatStreams]:
"""获取活跃的聊天流
Args:
platform: 平台(可选)
limit: 限制数量
Returns:
聊天流列表
"""
query = QueryBuilder(ChatStreams)
if platform:
query = query.filter(platform=platform)
return await query.order_by("-last_message_time").limit(limit).all()
@@ -300,20 +300,20 @@ async def record_llm_usage(
model_name: str,
input_tokens: int,
output_tokens: int,
stream_id: Optional[str] = None,
platform: Optional[str] = None,
stream_id: str | None = None,
platform: str | None = None,
user_id: str = "system",
request_type: str = "chat",
model_assign_name: Optional[str] = None,
model_api_provider: Optional[str] = None,
model_assign_name: str | None = None,
model_api_provider: str | None = None,
endpoint: str = "/v1/chat/completions",
cost: float = 0.0,
status: str = "success",
time_cost: Optional[float] = None,
time_cost: float | None = None,
use_batch: bool = True,
) -> Optional[LLMUsage]:
) -> LLMUsage | None:
"""记录LLM使用情况
Args:
model_name: 模型名称
input_tokens: 输入token数
@@ -329,7 +329,7 @@ async def record_llm_usage(
status: 状态
time_cost: 时间成本
use_batch: 是否使用批处理
Returns:
LLM使用记录实例
"""
@@ -346,37 +346,36 @@ async def record_llm_usage(
"model_assign_name": model_assign_name or model_name,
"model_api_provider": model_api_provider or "unknown",
}
if time_cost is not None:
usage_data["time_cost"] = time_cost
return await _llm_usage_crud.create(usage_data, use_batch=use_batch)
async def get_usage_statistics(
start_time: Optional[float] = None,
end_time: Optional[float] = None,
model_name: Optional[str] = None,
start_time: float | None = None,
end_time: float | None = None,
model_name: str | None = None,
) -> dict[str, Any]:
"""获取使用统计
Args:
start_time: 开始时间戳
end_time: 结束时间戳
model_name: 模型名称
Returns:
统计数据字典
"""
from src.common.database.api.query import AggregateQuery
query = AggregateQuery(LLMUsage)
# 添加时间过滤
if start_time:
async with get_db_session() as session:
from sqlalchemy import and_
async with get_db_session():
conditions = []
if start_time:
conditions.append(LLMUsage.timestamp >= start_time)
@@ -384,15 +383,15 @@ async def get_usage_statistics(
conditions.append(LLMUsage.timestamp <= end_time)
if model_name:
conditions.append(LLMUsage.model_name == model_name)
if conditions:
query._conditions = conditions
# 聚合统计
total_input = await query.sum("input_tokens")
total_output = await query.sum("output_tokens")
total_count = await query.filter().count() if hasattr(query, "count") else 0
return {
"total_input_tokens": int(total_input),
"total_output_tokens": int(total_output),
@@ -407,14 +406,14 @@ async def get_user_relationship(
platform: str,
user_id: str,
target_id: str,
) -> Optional[UserRelationships]:
) -> UserRelationships | None:
"""获取用户关系
Args:
platform: 平台
user_id: 用户ID
target_id: 目标用户ID
Returns:
用户关系实例
"""
@@ -432,13 +431,13 @@ async def update_relationship_affinity(
affinity_delta: float,
) -> bool:
"""更新关系好感度
Args:
platform: 平台
user_id: 用户ID
target_id: 目标用户ID
affinity_delta: 好感度变化值
Returns:
是否成功
"""
@@ -450,15 +449,15 @@ async def update_relationship_affinity(
user_id=user_id,
target_id=target_id,
)
if not relationship:
logger.error(f"无法创建关系: {platform}/{user_id}->{target_id}")
return False
# 更新好感度和互动次数
new_affinity = (relationship.affinity or 0.0) + affinity_delta
new_count = (relationship.interaction_count or 0) + 1
await _user_relationships_crud.update(
relationship.id,
{
@@ -467,19 +466,19 @@ async def update_relationship_affinity(
"last_interaction_time": time.time(),
},
)
# 使缓存失效
cache = await get_cache()
cache_key = generate_cache_key("user_relationship", platform, user_id, target_id)
await cache.delete(cache_key)
logger.debug(
f"更新关系: {platform}/{user_id}->{target_id} "
f"好感度{affinity_delta:+.2f}->{new_affinity:.2f} "
f"互动{new_count}"
)
return True
except Exception as e:
logger.error(f"更新关系好感度失败: {e}", exc_info=True)
return False

View File

@@ -14,14 +14,14 @@ from .adapter import (
)
__all__ = [
# 从 core 重新导出的函数
"get_db_session",
"get_engine",
# 兼容层适配器
"MODEL_MAPPING",
"build_filters",
"db_get",
"db_query",
"db_save",
"db_get",
# 从 core 重新导出的函数
"get_db_session",
"get_engine",
"store_action_info",
]

View File

@@ -4,15 +4,13 @@
保持原有函数签名和行为不变
"""
import time
from typing import Any, Optional
import orjson
from sqlalchemy import and_, asc, desc, select
from typing import Any
from src.common.database.api import (
CRUDBase,
QueryBuilder,
)
from src.common.database.api import (
store_action_info as new_store_action_info,
)
from src.common.database.core.models import (
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
Messages,
MonthlyPlan,
OnlineTime,
PersonInfo,
PermissionNodes,
PersonInfo,
Schedule,
ThinkingLog,
UserPermissions,
UserRelationships,
Videos,
)
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger
logger = get_logger("database.compatibility")
@@ -82,11 +79,11 @@ _crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items(
async def build_filters(model_class, filters: dict[str, Any]):
"""构建查询过滤条件兼容MongoDB风格操作符
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件字典
Returns:
条件列表
"""
@@ -127,16 +124,16 @@ async def build_filters(model_class, filters: dict[str, Any]):
def _model_to_dict(instance) -> dict[str, Any]:
"""将模型实例转换为字典
Args:
instance: 模型实例
Returns:
字典表示
"""
if instance is None:
return None
result = {}
for column in instance.__table__.columns:
result[column.name] = getattr(instance, column.name)
@@ -145,15 +142,15 @@ def _model_to_dict(instance) -> dict[str, Any]:
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: Optional[bool] = False,
data: dict[str, Any] | None = None,
query_type: str | None = "get",
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: list[str] | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""执行异步数据库查询操作兼容旧API
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
@@ -162,7 +159,7 @@ async def db_query(
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
@@ -179,7 +176,7 @@ async def db_query(
if query_type == "get":
# 使用QueryBuilder
query_builder = QueryBuilder(model_class)
# 应用过滤条件
if filters:
# 将MongoDB风格过滤器转换为QueryBuilder格式
@@ -202,15 +199,15 @@ async def db_query(
query_builder = query_builder.filter(**{f"{field_name}__nin": op_value})
else:
query_builder = query_builder.filter(**{field_name: value})
# 应用排序
if order_by:
query_builder = query_builder.order_by(*order_by)
# 应用限制
if limit:
query_builder = query_builder.limit(limit)
# 执行查询
if single_result:
result = await query_builder.first()
@@ -223,7 +220,7 @@ async def db_query(
if not data:
logger.error("创建操作需要提供data参数")
return None
instance = await crud.create(data)
return _model_to_dict(instance)
@@ -231,17 +228,17 @@ async def db_query(
if not filters or not data:
logger.error("更新操作需要提供filters和data参数")
return None
# 先查找记录
query_builder = QueryBuilder(model_class)
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
instance = await query_builder.first()
if not instance:
logger.warning(f"未找到匹配的记录: {filters}")
return None
# 更新记录
updated = await crud.update(instance.id, data)
return _model_to_dict(updated)
@@ -250,29 +247,29 @@ async def db_query(
if not filters:
logger.error("删除操作需要提供filters参数")
return None
# 先查找记录
query_builder = QueryBuilder(model_class)
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
instance = await query_builder.first()
if not instance:
logger.warning(f"未找到匹配的记录: {filters}")
return None
# 删除记录
success = await crud.delete(instance.id)
return {"deleted": success}
elif query_type == "count":
query_builder = QueryBuilder(model_class)
# 应用过滤条件
if filters:
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
count = await query_builder.count()
return {"count": count}
@@ -286,15 +283,15 @@ async def db_save(
data: dict[str, Any],
key_field: str,
key_value: Any,
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""保存或更新记录兼容旧API
Args:
model_class: SQLAlchemy模型类
data: 数据字典
key_field: 主键字段名
key_value: 主键值
Returns:
保存的记录数据或None
"""
@@ -303,15 +300,15 @@ async def db_save(
crud = _crud_instances.get(model_name)
if not crud:
crud = CRUDBase(model_class)
# 使用get_or_create (返回tuple[T, bool])
instance, created = await crud.get_or_create(
defaults=data,
**{key_field: key_value},
)
return _model_to_dict(instance)
except Exception as e:
logger.error(f"保存数据库记录出错: {e}", exc_info=True)
return None
@@ -319,20 +316,20 @@ async def db_save(
async def db_get(
model_class,
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: str | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""从数据库获取记录兼容旧API
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
@@ -353,11 +350,11 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: dict | None = None,
action_name: str = "",
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""存储动作信息到数据库兼容旧API
直接使用新的specialized API
"""
return await new_store_action_info(

View File

@@ -5,7 +5,7 @@
import os
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
from urllib.parse import quote_plus
from src.common.logger import get_logger
@@ -16,50 +16,50 @@ logger = get_logger("database_config")
@dataclass
class DatabaseConfig:
"""数据库配置"""
# 基础配置
db_type: str # "sqlite" 或 "mysql"
url: str # 数据库连接URL
# 引擎配置
engine_kwargs: dict[str, Any]
# SQLite特定配置
sqlite_path: Optional[str] = None
sqlite_path: str | None = None
# MySQL特定配置
mysql_host: Optional[str] = None
mysql_port: Optional[int] = None
mysql_user: Optional[str] = None
mysql_password: Optional[str] = None
mysql_database: Optional[str] = None
mysql_host: str | None = None
mysql_port: int | None = None
mysql_user: str | None = None
mysql_password: str | None = None
mysql_database: str | None = None
mysql_charset: str = "utf8mb4"
mysql_unix_socket: Optional[str] = None
mysql_unix_socket: str | None = None
_database_config: Optional[DatabaseConfig] = None
_database_config: DatabaseConfig | None = None
def get_database_config() -> DatabaseConfig:
"""获取数据库配置
从全局配置中读取数据库设置并构建配置对象
"""
global _database_config
if _database_config is not None:
return _database_config
from src.config.config import global_config
config = global_config.database
# 构建数据库URL
if config.database_type == "mysql":
# MySQL配置
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
@@ -75,7 +75,7 @@ def get_database_config() -> DatabaseConfig:
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
@@ -90,7 +90,7 @@ def get_database_config() -> DatabaseConfig:
"connect_timeout": config.connection_timeout,
},
}
_database_config = DatabaseConfig(
db_type="mysql",
url=url,
@@ -103,12 +103,12 @@ def get_database_config() -> DatabaseConfig:
mysql_charset=config.mysql_charset,
mysql_unix_socket=config.mysql_unix_socket,
)
logger.info(
f"MySQL配置已加载: "
f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
else:
# SQLite配置
if not os.path.isabs(config.sqlite_path):
@@ -116,12 +116,12 @@ def get_database_config() -> DatabaseConfig:
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = f"sqlite+aiosqlite:///{db_path}"
engine_kwargs = {
"echo": False,
"future": True,
@@ -130,16 +130,16 @@ def get_database_config() -> DatabaseConfig:
"timeout": 60,
},
}
_database_config = DatabaseConfig(
db_type="sqlite",
url=url,
engine_kwargs=engine_kwargs,
sqlite_path=db_path,
)
logger.info(f"SQLite配置已加载: {db_path}")
return _database_config

View File

@@ -19,7 +19,6 @@ from .models import (
ChatStreams,
Emoji,
Expression,
get_string_field,
GraphEdges,
GraphNodes,
ImageDescriptions,
@@ -37,30 +36,17 @@ from .models import (
UserPermissions,
UserRelationships,
Videos,
get_string_field,
)
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
__all__ = [
# Engine
"get_engine",
"close_engine",
"get_engine_info",
# Session
"get_db_session",
"get_db_session_direct",
"get_session_factory",
"reset_session_factory",
# Migration
"check_and_migrate_database",
"create_all_tables",
"drop_all_tables",
# Models - Base
"Base",
"get_string_field",
# Models - Tables (按字母顺序)
"ActionRecords",
"AntiInjectionStats",
"BanUser",
# Models - Base
"Base",
"BotPersonalityInterests",
"CacheEntries",
"ChatStreams",
@@ -83,4 +69,18 @@ __all__ = [
"UserPermissions",
"UserRelationships",
"Videos",
# Migration
"check_and_migrate_database",
"close_engine",
"create_all_tables",
"drop_all_tables",
# Session
"get_db_session",
"get_db_session_direct",
# Engine
"get_engine",
"get_engine_info",
"get_session_factory",
"get_string_field",
"reset_session_factory",
]

View File

@@ -5,7 +5,6 @@
import asyncio
import os
from typing import Optional
from urllib.parse import quote_plus
from sqlalchemy import text
@@ -18,49 +17,49 @@ from ..utils.exceptions import DatabaseInitializationError
logger = get_logger("database.engine")
# 全局引擎实例
_engine: Optional[AsyncEngine] = None
_engine_lock: Optional[asyncio.Lock] = None
_engine: AsyncEngine | None = None
_engine_lock: asyncio.Lock | None = None
async def get_engine() -> AsyncEngine:
"""获取全局数据库引擎(单例模式)
Returns:
AsyncEngine: SQLAlchemy异步引擎
Raises:
DatabaseInitializationError: 引擎初始化失败
"""
global _engine, _engine_lock
# 快速路径:引擎已初始化
if _engine is not None:
return _engine
# 延迟创建锁(避免在导入时创建)
if _engine_lock is None:
_engine_lock = asyncio.Lock()
# 使用锁保护初始化过程
async with _engine_lock:
# 双重检查锁定模式
if _engine is not None:
return _engine
try:
from src.config.config import global_config
config = global_config.database
db_type = config.database_type
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
# 构建数据库URL和引擎参数
if db_type == "mysql":
# MySQL配置
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
@@ -76,7 +75,7 @@ async def get_engine() -> AsyncEngine:
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
@@ -91,11 +90,11 @@ async def get_engine() -> AsyncEngine:
"connect_timeout": config.connection_timeout,
},
}
logger.info(
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
else:
# SQLite配置
if not os.path.isabs(config.sqlite_path):
@@ -103,12 +102,12 @@ async def get_engine() -> AsyncEngine:
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = f"sqlite+aiosqlite:///{db_path}"
engine_kwargs = {
"echo": False,
"future": True,
@@ -117,19 +116,19 @@ async def get_engine() -> AsyncEngine:
"timeout": 60,
},
}
logger.info(f"SQLite配置: {db_path}")
# 创建异步引擎
_engine = create_async_engine(url, **engine_kwargs)
# SQLite特定优化
if db_type == "sqlite":
await _enable_sqlite_optimizations(_engine)
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
return _engine
except Exception as e:
logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True)
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
@@ -137,11 +136,11 @@ async def get_engine() -> AsyncEngine:
async def close_engine():
"""关闭数据库引擎
释放所有连接池资源
"""
global _engine
if _engine is not None:
logger.info("正在关闭数据库引擎...")
await _engine.dispose()
@@ -151,13 +150,13 @@ async def close_engine():
async def _enable_sqlite_optimizations(engine: AsyncEngine):
"""启用SQLite性能优化
优化项:
- WAL模式提高并发性能
- NORMAL同步平衡性能和安全性
- 启用外键约束
- 设置busy_timeout避免锁定错误
Args:
engine: SQLAlchemy异步引擎
"""
@@ -175,22 +174,22 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
await conn.execute(text("PRAGMA cache_size = -10000"))
# 临时存储使用内存
await conn.execute(text("PRAGMA temp_store = MEMORY"))
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
except Exception as e:
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
async def get_engine_info() -> dict:
"""获取引擎信息(用于监控和调试)
Returns:
dict: 引擎信息字典
"""
try:
engine = await get_engine()
info = {
"name": engine.name,
"driver": engine.driver,
@@ -199,9 +198,9 @@ async def get_engine_info() -> dict:
"pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(),
"pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(),
}
return info
except Exception as e:
logger.error(f"获取引擎信息失败: {e}")
return {}

View File

@@ -20,15 +20,15 @@ logger = get_logger("db_migration")
async def check_and_migrate_database(existing_engine=None):
"""异步检查数据库结构并自动迁移
自动执行以下操作:
- 创建不存在的表
- 为现有表添加缺失的列
- 为现有表创建缺失的索引
Args:
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎
Note:
此函数是幂等的,可以安全地多次调用
"""
@@ -65,7 +65,7 @@ async def check_and_migrate_database(existing_engine=None):
for table in tables_to_create:
logger.info(f"'{table.name}' 创建成功。")
db_table_names.add(table.name) # 将新创建的表添加到集合中
# 提交表创建事务
await connection.commit()
except Exception as e:
@@ -191,40 +191,40 @@ async def check_and_migrate_database(existing_engine=None):
async def create_all_tables(existing_engine=None):
"""创建所有表(不进行迁移检查)
直接创建所有在 Base.metadata 中定义的表。
如果表已存在,将被跳过。
Args:
existing_engine: 可选的已存在的数据库引擎
Note:
生产环境建议使用 check_and_migrate_database()
"""
logger.info("正在创建所有数据库表...")
engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)
logger.info("数据库表创建完成。")
async def drop_all_tables(existing_engine=None):
"""删除所有表(危险操作!)
删除所有在 Base.metadata 中定义的表。
Args:
existing_engine: 可选的已存在的数据库引擎
Warning:
此操作将删除所有数据,不可恢复!仅用于测试环境!
"""
logger.warning("⚠️ 正在删除所有数据库表...")
engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
logger.warning("所有数据库表已删除。")

View File

@@ -6,7 +6,6 @@
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -18,38 +17,38 @@ from .engine import get_engine
logger = get_logger("database.session")
# 全局会话工厂
_session_factory: Optional[async_sessionmaker] = None
_factory_lock: Optional[asyncio.Lock] = None
_session_factory: async_sessionmaker | None = None
_factory_lock: asyncio.Lock | None = None
async def get_session_factory() -> async_sessionmaker:
"""获取会话工厂(单例模式)
Returns:
async_sessionmaker: SQLAlchemy异步会话工厂
"""
global _session_factory, _factory_lock
# 快速路径
if _session_factory is not None:
return _session_factory
# 延迟创建锁
if _factory_lock is None:
_factory_lock = asyncio.Lock()
async with _factory_lock:
# 双重检查
if _session_factory is not None:
return _session_factory
engine = await get_engine()
_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False, # 避免在commit后访问属性时重新查询
)
logger.debug("会话工厂已创建")
return _session_factory
@@ -57,28 +56,28 @@ async def get_session_factory() -> async_sessionmaker:
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话上下文管理器
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
使用示例:
async with get_db_session() as session:
result = await session.execute(select(User))
users = result.scalars().all()
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
# 延迟导入避免循环依赖
from ..optimization.connection_pool import get_connection_pool_manager
session_factory = await get_session_factory()
pool_manager = get_connection_pool_manager()
# 使用连接池管理器(透明复用连接)
async with pool_manager.get_session(session_factory) as session:
# 为SQLite设置特定的PRAGMA
from src.config.config import global_config
if global_config.database.database_type == "sqlite":
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
@@ -86,22 +85,22 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
except Exception:
# 复用连接时PRAGMA可能已设置忽略错误
pass
yield session
@asynccontextmanager
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话(直接模式,不使用连接池)
用于特殊场景,如需要完全独立的连接时。
一般情况下应使用 get_db_session()。
Yields:
AsyncSession: SQLAlchemy异步会话对象
"""
session_factory = await get_session_factory()
async with session_factory() as session:
try:
yield session

View File

@@ -11,17 +11,17 @@ from .batch_scheduler import (
AdaptiveBatchScheduler,
BatchOperation,
BatchStats,
Priority,
close_batch_scheduler,
get_batch_scheduler,
Priority,
)
from .cache_manager import (
CacheEntry,
CacheStats,
close_cache,
get_cache,
LRUCache,
MultiLevelCache,
close_cache,
get_cache,
)
from .connection_pool import (
ConnectionPoolManager,
@@ -31,36 +31,36 @@ from .connection_pool import (
)
from .preloader import (
AccessPattern,
close_preloader,
CommonDataPreloader,
DataPreloader,
close_preloader,
get_preloader,
)
__all__ = [
# Connection Pool
"ConnectionPoolManager",
"get_connection_pool_manager",
"start_connection_pool",
"stop_connection_pool",
# Cache
"MultiLevelCache",
"LRUCache",
"CacheEntry",
"CacheStats",
"get_cache",
"close_cache",
# Preloader
"DataPreloader",
"CommonDataPreloader",
"AccessPattern",
"get_preloader",
"close_preloader",
# Batch Scheduler
"AdaptiveBatchScheduler",
"BatchOperation",
"BatchStats",
"CacheEntry",
"CacheStats",
"CommonDataPreloader",
# Connection Pool
"ConnectionPoolManager",
# Preloader
"DataPreloader",
"LRUCache",
# Cache
"MultiLevelCache",
"Priority",
"get_batch_scheduler",
"close_batch_scheduler",
"close_cache",
"close_preloader",
"get_batch_scheduler",
"get_cache",
"get_connection_pool_manager",
"get_preloader",
"start_connection_pool",
"stop_connection_pool",
]

View File

@@ -10,12 +10,12 @@
import asyncio
import time
from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Callable, Optional, TypeVar
from typing import Any, TypeVar
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger
@@ -36,22 +36,22 @@ class Priority(IntEnum):
@dataclass
class BatchOperation:
"""批量操作"""
operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: type
conditions: dict[str, Any] = field(default_factory=dict)
data: Optional[dict[str, Any]] = None
callback: Optional[Callable] = None
future: Optional[asyncio.Future] = None
data: dict[str, Any] | None = None
callback: Callable | None = None
future: asyncio.Future | None = None
timestamp: float = field(default_factory=time.time)
priority: Priority = Priority.NORMAL
timeout: Optional[float] = None # 超时时间(秒)
timeout: float | None = None # 超时时间(秒)
@dataclass
class BatchStats:
"""批处理统计"""
total_operations: int = 0
batched_operations: int = 0
cache_hits: int = 0
@@ -60,7 +60,7 @@ class BatchStats:
avg_wait_time: float = 0.0
timeout_count: int = 0
error_count: int = 0
# 自适应统计
last_batch_duration: float = 0.0
last_batch_size: int = 0
@@ -69,7 +69,7 @@ class BatchStats:
class AdaptiveBatchScheduler:
"""自适应批量调度器
特性:
- 动态批次大小:根据负载自动调整
- 优先级队列:高优先级操作优先执行
@@ -87,7 +87,7 @@ class AdaptiveBatchScheduler:
cache_ttl: float = 5.0,
):
"""初始化调度器
Args:
min_batch_size: 最小批次大小
max_batch_size: 最大批次大小
@@ -104,23 +104,23 @@ class AdaptiveBatchScheduler:
self.current_wait_time = base_wait_time
self.max_queue_size = max_queue_size
self.cache_ttl = cache_ttl
# 操作队列,按优先级分类
self.operation_queues: dict[Priority, deque[BatchOperation]] = {
priority: deque() for priority in Priority
}
# 调度控制
self._scheduler_task: Optional[asyncio.Task] = None
self._scheduler_task: asyncio.Task | None = None
self._is_running = False
self._lock = asyncio.Lock()
# 统计信息
self.stats = BatchStats()
# 简单的结果缓存
self._result_cache: dict[str, tuple[Any, float]] = {}
logger.info(
f"自适应批量调度器初始化: "
f"批次大小{min_batch_size}-{max_batch_size}, "
@@ -132,7 +132,7 @@ class AdaptiveBatchScheduler:
if self._is_running:
logger.warning("调度器已在运行")
return
self._is_running = True
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
logger.info("批量调度器已启动")
@@ -141,16 +141,16 @@ class AdaptiveBatchScheduler:
"""停止调度器"""
if not self._is_running:
return
self._is_running = False
if self._scheduler_task:
self._scheduler_task.cancel()
try:
await self._scheduler_task
except asyncio.CancelledError:
pass
# 处理剩余操作
await self._flush_all_queues()
logger.info("批量调度器已停止")
@@ -160,10 +160,10 @@ class AdaptiveBatchScheduler:
operation: BatchOperation,
) -> asyncio.Future:
"""添加操作到队列
Args:
operation: 批量操作
Returns:
Future对象可用于获取结果
"""
@@ -175,11 +175,11 @@ class AdaptiveBatchScheduler:
future = asyncio.get_event_loop().create_future()
future.set_result(cached_result)
return future
# 创建future
future = asyncio.get_event_loop().create_future()
operation.future = future
async with self._lock:
# 检查队列是否已满
total_queued = sum(len(q) for q in self.operation_queues.values())
@@ -191,7 +191,7 @@ class AdaptiveBatchScheduler:
# 添加到优先级队列
self.operation_queues[operation.priority].append(operation)
self.stats.total_operations += 1
return future
async def _scheduler_loop(self) -> None:
@@ -217,10 +217,10 @@ class AdaptiveBatchScheduler:
for _ in range(count):
if queue:
operations.append(queue.popleft())
if not operations:
return
# 执行批量操作
await self._execute_operations(operations)
@@ -231,10 +231,10 @@ class AdaptiveBatchScheduler:
"""执行批量操作"""
if not operations:
return
start_time = time.time()
batch_size = len(operations)
try:
# 检查超时
valid_operations = []
@@ -246,41 +246,41 @@ class AdaptiveBatchScheduler:
self.stats.timeout_count += 1
else:
valid_operations.append(op)
if not valid_operations:
return
# 按操作类型分组
op_groups = defaultdict(list)
for op in valid_operations:
key = f"{op.operation_type}_{op.model_class.__name__}"
op_groups[key].append(op)
# 执行各组操作
for group_key, ops in op_groups.items():
for ops in op_groups.values():
await self._execute_group(ops)
# 更新统计
duration = time.time() - start_time
self.stats.batched_operations += batch_size
self.stats.total_execution_time += duration
self.stats.last_batch_duration = duration
self.stats.last_batch_size = batch_size
if self.stats.batched_operations > 0:
self.stats.avg_batch_size = (
self.stats.batched_operations /
self.stats.batched_operations /
(self.stats.total_execution_time / duration)
)
logger.debug(
f"批量执行完成: {batch_size}个操作, 耗时{duration*1000:.2f}ms"
)
except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info=True)
self.stats.error_count += 1
# 设置所有future的异常
for op in operations:
if op.future and not op.future.done():
@@ -290,9 +290,9 @@ class AdaptiveBatchScheduler:
"""执行同类操作组"""
if not operations:
return
op_type = operations[0].operation_type
try:
if op_type == "select":
await self._execute_select_batch(operations)
@@ -304,7 +304,7 @@ class AdaptiveBatchScheduler:
await self._execute_delete_batch(operations)
else:
raise ValueError(f"未知操作类型: {op_type}")
except Exception as e:
logger.error(f"执行{op_type}操作组失败: {e}", exc_info=True)
for op in operations:
@@ -323,30 +323,30 @@ class AdaptiveBatchScheduler:
stmt = select(op.model_class)
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
if isinstance(value, (list, tuple, set)):
if isinstance(value, list | tuple | set):
stmt = stmt.where(attr.in_(value))
else:
stmt = stmt.where(attr == value)
# 执行查询
result = await session.execute(stmt)
data = result.scalars().all()
# 设置结果
if op.future and not op.future.done():
op.future.set_result(data)
# 缓存结果
cache_key = self._generate_cache_key(op)
self._set_cache(cache_key, data)
# 执行回调
if op.callback:
try:
op.callback(data)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"查询失败: {e}", exc_info=True)
if op.future and not op.future.done():
@@ -363,23 +363,23 @@ class AdaptiveBatchScheduler:
all_data = [op.data for op in operations if op.data]
if not all_data:
return
# 批量插入
stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt)
await session.execute(stmt)
await session.commit()
# 设置结果
for op in operations:
if op.future and not op.future.done():
op.future.set_result(True)
if op.callback:
try:
op.callback(True)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量插入失败: {e}", exc_info=True)
await session.rollback()
@@ -402,28 +402,28 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
if op.data:
stmt = stmt.values(**op.data)
# 执行更新但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量更新失败: {e}", exc_info=True)
await session.rollback()
@@ -447,25 +447,25 @@ class AdaptiveBatchScheduler:
for key, value in op.conditions.items():
attr = getattr(op.model_class, key)
stmt = stmt.where(attr == value)
# 执行删除但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
# 所有操作成功后一次性commit
await session.commit()
# 设置所有操作的结果
for op, rowcount in results:
if op.future and not op.future.done():
op.future.set_result(rowcount)
if op.callback:
try:
op.callback(rowcount)
except Exception as e:
logger.warning(f"回调执行失败: {e}")
except Exception as e:
logger.error(f"批量删除失败: {e}", exc_info=True)
await session.rollback()
@@ -479,7 +479,7 @@ class AdaptiveBatchScheduler:
# 计算拥塞评分
total_queued = sum(len(q) for q in self.operation_queues.values())
self.stats.congestion_score = min(1.0, total_queued / self.max_queue_size)
# 根据拥塞情况调整批次大小
if self.stats.congestion_score > 0.7:
# 高拥塞,增加批次大小
@@ -493,7 +493,7 @@ class AdaptiveBatchScheduler:
self.min_batch_size,
int(self.current_batch_size * 0.9),
)
# 根据批次执行时间调整等待时间
if self.stats.last_batch_duration > 0:
if self.stats.last_batch_duration > self.current_wait_time * 2:
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
]
return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果"""
if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key]
@@ -551,27 +551,27 @@ class AdaptiveBatchScheduler:
# 全局调度器实例
_global_scheduler: Optional[AdaptiveBatchScheduler] = None
_global_scheduler: AdaptiveBatchScheduler | None = None
_scheduler_lock = asyncio.Lock()
async def get_batch_scheduler() -> AdaptiveBatchScheduler:
"""获取全局批量调度器(单例)"""
global _global_scheduler
if _global_scheduler is None:
async with _scheduler_lock:
if _global_scheduler is None:
_global_scheduler = AdaptiveBatchScheduler()
await _global_scheduler.start()
return _global_scheduler
async def close_batch_scheduler() -> None:
"""关闭全局批量调度器"""
global _global_scheduler
if _global_scheduler is not None:
await _global_scheduler.stop()
_global_scheduler = None

View File

@@ -11,8 +11,9 @@
import asyncio
import time
from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import Any, Generic, TypeVar
from src.common.logger import get_logger
@@ -24,7 +25,7 @@ T = TypeVar("T")
@dataclass
class CacheEntry(Generic[T]):
"""缓存条目
Attributes:
value: 缓存的值
created_at: 创建时间戳
@@ -42,7 +43,7 @@ class CacheEntry(Generic[T]):
@dataclass
class CacheStats:
"""缓存统计信息
Attributes:
hits: 命中次数
misses: 未命中次数
@@ -70,7 +71,7 @@ class CacheStats:
class LRUCache(Generic[T]):
"""LRU缓存实现
使用OrderedDict实现O(1)的get/set操作
"""
@@ -81,7 +82,7 @@ class LRUCache(Generic[T]):
name: str = "cache",
):
"""初始化LRU缓存
Args:
max_size: 最大缓存条目数
ttl: 过期时间(秒)
@@ -94,18 +95,18 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock()
self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]:
async def get(self, key: str) -> T | None:
"""获取缓存值
Args:
key: 缓存键
Returns:
缓存值如果不存在或已过期返回None
"""
async with self._lock:
entry = self._cache.get(key)
if entry is None:
self._stats.misses += 1
return None
@@ -125,20 +126,20 @@ class LRUCache(Generic[T]):
entry.last_accessed = now
entry.access_count += 1
self._stats.hits += 1
# 移到末尾(最近使用)
self._cache.move_to_end(key)
return entry.value
async def set(
self,
key: str,
value: T,
size: Optional[int] = None,
size: int | None = None,
) -> None:
"""设置缓存值
Args:
key: 缓存键
value: 缓存值
@@ -146,16 +147,16 @@ class LRUCache(Generic[T]):
"""
async with self._lock:
now = time.time()
# 如果键已存在,更新值
if key in self._cache:
old_entry = self._cache[key]
self._stats.total_size -= old_entry.size
# 估算大小
if size is None:
size = self._estimate_size(value)
# 创建新条目
entry = CacheEntry(
value=value,
@@ -164,7 +165,7 @@ class LRUCache(Generic[T]):
access_count=0,
size=size,
)
# 如果缓存已满,淘汰最久未使用的条目
while len(self._cache) >= self.max_size:
oldest_key, oldest_entry = self._cache.popitem(last=False)
@@ -175,7 +176,7 @@ class LRUCache(Generic[T]):
f"[{self.name}] 淘汰缓存条目: {oldest_key} "
f"(访问{oldest_entry.access_count}次)"
)
# 添加新条目
self._cache[key] = entry
self._stats.item_count += 1
@@ -183,10 +184,10 @@ class LRUCache(Generic[T]):
async def delete(self, key: str) -> bool:
"""删除缓存条目
Args:
key: 缓存键
Returns:
是否成功删除
"""
@@ -217,7 +218,7 @@ class LRUCache(Generic[T]):
def _estimate_size(self, value: Any) -> int:
"""估算数据大小(字节)
这是一个简单的估算,实际大小可能不同
"""
import sys
@@ -230,11 +231,11 @@ class LRUCache(Generic[T]):
class MultiLevelCache:
"""多级缓存管理器
实现两级缓存架构:
- L1: 高速缓存小容量短TTL
- L2: 扩展缓存大容量长TTL
查询时先查L1未命中再查L2未命中再从数据源加载
"""
@@ -246,7 +247,7 @@ class MultiLevelCache:
l2_ttl: float = 300,
):
"""初始化多级缓存
Args:
l1_max_size: L1缓存最大条目数
l1_ttl: L1缓存TTL
@@ -255,8 +256,8 @@ class MultiLevelCache:
"""
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_task: asyncio.Task | None = None
logger.info(
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
f"L2({l2_max_size}项/{l2_ttl}s)"
@@ -265,16 +266,16 @@ class MultiLevelCache:
async def get(
self,
key: str,
loader: Optional[Callable[[], Any]] = None,
) -> Optional[Any]:
loader: Callable[[], Any] | None = None,
) -> Any | None:
"""从缓存获取数据
查询顺序L1 -> L2 -> loader
Args:
key: 缓存键
loader: 数据加载函数,当缓存未命中时调用
Returns:
缓存值或加载的值如果都不存在返回None
"""
@@ -307,12 +308,12 @@ class MultiLevelCache:
self,
key: str,
value: Any,
size: Optional[int] = None,
size: int | None = None,
) -> None:
"""设置缓存值
同时写入L1和L2
Args:
key: 缓存键
value: 缓存值
@@ -323,9 +324,9 @@ class MultiLevelCache:
async def delete(self, key: str) -> None:
"""删除缓存条目
同时从L1和L2删除
Args:
key: 缓存键
"""
@@ -347,7 +348,7 @@ class MultiLevelCache:
async def start_cleanup_task(self, interval: float = 60) -> None:
"""启动定期清理任务
Args:
interval: 清理间隔(秒)
"""
@@ -387,27 +388,27 @@ class MultiLevelCache:
# 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None
_global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例)"""
global _global_cache
if _global_cache is None:
async with _cache_lock:
if _global_cache is None:
_global_cache = MultiLevelCache()
await _global_cache.start_cleanup_task()
return _global_cache
async def close_cache() -> None:
"""关闭全局缓存"""
global _global_cache
if _global_cache is not None:
await _global_cache.stop_cleanup_task()
await _global_cache.clear()

View File

@@ -150,7 +150,7 @@ class ConnectionPoolManager:
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
yield connection_info.session
# 🔧 修复:正常退出时提交事务
# 这对SQLite至关重要因为SQLite没有autocommit
if connection_info and connection_info.session:
@@ -249,7 +249,7 @@ class ConnectionPoolManager:
"""获取连接池统计信息"""
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
return {
**self._stats,
"active_connections": len(self._connections),

View File

@@ -10,8 +10,9 @@
import asyncio
import time
from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Optional
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -25,7 +26,7 @@ logger = get_logger("preloader")
@dataclass
class AccessPattern:
"""访问模式统计
Attributes:
key: 数据键
access_count: 访问次数
@@ -42,7 +43,7 @@ class AccessPattern:
class DataPreloader:
"""数据预加载器
通过分析访问模式,预测并预加载可能需要的数据
"""
@@ -53,7 +54,7 @@ class DataPreloader:
max_patterns: int = 1000,
):
"""初始化预加载器
Args:
decay_factor: 时间衰减因子0-1越小衰减越快
preload_threshold: 预加载阈值score超过此值时预加载
@@ -62,7 +63,7 @@ class DataPreloader:
self.decay_factor = decay_factor
self.preload_threshold = preload_threshold
self.max_patterns = max_patterns
# 访问模式跟踪
self._patterns: dict[str, AccessPattern] = {}
# 关联关系key -> [related_keys]
@@ -73,9 +74,9 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
self._lock = asyncio.Lock()
logger.info(
f"数据预加载器初始化: 衰减因子={decay_factor}, "
f"预加载阈值={preload_threshold}"
@@ -84,10 +85,10 @@ class DataPreloader:
async def record_access(
self,
key: str,
related_keys: Optional[list[str]] = None,
related_keys: list[str] | None = None,
) -> None:
"""记录数据访问
Args:
key: 被访问的数据键
related_keys: 关联访问的数据键列表
@@ -95,7 +96,7 @@ class DataPreloader:
async with self._lock:
self._total_accesses += 1
now = time.time()
# 更新或创建访问模式
if key in self._patterns:
pattern = self._patterns[key]
@@ -108,15 +109,15 @@ class DataPreloader:
last_access=now,
)
self._patterns[key] = pattern
# 更新热度评分(时间衰减)
pattern.score = self._calculate_score(pattern)
# 记录关联关系
if related_keys:
self._associations[key].update(related_keys)
pattern.related_keys = list(self._associations[key])
# 如果模式过多,删除评分最低的
if len(self._patterns) > self.max_patterns:
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
@@ -126,10 +127,10 @@ class DataPreloader:
async def should_preload(self, key: str) -> bool:
"""判断是否应该预加载某个数据
Args:
key: 数据键
Returns:
是否应该预加载
"""
@@ -137,18 +138,18 @@ class DataPreloader:
pattern = self._patterns.get(key)
if pattern is None:
return False
# 更新评分
pattern.score = self._calculate_score(pattern)
return pattern.score >= self.preload_threshold
async def get_preload_keys(self, limit: int = 100) -> list[str]:
"""获取应该预加载的数据键列表
Args:
limit: 最大返回数量
Returns:
按评分排序的数据键列表
"""
@@ -156,14 +157,14 @@ class DataPreloader:
# 更新所有评分
for pattern in self._patterns.values():
pattern.score = self._calculate_score(pattern)
# 按评分排序
sorted_patterns = sorted(
self._patterns.values(),
key=lambda p: p.score,
reverse=True,
)
# 返回超过阈值的键
return [
p.key for p in sorted_patterns[:limit]
@@ -172,10 +173,10 @@ class DataPreloader:
async def get_related_keys(self, key: str) -> list[str]:
"""获取关联数据键
Args:
key: 数据键
Returns:
关联数据键列表
"""
@@ -188,27 +189,27 @@ class DataPreloader:
loader: Callable[[], Awaitable[Any]],
) -> None:
"""预加载数据
Args:
key: 数据键
loader: 异步加载函数
"""
try:
cache = await get_cache()
# 检查缓存中是否已存在
if await cache.l1_cache.get(key) is not None:
return
# 加载数据
logger.debug(f"预加载数据: {key}")
data = await loader()
if data is not None:
# 写入缓存
await cache.set(key, data)
self._preload_count += 1
# 预加载关联数据
related_keys = await self.get_related_keys(key)
for related_key in related_keys[:5]: # 最多预加载5个关联项
@@ -216,7 +217,7 @@ class DataPreloader:
# 这里需要调用者提供关联数据的加载函数
# 暂时只记录,不实际加载
logger.debug(f"发现关联数据: {related_key}")
except Exception as e:
logger.error(f"预加载数据失败 {key}: {e}", exc_info=True)
@@ -226,13 +227,13 @@ class DataPreloader:
loaders: dict[str, Callable[[], Awaitable[Any]]],
) -> None:
"""批量启动预加载任务
Args:
session: 数据库会话
loaders: 数据键到加载函数的映射
"""
preload_keys = await self.get_preload_keys()
for key in preload_keys:
if key in loaders:
loader = loaders[key]
@@ -242,9 +243,9 @@ class DataPreloader:
async def record_hit(self, key: str) -> None:
"""记录预加载命中
当缓存命中的数据是预加载的,调用此方法统计
Args:
key: 数据键
"""
@@ -259,7 +260,7 @@ class DataPreloader:
if self._preload_count > 0
else 0.0
)
return {
"total_accesses": self._total_accesses,
"tracked_patterns": len(self._patterns),
@@ -278,7 +279,7 @@ class DataPreloader:
self._total_accesses = 0
self._preload_count = 0
self._preload_hits = 0
# 取消所有预加载任务
for task in self._preload_tasks:
task.cancel()
@@ -286,38 +287,38 @@ class DataPreloader:
def _calculate_score(self, pattern: AccessPattern) -> float:
"""计算热度评分
使用时间衰减的访问频率:
score = access_count * decay_factor^(time_since_last_access)
Args:
pattern: 访问模式
Returns:
热度评分
"""
now = time.time()
time_diff = now - pattern.last_access
# 时间衰减(以小时为单位)
hours_passed = time_diff / 3600
decay = self.decay_factor ** hours_passed
# 评分 = 访问次数 * 时间衰减
score = pattern.access_count * decay
return score
class CommonDataPreloader:
"""常见数据预加载器
针对特定的数据类型提供预加载策略
"""
def __init__(self, preloader: DataPreloader):
"""初始化
Args:
preloader: 基础预加载器
"""
@@ -330,16 +331,16 @@ class CommonDataPreloader:
platform: str,
) -> None:
"""预加载用户相关数据
包括:个人信息、权限、关系等
Args:
session: 数据库会话
user_id: 用户ID
platform: 平台
"""
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
# 预加载个人信息
await self._preload_model(
session,
@@ -347,7 +348,7 @@ class CommonDataPreloader:
PersonInfo,
{"platform": platform, "user_id": user_id},
)
# 预加载用户权限
await self._preload_model(
session,
@@ -355,7 +356,7 @@ class CommonDataPreloader:
UserPermissions,
{"platform": platform, "user_id": user_id},
)
# 预加载用户关系
await self._preload_model(
session,
@@ -371,16 +372,16 @@ class CommonDataPreloader:
limit: int = 50,
) -> None:
"""预加载聊天上下文
包括:最近消息、聊天流信息等
Args:
session: 数据库会话
stream_id: 聊天流ID
limit: 消息数量限制
"""
from src.common.database.core.models import ChatStreams, Messages
from src.common.database.core.models import ChatStreams
# 预加载聊天流信息
await self._preload_model(
session,
@@ -388,7 +389,7 @@ class CommonDataPreloader:
ChatStreams,
{"stream_id": stream_id},
)
# 预加载最近消息(这个比较复杂,暂时跳过)
# TODO: 实现消息列表的预加载
@@ -400,7 +401,7 @@ class CommonDataPreloader:
filters: dict[str, Any],
) -> None:
"""预加载模型数据
Args:
session: 数据库会话
cache_key: 缓存键
@@ -413,31 +414,31 @@ class CommonDataPreloader:
stmt = stmt.where(getattr(model_class, key) == value)
result = await session.execute(stmt)
return result.scalar_one_or_none()
await self.preloader.preload_data(cache_key, loader)
# 全局预加载器实例
_global_preloader: Optional[DataPreloader] = None
_global_preloader: DataPreloader | None = None
_preloader_lock = asyncio.Lock()
async def get_preloader() -> DataPreloader:
"""获取全局预加载器实例(单例)"""
global _global_preloader
if _global_preloader is None:
async with _preloader_lock:
if _global_preloader is None:
_global_preloader = DataPreloader()
return _global_preloader
async def close_preloader() -> None:
"""关闭全局预加载器"""
global _global_preloader
if _global_preloader is not None:
await _global_preloader.clear()
_global_preloader = None

View File

@@ -37,29 +37,29 @@ from .monitoring import (
)
__all__ = [
"BatchSchedulerError",
"CacheError",
"ConnectionPoolError",
"DatabaseConnectionError",
# 异常
"DatabaseError",
"DatabaseInitializationError",
"DatabaseConnectionError",
"DatabaseMigrationError",
# 监控
"DatabaseMonitor",
"DatabaseQueryError",
"DatabaseTransactionError",
"DatabaseMigrationError",
"CacheError",
"BatchSchedulerError",
"ConnectionPoolError",
"cached",
"db_operation",
"get_monitor",
"measure_time",
"print_stats",
"record_cache_hit",
"record_cache_miss",
"record_operation",
"reset_stats",
# 装饰器
"retry",
"timeout",
"cached",
"measure_time",
"transactional",
"db_operation",
# 监控
"DatabaseMonitor",
"get_monitor",
"record_operation",
"record_cache_hit",
"record_cache_miss",
"print_stats",
"reset_stats",
]

View File

@@ -10,9 +10,11 @@ import asyncio
import functools
import hashlib
import time
from typing import Any, Awaitable, Callable, Optional, TypeVar
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
from src.common.logger import get_logger
@@ -25,33 +27,33 @@ def generate_cache_key(
**kwargs: Any,
) -> str:
"""生成与@cached装饰器相同的缓存键
用于手动缓存失效等操作
Args:
key_prefix: 缓存键前缀
*args: 位置参数
**kwargs: 关键字参数
Returns:
缓存键字符串
Example:
cache_key = generate_cache_key("person_info", platform, person_id)
await cache.delete(cache_key)
"""
cache_key_parts = [key_prefix]
if args:
args_str = ",".join(str(arg) for arg in args)
args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8]
cache_key_parts.append(f"args:{args_hash}")
if kwargs:
kwargs_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
kwargs_hash = hashlib.md5(kwargs_str.encode()).hexdigest()[:8]
cache_key_parts.append(f"kwargs:{kwargs_hash}")
return ":".join(cache_key_parts)
T = TypeVar("T")
@@ -65,15 +67,15 @@ def retry(
exceptions: tuple[type[Exception], ...] = (OperationalError, DBAPIError, SQLTimeoutError),
):
"""重试装饰器
自动重试失败的数据库操作,适用于临时性错误
Args:
max_attempts: 最大尝试次数
delay: 初始延迟时间(秒)
backoff: 延迟倍数(指数退避)
exceptions: 需要重试的异常类型
Example:
@retry(max_attempts=3, delay=1.0)
async def query_data():
@@ -114,12 +116,12 @@ def retry(
def timeout(seconds: float):
"""超时装饰器
为数据库操作添加超时控制
Args:
seconds: 超时时间(秒)
Example:
@timeout(30.0)
async def long_query():
@@ -141,21 +143,21 @@ def timeout(seconds: float):
def cached(
ttl: Optional[int] = 300,
key_prefix: Optional[str] = None,
ttl: int | None = 300,
key_prefix: str | None = None,
use_args: bool = True,
use_kwargs: bool = True,
):
"""缓存装饰器
自动缓存函数返回值
Args:
ttl: 缓存过期时间None表示永不过期
key_prefix: 缓存键前缀,默认使用函数名
use_args: 是否将位置参数包含在缓存键中
use_kwargs: 是否将关键字参数包含在缓存键中
Example:
@cached(ttl=60, key_prefix="user_data")
async def get_user_info(user_id: str) -> dict:
@@ -167,7 +169,7 @@ def cached(
async def wrapper(*args: Any, **kwargs: Any) -> T:
# 延迟导入避免循环依赖
from src.common.database.optimization import get_cache
# 生成缓存键
cache_key_parts = [key_prefix or func.__name__]
@@ -207,14 +209,14 @@ def cached(
return decorator
def measure_time(log_slow: Optional[float] = None):
def measure_time(log_slow: float | None = None):
"""性能测量装饰器
测量函数执行时间,可选择性记录慢查询
Args:
log_slow: 慢查询阈值超过此时间会记录warning日志
Example:
@measure_time(log_slow=1.0)
async def complex_query():
@@ -246,19 +248,19 @@ def measure_time(log_slow: Optional[float] = None):
def transactional(auto_commit: bool = True, auto_rollback: bool = True):
"""事务装饰器
自动管理事务的提交和回滚
Args:
auto_commit: 是否自动提交
auto_rollback: 发生异常时是否自动回滚
Example:
@transactional()
async def update_multiple_records(session):
await session.execute(stmt1)
await session.execute(stmt2)
Note:
函数需要接受session参数
"""
@@ -306,20 +308,20 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
# 组合装饰器示例
def db_operation(
retry_attempts: int = 3,
timeout_seconds: Optional[float] = None,
cache_ttl: Optional[int] = None,
timeout_seconds: float | None = None,
cache_ttl: int | None = None,
measure: bool = True,
):
"""组合装饰器
组合多个装饰器,提供完整的数据库操作保护
Args:
retry_attempts: 重试次数
timeout_seconds: 超时时间
cache_ttl: 缓存时间
measure: 是否测量性能
Example:
@db_operation(retry_attempts=3, timeout_seconds=30, cache_ttl=60)
async def important_query():

View File

@@ -4,7 +4,6 @@
"""
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Optional
@@ -22,7 +21,7 @@ class OperationMetrics:
min_time: float = float("inf")
max_time: float = 0.0
error_count: int = 0
last_execution_time: Optional[float] = None
last_execution_time: float | None = None
@property
def avg_time(self) -> float:
@@ -91,7 +90,7 @@ class DatabaseMetrics:
class DatabaseMonitor:
"""数据库监控器
单例模式,收集和报告数据库性能指标
"""
@@ -285,7 +284,7 @@ class DatabaseMonitor:
# 全局监控器实例
_monitor: Optional[DatabaseMonitor] = None
_monitor: DatabaseMonitor | None = None
def get_monitor() -> DatabaseMonitor:

View File

@@ -4,8 +4,8 @@ from datetime import datetime
from PIL import Image
from src.common.database.core.models import LLMUsage
from src.common.database.core import get_db_session
from src.common.database.core.models import LLMUsage
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo

View File

@@ -224,7 +224,7 @@ class MainSystem:
storage_batcher = get_message_storage_batcher()
cleanup_tasks.append(("消息存储批处理器", storage_batcher.stop()))
update_batcher = get_message_update_batcher()
cleanup_tasks.append(("消息更新批处理器", update_batcher.stop()))
except Exception as e:
@@ -502,7 +502,7 @@ MoFox_Bot(第三方修改版)
storage_batcher = get_message_storage_batcher()
await storage_batcher.start()
logger.info("消息存储批处理器已启动")
update_batcher = get_message_update_batcher()
await update_batcher.start()
logger.info("消息更新批处理器已启动")

View File

@@ -256,9 +256,10 @@ class RelationshipFetcher:
str: 格式化后的聊天流印象字符串
"""
try:
from src.common.database.api.specialized import get_or_create_chat_stream
import time
from src.common.database.api.specialized import get_or_create_chat_stream
# 使用优化后的API带缓存
# 从stream_id解析platform或使用默认值
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"
@@ -289,7 +290,7 @@ class RelationshipFetcher:
except Exception as e:
logger.warning(f"访问stream对象属性失败: {e}")
stream_data = {}
impression_parts = []
# 1. 聊天环境基本信息

View File

@@ -52,8 +52,8 @@ from typing import Any
import orjson
from sqlalchemy import func, select
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.logger import get_logger
from src.schedule.database import get_active_plans_for_month

View File

@@ -7,7 +7,7 @@ from src.plugin_system.base.component_types import ChatType, CommandInfo, Compon
from src.plugin_system.base.plus_command import PlusCommand
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
pass
logger = get_logger("base_command")

View File

@@ -10,8 +10,8 @@ from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker
from src.common.database.core.models import PermissionNodes, UserPermissions
from src.common.database.core import get_engine
from src.common.database.core.models import PermissionNodes, UserPermissions
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo

View File

@@ -5,8 +5,8 @@
import time
from src.common.database.core.models import UserRelationships
from src.common.database.core import get_db_session
from src.common.database.core.models import UserRelationships
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -7,12 +7,8 @@
import json
from typing import Any, ClassVar
from sqlalchemy import select
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached
from src.common.database.core.models import ChatStreams
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
@@ -358,14 +354,14 @@ class ChatStreamImpressionTool(BaseTool):
"stream_interest_score": impression.get("stream_interest_score", 0.5),
}
)
# 使缓存失效
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
cache = await get_cache()
await cache.delete(generate_cache_key("stream_impression", stream_id))
await cache.delete(generate_cache_key("chat_stream", stream_id))
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
else:
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"

View File

@@ -64,7 +64,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
stream_config = chat_stream.get_raw_id()
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):

View File

@@ -7,13 +7,10 @@ import json
from datetime import datetime
from typing import Any, Literal
from sqlalchemy import select
from src.chat.express.expression_selector import expression_selector
from src.chat.utils.prompt import Prompt
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.database.api.crud import CRUDBase
from src.common.database.core.models import ChatStreams
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -208,7 +205,7 @@ class ProactiveThinkingPlanner:
# 3. 获取bot人设和时间信息
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
# 构建时间信息块
time_block = f"当前时间是 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
@@ -554,11 +551,11 @@ async def execute_proactive_thinking(stream_id: str):
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream:
# 使用 ChatStream 的 get_raw_id() 方法获取配置字符串
stream_config = chat_stream.get_raw_id()
# 执行白名单/黑名单检查
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
logger.debug(f"聊天流 {stream_id} ({stream_config}) 未通过白名单/黑名单检查,跳过主动思考")
@@ -567,7 +564,7 @@ async def execute_proactive_thinking(stream_id: str):
logger.warning(f"无法获取聊天流 {stream_id} 的信息,跳过白名单检查")
except Exception as e:
logger.warning(f"白名单检查时出错: {e},继续执行")
# 0.2 检查安静时段
if proactive_thinking_scheduler._is_in_quiet_hours():
logger.debug("安静时段,跳过")

View File

@@ -3,8 +3,8 @@
from sqlalchemy import delete, func, select, update
from src.common.database.core.models import MonthlyPlan
from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan
from src.common.logger import get_logger
from src.config.config import global_config
@@ -312,7 +312,7 @@ async def delete_plans_older_than(month: str):
logger.info(f"没有找到比 {month} 更早的月度计划需要删除。")
return 0
plan_months = sorted(list(set(p.target_month for p in plans_to_delete)))
plan_months = sorted({p.target_month for p in plans_to_delete})
logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。")
# 然后,执行删除操作

View File

@@ -100,7 +100,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
next_month = datetime(now.year + 1, 1, 1)
else:
next_month = datetime(now.year, now.month + 1, 1)
sleep_seconds = (next_month - now).total_seconds()
logger.info(
f" 下一次月度计划生成任务将在 {sleep_seconds:.2f} 秒后运行 (北京时间 {next_month.strftime('%Y-%m-%d %H:%M:%S')})"
@@ -110,7 +110,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
# 到达月初,先归档上个月的计划
last_month = (next_month - timedelta(days=1)).strftime("%Y-%m")
await self.monthly_plan_manager.plan_manager.archive_current_month_plans(last_month)
# 为当前月生成新计划
current_month = next_month.strftime("%Y-%m")
logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...")

View File

@@ -97,4 +97,4 @@ MONTHLY_PLAN_GENERATION_PROMPT = Prompt(
请你扮演我,以我的身份和兴趣,为 {target_month} 制定合适的月度计划。
""",
)
)

View File

@@ -5,8 +5,8 @@ from typing import Any
import orjson
from sqlalchemy import select
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.logger import get_logger
from src.config.config import global_config
from src.manager.async_task_manager import AsyncTask, async_task_manager