rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
committed by Windpicker-owo
parent 05daf869d1
commit ff6dc542e1
50 changed files with 742 additions and 759 deletions

View File

@@ -16,7 +16,7 @@ models_file = os.path.join(
print(f"正在清理文件: {models_file}") print(f"正在清理文件: {models_file}")
# 读取文件 # 读取文件
with open(models_file, "r", encoding="utf-8") as f: with open(models_file, encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
# 找到最后一个模型类的结束位置MonthlyPlan的 __table_args__ 结束) # 找到最后一个模型类的结束位置MonthlyPlan的 __table_args__ 结束)
@@ -43,7 +43,7 @@ if not found_end:
with open(models_file, "w", encoding="utf-8") as f: with open(models_file, "w", encoding="utf-8") as f:
f.writelines(keep_lines) f.writelines(keep_lines)
print(f"✅ 文件清理完成") print("✅ 文件清理完成")
print(f"保留行数: {len(keep_lines)}") print(f"保留行数: {len(keep_lines)}")
print(f"原始行数: {len(lines)}") print(f"原始行数: {len(lines)}")
print(f"删除行数: {len(lines) - len(keep_lines)}") print(f"删除行数: {len(lines) - len(keep_lines)}")

View File

@@ -4,20 +4,20 @@
import re import re
# 读取原始文件 # 读取原始文件
with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f: with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
content = f.read() content = f.read()
# 找到get_string_field函数的开始和结束 # 找到get_string_field函数的开始和结束
get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数') get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
get_string_field_end = content.find('\n\nclass ChatStreams(Base):') get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
get_string_field = content[get_string_field_start:get_string_field_end] get_string_field = content[get_string_field_start:get_string_field_end]
# 找到第一个class定义开始 # 找到第一个class定义开始
first_class_pos = content.find('class ChatStreams(Base):') first_class_pos = content.find("class ChatStreams(Base):")
# 找到所有class定义直到遇到非class的def # 找到所有class定义直到遇到非class的def
# 简单策略:找到所有以"class "开头且继承Base的类 # 简单策略:找到所有以"class "开头且继承Base的类
classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)' classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL)) matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
if matches: if matches:
@@ -53,14 +53,14 @@ Base = declarative_base()
''' '''
new_content = header + get_string_field + '\n\n' + models_content new_content = header + get_string_field + "\n\n" + models_content
# 写入新文件 # 写入新文件
with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f: with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
f.write(new_content) f.write(new_content)
print('✅ Models file rewritten successfully') print("✅ Models file rewritten successfully")
print(f'File size: {len(new_content)} characters') print(f"File size: {len(new_content)} characters")
pattern = r"^class \w+\(Base\):" pattern = r"^class \w+\(Base\):"
model_count = len(re.findall(pattern, models_content, re.MULTILINE)) model_count = len(re.findall(pattern, models_content, re.MULTILINE))
print(f'Number of model classes: {model_count}') print(f"Number of model classes: {model_count}")

View File

@@ -8,54 +8,53 @@
import re import re
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple
# 定义导入映射规则 # 定义导入映射规则
IMPORT_MAPPINGS = { IMPORT_MAPPINGS = {
# 模型导入 # 模型导入
r'from src\.common\.database\.sqlalchemy_models import (.+)': r"from src\.common\.database\.sqlalchemy_models import (.+)":
r'from src.common.database.core.models import \1', r"from src.common.database.core.models import \1",
# API导入 - 需要特殊处理 # API导入 - 需要特殊处理
r'from src\.common\.database\.sqlalchemy_database_api import (.+)': r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
r'from src.common.database.compatibility import \1', r"from src.common.database.compatibility import \1",
# get_db_session 从 sqlalchemy_database_api 导入 # get_db_session 从 sqlalchemy_database_api 导入
r'from src\.common\.database\.sqlalchemy_database_api import get_db_session': r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
r'from src.common.database.core import get_db_session', r"from src.common.database.core import get_db_session",
# get_db_session 从 sqlalchemy_models 导入 # get_db_session 从 sqlalchemy_models 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)': r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}' lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
if 'get_db_session' in m.group(0) else m.group(0), if "get_db_session" in m.group(0) else m.group(0),
# get_engine 导入 # get_engine 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)': r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}', lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
# Base 导入 # Base 导入
r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)': r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}', lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
# initialize_database 导入 # initialize_database 导入
r'from src\.common\.database\.sqlalchemy_models import initialize_database': r"from src\.common\.database\.sqlalchemy_models import initialize_database":
r'from src.common.database.core import check_and_migrate_database as initialize_database', r"from src.common.database.core import check_and_migrate_database as initialize_database",
# database.py 导入 # database.py 导入
r'from src\.common\.database\.database import stop_database': r"from src\.common\.database\.database import stop_database":
r'from src.common.database.core import close_engine as stop_database', r"from src.common.database.core import close_engine as stop_database",
r'from src\.common\.database\.database import initialize_sql_database': r"from src\.common\.database\.database import initialize_sql_database":
r'from src.common.database.core import check_and_migrate_database as initialize_sql_database', r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
} }
# 需要排除的文件 # 需要排除的文件
EXCLUDE_PATTERNS = [ EXCLUDE_PATTERNS = [
'**/database_refactoring_plan.md', # 文档文件 "**/database_refactoring_plan.md", # 文档文件
'**/old/**', # 旧文件目录 "**/old/**", # 旧文件目录
'**/sqlalchemy_*.py', # 旧的数据库文件本身 "**/sqlalchemy_*.py", # 旧的数据库文件本身
'**/database.py', # 旧的database文件 "**/database.py", # 旧的database文件
'**/db_*.py', # 旧的db文件 "**/db_*.py", # 旧的db文件
] ]
@@ -67,7 +66,7 @@ def should_exclude(file_path: Path) -> bool:
return False return False
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]: def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
"""更新单个文件中的导入语句 """更新单个文件中的导入语句
Args: Args:
@@ -78,7 +77,7 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
(修改次数, 修改详情列表) (修改次数, 修改详情列表)
""" """
try: try:
content = file_path.read_text(encoding='utf-8') content = file_path.read_text(encoding="utf-8")
original_content = content original_content = content
changes = [] changes = []
@@ -103,7 +102,7 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
# 如果有修改且不是dry_run写回文件 # 如果有修改且不是dry_run写回文件
if content != original_content: if content != original_content:
if not dry_run: if not dry_run:
file_path.write_text(content, encoding='utf-8') file_path.write_text(content, encoding="utf-8")
return len(changes) // 2, changes return len(changes) // 2, changes
return 0, [] return 0, []
@@ -155,7 +154,7 @@ def main():
total_changes += count total_changes += count
print("\n" + "="*80) print("\n" + "="*80)
print(f"\n📊 统计:") print("\n📊 统计:")
print(f" - 需要更新的文件: {len(files_to_update)}") print(f" - 需要更新的文件: {len(files_to_update)}")
print(f" - 总修改次数: {total_changes}") print(f" - 总修改次数: {total_changes}")
@@ -163,7 +162,7 @@ def main():
print("\n" + "="*80) print("\n" + "="*80)
response = input("\n是否执行更新?(yes/no): ").strip().lower() response = input("\n是否执行更新?(yes/no): ").strip().lower()
if response != 'yes': if response != "yes":
print("❌ 已取消更新") print("❌ 已取消更新")
return return

View File

@@ -263,8 +263,8 @@ class AntiPromptInjector:
try: try:
from sqlalchemy import delete from sqlalchemy import delete
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id") message_id = message_data.get("message_id")
if not message_id: if not message_id:
@@ -291,8 +291,8 @@ class AntiPromptInjector:
try: try:
from sqlalchemy import update from sqlalchemy import update
from src.common.database.core.models import Messages
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import Messages
message_id = message_data.get("message_id") message_id = message_data.get("message_id")
if not message_id: if not message_id:

View File

@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
from sqlalchemy import delete, select from sqlalchemy import delete, select
from src.common.database.core.models import AntiInjectionStats
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import AntiInjectionStats
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config

View File

@@ -8,8 +8,8 @@ import datetime
from sqlalchemy import select from sqlalchemy import select
from src.common.database.core.models import BanUser
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import BanUser
from src.common.logger import get_logger from src.common.logger import get_logger
from ..types import DetectionResult from ..types import DetectionResult

View File

@@ -15,9 +15,9 @@ from rich.traceback import install
from sqlalchemy import select from sqlalchemy import select
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Emoji, Images from src.common.database.core.models import Emoji, Images
from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached from src.common.database.utils.decorators import cached
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config

View File

@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, TypedDict from typing import Any, TypedDict
from src.common.logger import get_logger
from src.common.database.api.crud import CRUDBase from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
logger = get_logger("energy_system") logger = get_logger("energy_system")
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
# 从数据库获取聊天流兴趣分数 # 从数据库获取聊天流兴趣分数
try: try:
from sqlalchemy import select
from src.common.database.core.models import ChatStreams from src.common.database.core.models import ChatStreams

View File

@@ -383,7 +383,7 @@ class ExpressionLearner:
current_time = time.time() current_time = time.time()
# 存储到数据库 Expression 表 # 存储到数据库 Expression 表
crud = CRUDBase(Expression) CRUDBase(Expression)
for chat_id, expr_list in chat_dict.items(): for chat_id, expr_list in chat_dict.items():
async with get_db_session() as session: async with get_db_session() as session:
for new_expr in expr_list: for new_expr in expr_list:

View File

@@ -9,10 +9,8 @@ from json_repair import repair_json
from sqlalchemy import select from sqlalchemy import select
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.api.crud import CRUDBase
from src.common.database.compatibility import get_db_session from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression from src.common.database.core.models import Expression
from src.common.database.utils.decorators import cached
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest

View File

@@ -728,7 +728,6 @@ class MemorySystem:
context = context or {} context = context or {}
# 所有记忆完全共享,统一使用 global 作用域,不区分用户 # 所有记忆完全共享,统一使用 global 作用域,不区分用户
resolved_user_id = GLOBAL_MEMORY_SCOPE
self.status = MemorySystemStatus.RETRIEVING self.status = MemorySystemStatus.RETRIEVING
start_time = time.time() start_time = time.time()

View File

@@ -4,15 +4,14 @@ import time
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from rich.traceback import install from rich.traceback import install
from sqlalchemy import select
from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.api.crud import CRUDBase
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.compatibility import get_db_session from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入 from src.common.database.core.models import ChatStreams # 新增导入
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config # 新增导入 from src.config.config import global_config # 新增导入

View File

@@ -105,8 +105,8 @@ class MessageStorageBatcher:
for msg_data in messages_to_store: for msg_data in messages_to_store:
try: try:
message_dict = await self._prepare_message_dict( message_dict = await self._prepare_message_dict(
msg_data['message'], msg_data["message"],
msg_data['chat_stream'] msg_data["chat_stream"]
) )
if message_dict: if message_dict:
messages_dicts.append(message_dict) messages_dicts.append(message_dict)
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
is_picid = message.is_picid is_picid = message.is_picid
is_notify = message.is_notify is_notify = message.is_notify
is_command = message.is_command is_command = message.is_command
is_public_notice = getattr(message, 'is_public_notice', False) is_public_notice = getattr(message, "is_public_notice", False)
notice_type = getattr(message, 'notice_type', None) notice_type = getattr(message, "notice_type", None)
actions = getattr(message, 'actions', None) actions = getattr(message, "actions", None)
should_reply = getattr(message, 'should_reply', None) should_reply = getattr(message, "should_reply", None)
should_act = getattr(message, 'should_act', None) should_act = getattr(message, "should_act", None)
additional_config = getattr(message, 'additional_config', None) additional_config = getattr(message, "additional_config", None)
key_words = MessageStorage._serialize_keywords(message.key_words) key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
# 全局批处理器实例 # 全局批处理器实例
_message_storage_batcher: Optional[MessageStorageBatcher] = None _message_storage_batcher: MessageStorageBatcher | None = None
_message_update_batcher: Optional["MessageUpdateBatcher"] = None _message_update_batcher: Optional["MessageUpdateBatcher"] = None
@@ -488,8 +488,8 @@ class MessageStorage:
if use_batch: if use_batch:
batcher = get_message_storage_batcher() batcher = get_message_storage_batcher()
await batcher.add_message({ await batcher.add_message({
'message': message, "message": message,
'chat_stream': chat_stream "chat_stream": chat_stream
}) })
return return

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any from typing import Any
from src.common.database.compatibility import db_get, db_query, db_save from src.common.database.compatibility import db_get, db_query
from src.common.database.core.models import LLMUsage, Messages, OnlineTime from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask

View File

@@ -12,8 +12,8 @@ from PIL import Image
from rich.traceback import install from rich.traceback import install
from sqlalchemy import and_, select from sqlalchemy import and_, select
from src.common.database.core.models import ImageDescriptions, Images
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import ImageDescriptions, Images
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest

View File

@@ -25,8 +25,8 @@ from typing import Any
from PIL import Image from PIL import Image
from src.common.database.core.models import Videos
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import Videos
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest

View File

@@ -9,6 +9,37 @@
""" """
# ===== 核心层 ===== # ===== 核心层 =====
# ===== API层 =====
from src.common.database.api import (
AggregateQuery,
CRUDBase,
QueryBuilder,
# ChatStreams API
get_active_streams,
# Messages API
get_chat_history,
get_message_count,
# PersonInfo API
get_or_create_person,
# ActionRecords API
get_recent_actions,
# LLMUsage API
get_usage_statistics,
record_llm_usage,
# 业务API
save_message,
store_action_info,
update_person_affinity,
)
# ===== 兼容层向后兼容旧API=====
from src.common.database.compatibility import (
MODEL_MAPPING,
build_filters,
db_get,
db_query,
db_save,
)
from src.common.database.core import ( from src.common.database.core import (
Base, Base,
check_and_migrate_database, check_and_migrate_database,
@@ -27,29 +58,6 @@ from src.common.database.optimization import (
get_preloader, get_preloader,
) )
# ===== API层 =====
from src.common.database.api import (
AggregateQuery,
CRUDBase,
QueryBuilder,
# ActionRecords API
get_recent_actions,
# ChatStreams API
get_active_streams,
# Messages API
get_chat_history,
get_message_count,
# PersonInfo API
get_or_create_person,
# LLMUsage API
get_usage_statistics,
record_llm_usage,
# 业务API
save_message,
store_action_info,
update_person_affinity,
)
# ===== Utils层 ===== # ===== Utils层 =====
from src.common.database.utils import ( from src.common.database.utils import (
cached, cached,
@@ -66,61 +74,52 @@ from src.common.database.utils import (
transactional, transactional,
) )
# ===== 兼容层向后兼容旧API=====
from src.common.database.compatibility import (
MODEL_MAPPING,
build_filters,
db_get,
db_query,
db_save,
)
__all__ = [ __all__ = [
# 核心层
"Base",
"get_engine",
"get_session_factory",
"get_db_session",
"check_and_migrate_database",
# 优化层
"MultiLevelCache",
"DataPreloader",
"AdaptiveBatchScheduler",
"get_cache",
"get_preloader",
"get_batch_scheduler",
# API层 - 基础类
"CRUDBase",
"QueryBuilder",
"AggregateQuery",
# API层 - 业务API
"store_action_info",
"get_recent_actions",
"get_chat_history",
"get_message_count",
"save_message",
"get_or_create_person",
"update_person_affinity",
"get_active_streams",
"record_llm_usage",
"get_usage_statistics",
# Utils层
"retry",
"timeout",
"cached",
"measure_time",
"transactional",
"db_operation",
"get_monitor",
"record_operation",
"record_cache_hit",
"record_cache_miss",
"print_stats",
"reset_stats",
# 兼容层 # 兼容层
"MODEL_MAPPING", "MODEL_MAPPING",
"AdaptiveBatchScheduler",
"AggregateQuery",
# 核心层
"Base",
# API层 - 基础类
"CRUDBase",
"DataPreloader",
# 优化层
"MultiLevelCache",
"QueryBuilder",
"build_filters", "build_filters",
"cached",
"check_and_migrate_database",
"db_get",
"db_operation",
"db_query", "db_query",
"db_save", "db_save",
"db_get", "get_active_streams",
"get_batch_scheduler",
"get_cache",
"get_chat_history",
"get_db_session",
"get_engine",
"get_message_count",
"get_monitor",
"get_or_create_person",
"get_preloader",
"get_recent_actions",
"get_session_factory",
"get_usage_statistics",
"measure_time",
"print_stats",
"record_cache_hit",
"record_cache_miss",
"record_llm_usage",
"record_operation",
"reset_stats",
# Utils层
"retry",
"save_message",
# API层 - 业务API
"store_action_info",
"timeout",
"transactional",
"update_person_affinity",
] ]

View File

@@ -11,49 +11,49 @@ from src.common.database.api.query import AggregateQuery, QueryBuilder
# 业务特定API # 业务特定API
from src.common.database.api.specialized import ( from src.common.database.api.specialized import (
# ActionRecords
get_recent_actions,
store_action_info,
# ChatStreams # ChatStreams
get_active_streams, get_active_streams,
get_or_create_chat_stream,
# LLMUsage
get_usage_statistics,
record_llm_usage,
# Messages # Messages
get_chat_history, get_chat_history,
get_message_count, get_message_count,
save_message, get_or_create_chat_stream,
# PersonInfo # PersonInfo
get_or_create_person, get_or_create_person,
update_person_affinity, # ActionRecords
get_recent_actions,
# LLMUsage
get_usage_statistics,
# UserRelationships # UserRelationships
get_user_relationship, get_user_relationship,
record_llm_usage,
save_message,
store_action_info,
update_person_affinity,
update_relationship_affinity, update_relationship_affinity,
) )
__all__ = [ __all__ = [
"AggregateQuery",
# 基础类 # 基础类
"CRUDBase", "CRUDBase",
"QueryBuilder", "QueryBuilder",
"AggregateQuery", "get_active_streams",
# ActionRecords API
"store_action_info",
"get_recent_actions",
# Messages API # Messages API
"get_chat_history", "get_chat_history",
"get_message_count", "get_message_count",
"save_message",
# PersonInfo API
"get_or_create_person",
"update_person_affinity",
# ChatStreams API # ChatStreams API
"get_or_create_chat_stream", "get_or_create_chat_stream",
"get_active_streams", # PersonInfo API
# LLMUsage API "get_or_create_person",
"record_llm_usage", "get_recent_actions",
"get_usage_statistics", "get_usage_statistics",
# UserRelationships API # UserRelationships API
"get_user_relationship", "get_user_relationship",
# LLMUsage API
"record_llm_usage",
"save_message",
# ActionRecords API
"store_action_info",
"update_person_affinity",
"update_relationship_affinity", "update_relationship_affinity",
] ]

View File

@@ -206,7 +206,7 @@ class CRUDBase:
# 应用过滤条件 # 应用过滤条件
for key, value in filters.items(): for key, value in filters.items():
if hasattr(self.model, key): if hasattr(self.model, key):
if isinstance(value, (list, tuple, set)): if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value)) stmt = stmt.where(getattr(self.model, key).in_(value))
else: else:
stmt = stmt.where(getattr(self.model, key) == value) stmt = stmt.where(getattr(self.model, key) == value)
@@ -397,7 +397,7 @@ class CRUDBase:
# 应用过滤条件 # 应用过滤条件
for key, value in filters.items(): for key, value in filters.items():
if hasattr(self.model, key): if hasattr(self.model, key):
if isinstance(value, (list, tuple, set)): if isinstance(value, list | tuple | set):
stmt = stmt.where(getattr(self.model, key).in_(value)) stmt = stmt.where(getattr(self.model, key).in_(value))
else: else:
stmt = stmt.where(getattr(self.model, key) == value) stmt = stmt.where(getattr(self.model, key) == value)

View File

@@ -4,7 +4,7 @@
""" """
import time import time
from typing import Any, Optional from typing import Any
import orjson import orjson
@@ -42,9 +42,9 @@ async def store_action_info(
action_prompt_display: str = "", action_prompt_display: str = "",
action_done: bool = True, action_done: bool = True,
thinking_id: str = "", thinking_id: str = "",
action_data: Optional[dict] = None, action_data: dict | None = None,
action_name: str = "", action_name: str = "",
) -> Optional[dict[str, Any]]: ) -> dict[str, Any] | None:
"""存储动作信息到数据库 """存储动作信息到数据库
Args: Args:
@@ -167,7 +167,7 @@ async def get_message_count(stream_id: str) -> int:
async def save_message( async def save_message(
message_data: dict[str, Any], message_data: dict[str, Any],
use_batch: bool = True, use_batch: bool = True,
) -> Optional[Messages]: ) -> Messages | None:
"""保存消息 """保存消息
Args: Args:
@@ -185,8 +185,8 @@ async def save_message(
async def get_or_create_person( async def get_or_create_person(
platform: str, platform: str,
person_id: str, person_id: str,
defaults: Optional[dict[str, Any]] = None, defaults: dict[str, Any] | None = None,
) -> tuple[Optional[PersonInfo], bool]: ) -> tuple[PersonInfo | None, bool]:
"""获取或创建人员信息 """获取或创建人员信息
Args: Args:
@@ -255,8 +255,8 @@ async def update_person_affinity(
async def get_or_create_chat_stream( async def get_or_create_chat_stream(
stream_id: str, stream_id: str,
platform: str, platform: str,
defaults: Optional[dict[str, Any]] = None, defaults: dict[str, Any] | None = None,
) -> tuple[Optional[ChatStreams], bool]: ) -> tuple[ChatStreams | None, bool]:
"""获取或创建聊天流 """获取或创建聊天流
Args: Args:
@@ -275,7 +275,7 @@ async def get_or_create_chat_stream(
async def get_active_streams( async def get_active_streams(
platform: Optional[str] = None, platform: str | None = None,
limit: int = 100, limit: int = 100,
) -> list[ChatStreams]: ) -> list[ChatStreams]:
"""获取活跃的聊天流 """获取活跃的聊天流
@@ -300,18 +300,18 @@ async def record_llm_usage(
model_name: str, model_name: str,
input_tokens: int, input_tokens: int,
output_tokens: int, output_tokens: int,
stream_id: Optional[str] = None, stream_id: str | None = None,
platform: Optional[str] = None, platform: str | None = None,
user_id: str = "system", user_id: str = "system",
request_type: str = "chat", request_type: str = "chat",
model_assign_name: Optional[str] = None, model_assign_name: str | None = None,
model_api_provider: Optional[str] = None, model_api_provider: str | None = None,
endpoint: str = "/v1/chat/completions", endpoint: str = "/v1/chat/completions",
cost: float = 0.0, cost: float = 0.0,
status: str = "success", status: str = "success",
time_cost: Optional[float] = None, time_cost: float | None = None,
use_batch: bool = True, use_batch: bool = True,
) -> Optional[LLMUsage]: ) -> LLMUsage | None:
"""记录LLM使用情况 """记录LLM使用情况
Args: Args:
@@ -354,9 +354,9 @@ async def record_llm_usage(
async def get_usage_statistics( async def get_usage_statistics(
start_time: Optional[float] = None, start_time: float | None = None,
end_time: Optional[float] = None, end_time: float | None = None,
model_name: Optional[str] = None, model_name: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""获取使用统计 """获取使用统计
@@ -374,8 +374,7 @@ async def get_usage_statistics(
# 添加时间过滤 # 添加时间过滤
if start_time: if start_time:
async with get_db_session() as session: async with get_db_session():
from sqlalchemy import and_
conditions = [] conditions = []
if start_time: if start_time:
@@ -407,7 +406,7 @@ async def get_user_relationship(
platform: str, platform: str,
user_id: str, user_id: str,
target_id: str, target_id: str,
) -> Optional[UserRelationships]: ) -> UserRelationships | None:
"""获取用户关系 """获取用户关系
Args: Args:

View File

@@ -14,14 +14,14 @@ from .adapter import (
) )
__all__ = [ __all__ = [
# 从 core 重新导出的函数
"get_db_session",
"get_engine",
# 兼容层适配器 # 兼容层适配器
"MODEL_MAPPING", "MODEL_MAPPING",
"build_filters", "build_filters",
"db_get",
"db_query", "db_query",
"db_save", "db_save",
"db_get", # 从 core 重新导出的函数
"get_db_session",
"get_engine",
"store_action_info", "store_action_info",
] ]

View File

@@ -4,15 +4,13 @@
保持原有函数签名和行为不变 保持原有函数签名和行为不变
""" """
import time from typing import Any
from typing import Any, Optional
import orjson
from sqlalchemy import and_, asc, desc, select
from src.common.database.api import ( from src.common.database.api import (
CRUDBase, CRUDBase,
QueryBuilder, QueryBuilder,
)
from src.common.database.api import (
store_action_info as new_store_action_info, store_action_info as new_store_action_info,
) )
from src.common.database.core.models import ( from src.common.database.core.models import (
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
Messages, Messages,
MonthlyPlan, MonthlyPlan,
OnlineTime, OnlineTime,
PersonInfo,
PermissionNodes, PermissionNodes,
PersonInfo,
Schedule, Schedule,
ThinkingLog, ThinkingLog,
UserPermissions, UserPermissions,
UserRelationships, UserRelationships,
Videos, Videos,
) )
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("database.compatibility") logger = get_logger("database.compatibility")
@@ -145,12 +142,12 @@ def _model_to_dict(instance) -> dict[str, Any]:
async def db_query( async def db_query(
model_class, model_class,
data: Optional[dict[str, Any]] = None, data: dict[str, Any] | None = None,
query_type: Optional[str] = "get", query_type: str | None = "get",
filters: Optional[dict[str, Any]] = None, filters: dict[str, Any] | None = None,
limit: Optional[int] = None, limit: int | None = None,
order_by: Optional[list[str]] = None, order_by: list[str] | None = None,
single_result: Optional[bool] = False, single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None: ) -> list[dict[str, Any]] | dict[str, Any] | None:
"""执行异步数据库查询操作兼容旧API """执行异步数据库查询操作兼容旧API
@@ -286,7 +283,7 @@ async def db_save(
data: dict[str, Any], data: dict[str, Any],
key_field: str, key_field: str,
key_value: Any, key_value: Any,
) -> Optional[dict[str, Any]]: ) -> dict[str, Any] | None:
"""保存或更新记录兼容旧API """保存或更新记录兼容旧API
Args: Args:
@@ -319,10 +316,10 @@ async def db_save(
async def db_get( async def db_get(
model_class, model_class,
filters: Optional[dict[str, Any]] = None, filters: dict[str, Any] | None = None,
limit: Optional[int] = None, limit: int | None = None,
order_by: Optional[str] = None, order_by: str | None = None,
single_result: Optional[bool] = False, single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None: ) -> list[dict[str, Any]] | dict[str, Any] | None:
"""从数据库获取记录兼容旧API """从数据库获取记录兼容旧API
@@ -353,9 +350,9 @@ async def store_action_info(
action_prompt_display: str = "", action_prompt_display: str = "",
action_done: bool = True, action_done: bool = True,
thinking_id: str = "", thinking_id: str = "",
action_data: Optional[dict] = None, action_data: dict | None = None,
action_name: str = "", action_name: str = "",
) -> Optional[dict[str, Any]]: ) -> dict[str, Any] | None:
"""存储动作信息到数据库兼容旧API """存储动作信息到数据库兼容旧API
直接使用新的specialized API 直接使用新的specialized API

View File

@@ -5,7 +5,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any
from urllib.parse import quote_plus from urllib.parse import quote_plus
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -25,19 +25,19 @@ class DatabaseConfig:
engine_kwargs: dict[str, Any] engine_kwargs: dict[str, Any]
# SQLite特定配置 # SQLite特定配置
sqlite_path: Optional[str] = None sqlite_path: str | None = None
# MySQL特定配置 # MySQL特定配置
mysql_host: Optional[str] = None mysql_host: str | None = None
mysql_port: Optional[int] = None mysql_port: int | None = None
mysql_user: Optional[str] = None mysql_user: str | None = None
mysql_password: Optional[str] = None mysql_password: str | None = None
mysql_database: Optional[str] = None mysql_database: str | None = None
mysql_charset: str = "utf8mb4" mysql_charset: str = "utf8mb4"
mysql_unix_socket: Optional[str] = None mysql_unix_socket: str | None = None
_database_config: Optional[DatabaseConfig] = None _database_config: DatabaseConfig | None = None
def get_database_config() -> DatabaseConfig: def get_database_config() -> DatabaseConfig:

View File

@@ -19,7 +19,6 @@ from .models import (
ChatStreams, ChatStreams,
Emoji, Emoji,
Expression, Expression,
get_string_field,
GraphEdges, GraphEdges,
GraphNodes, GraphNodes,
ImageDescriptions, ImageDescriptions,
@@ -37,30 +36,17 @@ from .models import (
UserPermissions, UserPermissions,
UserRelationships, UserRelationships,
Videos, Videos,
get_string_field,
) )
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
__all__ = [ __all__ = [
# Engine
"get_engine",
"close_engine",
"get_engine_info",
# Session
"get_db_session",
"get_db_session_direct",
"get_session_factory",
"reset_session_factory",
# Migration
"check_and_migrate_database",
"create_all_tables",
"drop_all_tables",
# Models - Base
"Base",
"get_string_field",
# Models - Tables (按字母顺序) # Models - Tables (按字母顺序)
"ActionRecords", "ActionRecords",
"AntiInjectionStats", "AntiInjectionStats",
"BanUser", "BanUser",
# Models - Base
"Base",
"BotPersonalityInterests", "BotPersonalityInterests",
"CacheEntries", "CacheEntries",
"ChatStreams", "ChatStreams",
@@ -83,4 +69,18 @@ __all__ = [
"UserPermissions", "UserPermissions",
"UserRelationships", "UserRelationships",
"Videos", "Videos",
# Migration
"check_and_migrate_database",
"close_engine",
"create_all_tables",
"drop_all_tables",
# Session
"get_db_session",
"get_db_session_direct",
# Engine
"get_engine",
"get_engine_info",
"get_session_factory",
"get_string_field",
"reset_session_factory",
] ]

View File

@@ -5,7 +5,6 @@
import asyncio import asyncio
import os import os
from typing import Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
from sqlalchemy import text from sqlalchemy import text
@@ -18,8 +17,8 @@ from ..utils.exceptions import DatabaseInitializationError
logger = get_logger("database.engine") logger = get_logger("database.engine")
# 全局引擎实例 # 全局引擎实例
_engine: Optional[AsyncEngine] = None _engine: AsyncEngine | None = None
_engine_lock: Optional[asyncio.Lock] = None _engine_lock: asyncio.Lock | None = None
async def get_engine() -> AsyncEngine: async def get_engine() -> AsyncEngine:

View File

@@ -6,7 +6,6 @@
import asyncio import asyncio
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -18,8 +17,8 @@ from .engine import get_engine
logger = get_logger("database.session") logger = get_logger("database.session")
# 全局会话工厂 # 全局会话工厂
_session_factory: Optional[async_sessionmaker] = None _session_factory: async_sessionmaker | None = None
_factory_lock: Optional[asyncio.Lock] = None _factory_lock: asyncio.Lock | None = None
async def get_session_factory() -> async_sessionmaker: async def get_session_factory() -> async_sessionmaker:

View File

@@ -11,17 +11,17 @@ from .batch_scheduler import (
AdaptiveBatchScheduler, AdaptiveBatchScheduler,
BatchOperation, BatchOperation,
BatchStats, BatchStats,
Priority,
close_batch_scheduler, close_batch_scheduler,
get_batch_scheduler, get_batch_scheduler,
Priority,
) )
from .cache_manager import ( from .cache_manager import (
CacheEntry, CacheEntry,
CacheStats, CacheStats,
close_cache,
get_cache,
LRUCache, LRUCache,
MultiLevelCache, MultiLevelCache,
close_cache,
get_cache,
) )
from .connection_pool import ( from .connection_pool import (
ConnectionPoolManager, ConnectionPoolManager,
@@ -31,36 +31,36 @@ from .connection_pool import (
) )
from .preloader import ( from .preloader import (
AccessPattern, AccessPattern,
close_preloader,
CommonDataPreloader, CommonDataPreloader,
DataPreloader, DataPreloader,
close_preloader,
get_preloader, get_preloader,
) )
__all__ = [ __all__ = [
# Connection Pool
"ConnectionPoolManager",
"get_connection_pool_manager",
"start_connection_pool",
"stop_connection_pool",
# Cache
"MultiLevelCache",
"LRUCache",
"CacheEntry",
"CacheStats",
"get_cache",
"close_cache",
# Preloader
"DataPreloader",
"CommonDataPreloader",
"AccessPattern", "AccessPattern",
"get_preloader",
"close_preloader",
# Batch Scheduler # Batch Scheduler
"AdaptiveBatchScheduler", "AdaptiveBatchScheduler",
"BatchOperation", "BatchOperation",
"BatchStats", "BatchStats",
"CacheEntry",
"CacheStats",
"CommonDataPreloader",
# Connection Pool
"ConnectionPoolManager",
# Preloader
"DataPreloader",
"LRUCache",
# Cache
"MultiLevelCache",
"Priority", "Priority",
"get_batch_scheduler",
"close_batch_scheduler", "close_batch_scheduler",
"close_cache",
"close_preloader",
"get_batch_scheduler",
"get_cache",
"get_connection_pool_manager",
"get_preloader",
"start_connection_pool",
"stop_connection_pool",
] ]

View File

@@ -10,12 +10,12 @@
import asyncio import asyncio
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntEnum from enum import IntEnum
from typing import Any, Callable, Optional, TypeVar from typing import Any, TypeVar
from sqlalchemy import delete, insert, select, update from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.core.session import get_db_session from src.common.database.core.session import get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -40,12 +40,12 @@ class BatchOperation:
operation_type: str # 'select', 'insert', 'update', 'delete' operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: type model_class: type
conditions: dict[str, Any] = field(default_factory=dict) conditions: dict[str, Any] = field(default_factory=dict)
data: Optional[dict[str, Any]] = None data: dict[str, Any] | None = None
callback: Optional[Callable] = None callback: Callable | None = None
future: Optional[asyncio.Future] = None future: asyncio.Future | None = None
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
priority: Priority = Priority.NORMAL priority: Priority = Priority.NORMAL
timeout: Optional[float] = None # 超时时间(秒) timeout: float | None = None # 超时时间(秒)
@dataclass @dataclass
@@ -111,7 +111,7 @@ class AdaptiveBatchScheduler:
} }
# 调度控制 # 调度控制
self._scheduler_task: Optional[asyncio.Task] = None self._scheduler_task: asyncio.Task | None = None
self._is_running = False self._is_running = False
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@@ -257,7 +257,7 @@ class AdaptiveBatchScheduler:
op_groups[key].append(op) op_groups[key].append(op)
# 执行各组操作 # 执行各组操作
for group_key, ops in op_groups.items(): for ops in op_groups.values():
await self._execute_group(ops) await self._execute_group(ops)
# 更新统计 # 更新统计
@@ -323,7 +323,7 @@ class AdaptiveBatchScheduler:
stmt = select(op.model_class) stmt = select(op.model_class)
for key, value in op.conditions.items(): for key, value in op.conditions.items():
attr = getattr(op.model_class, key) attr = getattr(op.model_class, key)
if isinstance(value, (list, tuple, set)): if isinstance(value, list | tuple | set):
stmt = stmt.where(attr.in_(value)) stmt = stmt.where(attr.in_(value))
else: else:
stmt = stmt.where(attr == value) stmt = stmt.where(attr == value)
@@ -366,7 +366,7 @@ class AdaptiveBatchScheduler:
# 批量插入 # 批量插入
stmt = insert(operations[0].model_class).values(all_data) stmt = insert(operations[0].model_class).values(all_data)
result = await session.execute(stmt) await session.execute(stmt)
await session.commit() await session.commit()
# 设置结果 # 设置结果
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
] ]
return "|".join(key_parts) return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]: def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果""" """从缓存获取结果"""
if cache_key in self._result_cache: if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key] result, timestamp = self._result_cache[cache_key]
@@ -551,7 +551,7 @@ class AdaptiveBatchScheduler:
# 全局调度器实例 # 全局调度器实例
_global_scheduler: Optional[AdaptiveBatchScheduler] = None _global_scheduler: AdaptiveBatchScheduler | None = None
_scheduler_lock = asyncio.Lock() _scheduler_lock = asyncio.Lock()

View File

@@ -11,8 +11,9 @@
import asyncio import asyncio
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar from typing import Any, Generic, TypeVar
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -94,7 +95,7 @@ class LRUCache(Generic[T]):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._stats = CacheStats() self._stats = CacheStats()
async def get(self, key: str) -> Optional[T]: async def get(self, key: str) -> T | None:
"""获取缓存值 """获取缓存值
Args: Args:
@@ -135,7 +136,7 @@ class LRUCache(Generic[T]):
self, self,
key: str, key: str,
value: T, value: T,
size: Optional[int] = None, size: int | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -255,7 +256,7 @@ class MultiLevelCache:
""" """
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1") self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2") self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_task: asyncio.Task | None = None
logger.info( logger.info(
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) " f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
@@ -265,8 +266,8 @@ class MultiLevelCache:
async def get( async def get(
self, self,
key: str, key: str,
loader: Optional[Callable[[], Any]] = None, loader: Callable[[], Any] | None = None,
) -> Optional[Any]: ) -> Any | None:
"""从缓存获取数据 """从缓存获取数据
查询顺序L1 -> L2 -> loader 查询顺序L1 -> L2 -> loader
@@ -307,7 +308,7 @@ class MultiLevelCache:
self, self,
key: str, key: str,
value: Any, value: Any,
size: Optional[int] = None, size: int | None = None,
) -> None: ) -> None:
"""设置缓存值 """设置缓存值
@@ -387,7 +388,7 @@ class MultiLevelCache:
# 全局缓存实例 # 全局缓存实例
_global_cache: Optional[MultiLevelCache] = None _global_cache: MultiLevelCache | None = None
_cache_lock = asyncio.Lock() _cache_lock = asyncio.Lock()

View File

@@ -10,8 +10,9 @@
import asyncio import asyncio
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Optional from typing import Any
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -84,7 +85,7 @@ class DataPreloader:
async def record_access( async def record_access(
self, self,
key: str, key: str,
related_keys: Optional[list[str]] = None, related_keys: list[str] | None = None,
) -> None: ) -> None:
"""记录数据访问 """记录数据访问
@@ -379,7 +380,7 @@ class CommonDataPreloader:
stream_id: 聊天流ID stream_id: 聊天流ID
limit: 消息数量限制 limit: 消息数量限制
""" """
from src.common.database.core.models import ChatStreams, Messages from src.common.database.core.models import ChatStreams
# 预加载聊天流信息 # 预加载聊天流信息
await self._preload_model( await self._preload_model(
@@ -418,7 +419,7 @@ class CommonDataPreloader:
# 全局预加载器实例 # 全局预加载器实例
_global_preloader: Optional[DataPreloader] = None _global_preloader: DataPreloader | None = None
_preloader_lock = asyncio.Lock() _preloader_lock = asyncio.Lock()

View File

@@ -37,29 +37,29 @@ from .monitoring import (
) )
__all__ = [ __all__ = [
"BatchSchedulerError",
"CacheError",
"ConnectionPoolError",
"DatabaseConnectionError",
# 异常 # 异常
"DatabaseError", "DatabaseError",
"DatabaseInitializationError", "DatabaseInitializationError",
"DatabaseConnectionError", "DatabaseMigrationError",
# 监控
"DatabaseMonitor",
"DatabaseQueryError", "DatabaseQueryError",
"DatabaseTransactionError", "DatabaseTransactionError",
"DatabaseMigrationError", "cached",
"CacheError", "db_operation",
"BatchSchedulerError", "get_monitor",
"ConnectionPoolError", "measure_time",
"print_stats",
"record_cache_hit",
"record_cache_miss",
"record_operation",
"reset_stats",
# 装饰器 # 装饰器
"retry", "retry",
"timeout", "timeout",
"cached",
"measure_time",
"transactional", "transactional",
"db_operation",
# 监控
"DatabaseMonitor",
"get_monitor",
"record_operation",
"record_cache_hit",
"record_cache_miss",
"print_stats",
"reset_stats",
] ]

View File

@@ -10,9 +10,11 @@ import asyncio
import functools import functools
import hashlib import hashlib
import time import time
from typing import Any, Awaitable, Callable, Optional, TypeVar from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -141,8 +143,8 @@ def timeout(seconds: float):
def cached( def cached(
ttl: Optional[int] = 300, ttl: int | None = 300,
key_prefix: Optional[str] = None, key_prefix: str | None = None,
use_args: bool = True, use_args: bool = True,
use_kwargs: bool = True, use_kwargs: bool = True,
): ):
@@ -207,7 +209,7 @@ def cached(
return decorator return decorator
def measure_time(log_slow: Optional[float] = None): def measure_time(log_slow: float | None = None):
"""性能测量装饰器 """性能测量装饰器
测量函数执行时间,可选择性记录慢查询 测量函数执行时间,可选择性记录慢查询
@@ -306,8 +308,8 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
# 组合装饰器示例 # 组合装饰器示例
def db_operation( def db_operation(
retry_attempts: int = 3, retry_attempts: int = 3,
timeout_seconds: Optional[float] = None, timeout_seconds: float | None = None,
cache_ttl: Optional[int] = None, cache_ttl: int | None = None,
measure: bool = True, measure: bool = True,
): ):
"""组合装饰器 """组合装饰器

View File

@@ -4,7 +4,6 @@
""" """
import time import time
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional from typing import Any, Optional
@@ -22,7 +21,7 @@ class OperationMetrics:
min_time: float = float("inf") min_time: float = float("inf")
max_time: float = 0.0 max_time: float = 0.0
error_count: int = 0 error_count: int = 0
last_execution_time: Optional[float] = None last_execution_time: float | None = None
@property @property
def avg_time(self) -> float: def avg_time(self) -> float:
@@ -285,7 +284,7 @@ class DatabaseMonitor:
# 全局监控器实例 # 全局监控器实例
_monitor: Optional[DatabaseMonitor] = None _monitor: DatabaseMonitor | None = None
def get_monitor() -> DatabaseMonitor: def get_monitor() -> DatabaseMonitor:

View File

@@ -4,8 +4,8 @@ from datetime import datetime
from PIL import Image from PIL import Image
from src.common.database.core.models import LLMUsage
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import LLMUsage
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo from src.config.api_ada_configs import ModelInfo

View File

@@ -256,9 +256,10 @@ class RelationshipFetcher:
str: 格式化后的聊天流印象字符串 str: 格式化后的聊天流印象字符串
""" """
try: try:
from src.common.database.api.specialized import get_or_create_chat_stream
import time import time
from src.common.database.api.specialized import get_or_create_chat_stream
# 使用优化后的API带缓存 # 使用优化后的API带缓存
# 从stream_id解析platform或使用默认值 # 从stream_id解析platform或使用默认值
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown" platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"

View File

@@ -52,8 +52,8 @@ from typing import Any
import orjson import orjson
from sqlalchemy import func, select from sqlalchemy import func, select
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.logger import get_logger from src.common.logger import get_logger
from src.schedule.database import get_active_plans_for_month from src.schedule.database import get_active_plans_for_month

View File

@@ -7,7 +7,7 @@ from src.plugin_system.base.component_types import ChatType, CommandInfo, Compon
from src.plugin_system.base.plus_command import PlusCommand from src.plugin_system.base.plus_command import PlusCommand
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream pass
logger = get_logger("base_command") logger = get_logger("base_command")

View File

@@ -10,8 +10,8 @@ from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_sessionmaker
from src.common.database.core.models import PermissionNodes, UserPermissions
from src.common.database.core import get_engine from src.common.database.core import get_engine
from src.common.database.core.models import PermissionNodes, UserPermissions
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo

View File

@@ -5,8 +5,8 @@
import time import time
from src.common.database.core.models import UserRelationships
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import UserRelationships
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config

View File

@@ -7,12 +7,8 @@
import json import json
from typing import Any, ClassVar from typing import Any, ClassVar
from sqlalchemy import select
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.database.api.crud import CRUDBase from src.common.database.api.crud import CRUDBase
from src.common.database.utils.decorators import cached from src.common.database.core.models import ChatStreams
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest

View File

@@ -7,13 +7,10 @@ import json
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal
from sqlalchemy import select
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.database.api.crud import CRUDBase from src.common.database.api.crud import CRUDBase
from src.common.database.core.models import ChatStreams
from src.common.database.utils.decorators import cached from src.common.database.utils.decorators import cached
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config

View File

@@ -3,8 +3,8 @@
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, update
from src.common.database.core.models import MonthlyPlan
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -312,7 +312,7 @@ async def delete_plans_older_than(month: str):
logger.info(f"没有找到比 {month} 更早的月度计划需要删除。") logger.info(f"没有找到比 {month} 更早的月度计划需要删除。")
return 0 return 0
plan_months = sorted(list(set(p.target_month for p in plans_to_delete))) plan_months = sorted({p.target_month for p in plans_to_delete})
logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。") logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。")
# 然后,执行删除操作 # 然后,执行删除操作

View File

@@ -5,8 +5,8 @@ from typing import Any
import orjson import orjson
from sqlalchemy import select from sqlalchemy import select
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.database.core import get_db_session from src.common.database.core import get_db_session
from src.common.database.core.models import MonthlyPlan, Schedule
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager