diff --git a/bot.py b/bot.py index 5fbd894cd..827d1e61e 100644 --- a/bot.py +++ b/bot.py @@ -282,7 +282,7 @@ class DatabaseManager: async def __aenter__(self): """异步上下文管理器入口""" try: - from src.common.database.database import initialize_sql_database + from src.common.database.core import check_and_migrate_database as initialize_sql_database from src.config.config import global_config logger.info("正在初始化数据库连接...") @@ -560,7 +560,7 @@ class MaiBotMain: logger.info("正在初始化数据库表结构...") try: start_time = time.time() - from src.common.database.sqlalchemy_models import initialize_database + from src.common.database.core.models import initialize_database await initialize_database() elapsed_time = time.time() - start_time diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py index c3ed2785e..d1e8a47b6 100644 --- a/scripts/check_expression_database.py +++ b/scripts/check_expression_database.py @@ -11,8 +11,8 @@ sys.path.insert(0, str(project_root)) from sqlalchemy import func, select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Expression +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Expression async def check_database(): diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py index eb4cec41e..980f3a07a 100644 --- a/scripts/check_style_field.py +++ b/scripts/check_style_field.py @@ -10,8 +10,8 @@ sys.path.insert(0, str(project_root)) from sqlalchemy import select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Expression +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Expression async def analyze_style_fields(): diff --git a/scripts/update_database_imports.py b/scripts/update_database_imports.py new file mode 100644 index 000000000..2e8df9bf5 --- /dev/null +++ b/scripts/update_database_imports.py @@ -0,0 +1,186 @@ +"""批量更新数据库导入语句的脚本 + +将旧的数据库导入路径更新为新的重构后的路径: +- sqlalchemy_models -> core, core.models +- sqlalchemy_database_api -> compatibility +- database.database -> core +""" + +import re +from pathlib import Path +from typing import Dict, List, Tuple + +# 定义导入映射规则 +IMPORT_MAPPINGS = { + # 模型导入 + r'from src\.common\.database\.sqlalchemy_models import (.+)': + r'from src.common.database.core.models import \1', + + # API导入 - 需要特殊处理 + r'from src\.common\.database\.sqlalchemy_database_api import (.+)': + r'from src.common.database.compatibility import \1', + + # get_db_session 从 sqlalchemy_database_api 导入 + r'from src\.common\.database\.sqlalchemy_database_api import get_db_session': + r'from src.common.database.core import get_db_session', + + # get_db_session 从 sqlalchemy_models 导入 + r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)': + lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}' + if 'get_db_session' in m.group(0) else m.group(0), + + # get_engine 导入 + r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)': + lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}', + + # Base 导入 + r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)': + lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}', + + # initialize_database 导入 + r'from src\.common\.database\.sqlalchemy_models import initialize_database': + r'from src.common.database.core import check_and_migrate_database as initialize_database', + + # database.py 导入 + r'from src\.common\.database\.database import stop_database': + r'from src.common.database.core import close_engine as stop_database', + + r'from src\.common\.database\.database import initialize_sql_database': + r'from src.common.database.core import check_and_migrate_database as initialize_sql_database', +} + +# 需要排除的文件 +EXCLUDE_PATTERNS = [ + '**/database_refactoring_plan.md', # 文档文件 + '**/old/**', # 旧文件目录 + '**/sqlalchemy_*.py', # 旧的数据库文件本身 + '**/database.py', # 旧的database文件 + '**/db_*.py', # 旧的db文件 +] + + +def should_exclude(file_path: Path) -> bool: + """检查文件是否应该被排除""" + for pattern in EXCLUDE_PATTERNS: + if file_path.match(pattern): + return True + return False + + +def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]: + """更新单个文件中的导入语句 + + Args: + file_path: 文件路径 + dry_run: 是否只是预览而不实际修改 + + Returns: + (修改次数, 修改详情列表) + """ + try: + content = file_path.read_text(encoding='utf-8') + original_content = content + changes = [] + + # 应用每个映射规则 + for pattern, replacement in IMPORT_MAPPINGS.items(): + matches = list(re.finditer(pattern, content)) + for match in matches: + old_line = match.group(0) + + # 处理函数类型的替换 + if callable(replacement): + new_line_result = replacement(match) + new_line = new_line_result if isinstance(new_line_result, str) else old_line + else: + new_line = re.sub(pattern, replacement, old_line) + + if old_line != new_line and isinstance(new_line, str): + content = content.replace(old_line, new_line, 1) + changes.append(f" - {old_line}") + changes.append(f" + {new_line}") + + # 如果有修改且不是dry_run,写回文件 + if content != original_content: + if not dry_run: + file_path.write_text(content, encoding='utf-8') + return len(changes) // 2, changes + + return 0, [] + + except Exception as e: + print(f"❌ 处理文件 {file_path} 时出错: {e}") + return 0, [] + + +def main(): + """主函数""" + print("🔍 搜索需要更新导入的文件...") + + # 获取项目根目录 + root_dir = Path(__file__).parent.parent + + # 搜索所有Python文件 + all_python_files = list(root_dir.rglob("*.py")) + + # 过滤掉排除的文件 + target_files = [f for f in all_python_files if not should_exclude(f)] + + print(f"📊 找到 {len(target_files)} 个Python文件需要检查") + print("\n" + "="*80) + + # 第一遍:预览模式 + print("\n🔍 预览模式 - 检查需要更新的文件...\n") + + files_to_update = [] + for file_path in target_files: + count, changes = update_imports_in_file(file_path, dry_run=True) + if count > 0: + files_to_update.append((file_path, count, changes)) + + if not files_to_update: + print("✅ 没有文件需要更新!") + return + + print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n") + + total_changes = 0 + for file_path, count, changes in files_to_update: + rel_path = file_path.relative_to(root_dir) + print(f"\n📄 {rel_path} ({count} 处修改)") + for change in changes[:10]: # 最多显示前5对修改 + print(change) + if len(changes) > 10: + print(f" ... 还有 {len(changes) - 10} 行") + total_changes += count + + print("\n" + "="*80) + print(f"\n📊 统计:") + print(f" - 需要更新的文件: {len(files_to_update)}") + print(f" - 总修改次数: {total_changes}") + + # 询问是否继续 + print("\n" + "="*80) + response = input("\n是否执行更新?(yes/no): ").strip().lower() + + if response != 'yes': + print("❌ 已取消更新") + return + + # 第二遍:实际更新 + print("\n✨ 开始更新文件...\n") + + success_count = 0 + for file_path, _, _ in files_to_update: + count, _ = update_imports_in_file(file_path, dry_run=False) + if count > 0: + rel_path = file_path.relative_to(root_dir) + print(f"✅ {rel_path} ({count} 处修改)") + success_count += 1 + + print("\n" + "="*80) + print(f"\n🎉 完成!成功更新 {success_count} 个文件") + + +if __name__ == "__main__": + main() diff --git a/src/api/statistic_router.py b/src/api/statistic_router.py index feda3e911..c65ca1f90 100644 --- a/src/api/statistic_router.py +++ b/src/api/statistic_router.py @@ -4,8 +4,8 @@ from typing import Any, Literal from fastapi import APIRouter, HTTPException, Query -from src.common.database.sqlalchemy_database_api import db_get -from src.common.database.sqlalchemy_models import LLMUsage +from src.common.database.compatibility import db_get +from src.common.database.core.models import LLMUsage from src.common.logger import get_logger from src.config.config import model_config diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 0c946e805..146d6d23b 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -263,7 +263,7 @@ class AntiPromptInjector: try: from sqlalchemy import delete - from src.common.database.sqlalchemy_models import Messages, get_db_session + from src.common.database.core.models import Messages, get_db_session message_id = message_data.get("message_id") if not message_id: @@ -290,7 +290,7 @@ class AntiPromptInjector: try: from sqlalchemy import update - from src.common.database.sqlalchemy_models import Messages, get_db_session + from src.common.database.core.models import Messages, get_db_session message_id = message_data.get("message_id") if not message_id: diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 6871ebecf..50ba52052 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -9,7 +9,7 @@ from typing import Any, TypeVar, cast from sqlalchemy import delete, select -from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session +from src.common.database.core.models import AntiInjectionStats, get_db_session from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index 34bf185c6..4f0711e66 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -8,7 +8,7 @@ import datetime from sqlalchemy import select -from src.common.database.sqlalchemy_models import BanUser, get_db_session +from src.common.database.core.models import BanUser, get_db_session from src.common.logger import get_logger from ..types import DetectionResult diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 22ec31538..df7a50df1 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -15,8 +15,8 @@ from rich.traceback import install from sqlalchemy import select from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Emoji, Images +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Emoji, Images from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/chat/energy_system/energy_manager.py b/src/chat/energy_system/energy_manager.py index 079147812..671575769 100644 --- a/src/chat/energy_system/energy_manager.py +++ b/src/chat/energy_system/energy_manager.py @@ -203,8 +203,8 @@ class RelationshipEnergyCalculator(EnergyCalculator): try: from sqlalchemy import select - from src.common.database.sqlalchemy_database_api import get_db_session - from src.common.database.sqlalchemy_models import ChatStreams + from src.common.database.compatibility import get_db_session + from src.common.database.core.models import ChatStreams async with get_db_session() as session: stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id) diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index da587a181..da0b2e7c6 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -10,8 +10,8 @@ from sqlalchemy import select from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Expression +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Expression from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 2c9dc63f6..7ae894dbf 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -9,8 +9,8 @@ from json_repair import repair_json from sqlalchemy import select from src.chat.utils.prompt import Prompt, global_prompt_manager -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Expression +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Expression from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index a37f777b5..958a0305b 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -649,8 +649,8 @@ class BotInterestManager: # 导入SQLAlchemy相关模块 import orjson - from src.common.database.sqlalchemy_database_api import get_db_session - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + from src.common.database.compatibility import get_db_session + from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests async with get_db_session() as session: # 查询最新的兴趣标签配置 @@ -731,8 +731,8 @@ class BotInterestManager: # 导入SQLAlchemy相关模块 import orjson - from src.common.database.sqlalchemy_database_api import get_db_session - from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests + from src.common.database.compatibility import get_db_session + from src.common.database.core.models import BotPersonalityInterests as DBBotPersonalityInterests # 将兴趣标签转换为JSON格式 tags_data = [] diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index 4bbe93e9c..adea3a607 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -9,8 +9,8 @@ from collections import defaultdict from dataclasses import dataclass, field from typing import Any -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import ChatStreams from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 4f6fbb3d7..789cdc3c5 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -9,8 +9,8 @@ from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.data_models.database_data_model import DatabaseMessages -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import ChatStreams # 新增导入 from src.common.logger import get_logger from src.config.config import global_config # 新增导入 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 1969aba3f..02be78320 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -8,8 +8,8 @@ import orjson from sqlalchemy import desc, select, update from src.common.data_models.database_data_model import DatabaseMessages -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import Images, Messages +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import Images, Messages from src.common.logger import get_logger from .chat_stream import ChatStream @@ -367,7 +367,7 @@ class MessageStorage: logger.debug(f"消息ID更新已加入批处理队列: {mmc_message_id} -> {qq_message_id}") else: # 直接更新(保留原有逻辑用于特殊情况) - from src.common.database.sqlalchemy_models import get_db_session + from src.common.database.core.models import get_db_session async with get_db_session() as session: matched_message = ( @@ -510,7 +510,7 @@ class MessageStorage: async with get_db_session() as session: from sqlalchemy import select, update - from src.common.database.sqlalchemy_models import Messages + from src.common.database.core.models import Messages # 查找需要修复的记录:interest_value为0、null或很小的值 query = ( diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 4cbf4ee11..fb95e4fd1 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -8,8 +8,8 @@ from rich.traceback import install from sqlalchemy import and_, select from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ActionRecords, Images +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import ActionRecords, Images from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages from src.config.config import global_config @@ -990,7 +990,7 @@ async def build_readable_messages( # 从第一条消息中获取chat_id chat_id = copy_messages[0].get("chat_id") if copy_messages else None - from src.common.database.sqlalchemy_database_api import get_db_session + from src.common.database.compatibility import get_db_session async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 8e451113f..985b58026 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -3,8 +3,8 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any -from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save -from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime +from src.common.database.compatibility import db_get, db_query, db_save +from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.logger import get_logger from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 227a45c18..19d8cc1bb 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -12,7 +12,7 @@ from PIL import Image from rich.traceback import install from sqlalchemy import and_, select -from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session +from src.common.database.core.models import ImageDescriptions, Images, get_db_session from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 5d99d9ca8..ca402d2cf 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -25,7 +25,7 @@ from typing import Any from PIL import Image -from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore +from src.common.database.core.models import Videos, get_db_session # type: ignore from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index e8f3b7715..d28ad6f1b 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -8,8 +8,8 @@ import numpy as np import orjson from src.common.config_helpers import resolve_embedding_dimension -from src.common.database.sqlalchemy_database_api import db_query, db_save -from src.common.database.sqlalchemy_models import CacheEntries +from src.common.database.compatibility import db_query, db_save +from src.common.database.core.models import CacheEntries from src.common.logger import get_logger from src.common.vector_db import vector_db_service from src.config.config import global_config, model_config diff --git a/src/common/message_repository.py b/src/common/message_repository.py index b97c000d5..94ff4bac9 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -5,10 +5,10 @@ from typing import Any from sqlalchemy import func, not_, select from sqlalchemy.orm import DeclarativeBase -from src.common.database.sqlalchemy_database_api import get_db_session +from src.common.database.compatibility import get_db_session # from src.common.database.database_model import Messages -from src.common.database.sqlalchemy_models import Messages +from src.common.database.core.models import Messages from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 9855b2446..ad6ff0396 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -4,7 +4,7 @@ from datetime import datetime from PIL import Image -from src.common.database.sqlalchemy_models import LLMUsage, get_db_session +from src.common.database.core.models import LLMUsage, get_db_session from src.common.logger import get_logger from src.config.api_ada_configs import ModelInfo diff --git a/src/main.py b/src/main.py index c11180e43..d5b09edfb 100644 --- a/src/main.py +++ b/src/main.py @@ -220,7 +220,7 @@ class MainSystem: # 停止数据库服务 try: - from src.common.database.database import stop_database + from src.common.database.core import close_engine as stop_database cleanup_tasks.append(("数据库服务", stop_database())) except Exception as e: diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 4c4c3a133..36b432769 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -9,8 +9,8 @@ import orjson from json_repair import repair_json from sqlalchemy import select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import PersonInfo +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import PersonInfo from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index add5039fe..840044c89 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -181,8 +181,8 @@ class RelationshipFetcher: # 5. 从UserRelationships表获取完整关系信息(新系统) try: - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import UserRelationships + from src.common.database.compatibility import db_query + from src.common.database.core.models import UserRelationships # 查询用户关系数据(修复:添加 await) user_id = str(await person_info_manager.get_value(person_id, "user_id")) @@ -243,8 +243,8 @@ class RelationshipFetcher: str: 格式化后的聊天流印象字符串 """ try: - from src.common.database.sqlalchemy_database_api import db_query - from src.common.database.sqlalchemy_models import ChatStreams + from src.common.database.compatibility import db_query + from src.common.database.core.models import ChatStreams # 查询聊天流数据 streams = await db_query( diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index aa6714655..4dc377a81 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -9,7 +9,7 @@ 注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理 """ -from src.common.database.sqlalchemy_database_api import MODEL_MAPPING, db_get, db_query, db_save, store_action_info +from src.common.database.compatibility import MODEL_MAPPING, db_get, db_query, db_save, store_action_info # 保持向后兼容性 __all__ = ["MODEL_MAPPING", "db_get", "db_query", "db_save", "store_action_info"] diff --git a/src/plugin_system/apis/schedule_api.py b/src/plugin_system/apis/schedule_api.py index 2b456456c..8eae53dcb 100644 --- a/src/plugin_system/apis/schedule_api.py +++ b/src/plugin_system/apis/schedule_api.py @@ -52,7 +52,7 @@ from typing import Any import orjson from sqlalchemy import func, select -from src.common.database.sqlalchemy_models import MonthlyPlan, Schedule, get_db_session +from src.common.database.core.models import MonthlyPlan, Schedule, get_db_session from src.common.logger import get_logger from src.schedule.database import get_active_plans_for_month diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py index 038c7407c..c7bc40010 100644 --- a/src/plugin_system/core/permission_manager.py +++ b/src/plugin_system/core/permission_manager.py @@ -10,7 +10,7 @@ from sqlalchemy import delete, select from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.ext.asyncio import async_sessionmaker -from src.common.database.sqlalchemy_models import PermissionNodes, UserPermissions, get_engine +from src.common.database.core.models import PermissionNodes, UserPermissions, get_engine from src.common.logger import get_logger from src.config.config import global_config from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo diff --git a/src/plugin_system/services/relationship_service.py b/src/plugin_system/services/relationship_service.py index e88e04ac2..11b0d8605 100644 --- a/src/plugin_system/services/relationship_service.py +++ b/src/plugin_system/services/relationship_service.py @@ -5,7 +5,7 @@ import time -from src.common.database.sqlalchemy_models import UserRelationships, get_db_session +from src.common.database.core.models import UserRelationships, get_db_session from src.common.logger import get_logger from src.config.config import global_config diff --git a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py index 3074e8b76..d6a66913d 100644 --- a/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/chat_stream_impression_tool.py @@ -9,8 +9,8 @@ from typing import Any, ClassVar from sqlalchemy import select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import ChatStreams from src.common.logger import get_logger from src.config.config import model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py index e172c4600..6a26a8bbe 100644 --- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py +++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_executor.py @@ -11,8 +11,8 @@ from sqlalchemy import select from src.chat.express.expression_selector import expression_selector from src.chat.utils.prompt import Prompt -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import ChatStreams from src.common.logger import get_logger from src.config.config import global_config, model_config from src.individuality.individuality import Individuality diff --git a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py index aa9286251..6c659141d 100644 --- a/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/user_profile_tool.py @@ -10,8 +10,8 @@ from typing import Any, ClassVar import orjson from sqlalchemy import select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import UserRelationships +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import UserRelationships from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py index 7cf0e7c93..c4059f33d 100644 --- a/src/plugins/built_in/maizone_refactored/services/scheduler_service.py +++ b/src/plugins/built_in/maizone_refactored/services/scheduler_service.py @@ -11,8 +11,8 @@ from collections.abc import Callable from sqlalchemy import select -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus +from src.common.database.compatibility import get_db_session +from src.common.database.core.models import MaiZoneScheduleStatus from src.common.logger import get_logger from src.schedule.schedule_manager import schedule_manager diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index 3ff20c2b2..ccc1731b5 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -9,7 +9,7 @@ from json_repair import repair_json from lunar_python import Lunar from src.chat.utils.prompt import global_prompt_manager -from src.common.database.sqlalchemy_models import MonthlyPlan +from src.common.database.core.models import MonthlyPlan from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index 477ce421d..d578619e8 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -5,7 +5,7 @@ from typing import Any import orjson from sqlalchemy import select -from src.common.database.sqlalchemy_models import MonthlyPlan, Schedule, get_db_session +from src.common.database.core.models import MonthlyPlan, Schedule, get_db_session from src.common.logger import get_logger from src.config.config import global_config from src.manager.async_task_manager import AsyncTask, async_task_manager