rufffffff
This commit is contained in:
@@ -16,7 +16,7 @@ models_file = os.path.join(
|
|||||||
print(f"正在清理文件: {models_file}")
|
print(f"正在清理文件: {models_file}")
|
||||||
|
|
||||||
# 读取文件
|
# 读取文件
|
||||||
with open(models_file, "r", encoding="utf-8") as f:
|
with open(models_file, encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
||||||
@@ -43,7 +43,7 @@ if not found_end:
|
|||||||
with open(models_file, "w", encoding="utf-8") as f:
|
with open(models_file, "w", encoding="utf-8") as f:
|
||||||
f.writelines(keep_lines)
|
f.writelines(keep_lines)
|
||||||
|
|
||||||
print(f"✅ 文件清理完成")
|
print("✅ 文件清理完成")
|
||||||
print(f"保留行数: {len(keep_lines)}")
|
print(f"保留行数: {len(keep_lines)}")
|
||||||
print(f"原始行数: {len(lines)}")
|
print(f"原始行数: {len(lines)}")
|
||||||
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
||||||
|
|||||||
@@ -4,20 +4,20 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
# 读取原始文件
|
# 读取原始文件
|
||||||
with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f:
|
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# 找到get_string_field函数的开始和结束
|
# 找到get_string_field函数的开始和结束
|
||||||
get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数')
|
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
|
||||||
get_string_field_end = content.find('\n\nclass ChatStreams(Base):')
|
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
|
||||||
get_string_field = content[get_string_field_start:get_string_field_end]
|
get_string_field = content[get_string_field_start:get_string_field_end]
|
||||||
|
|
||||||
# 找到第一个class定义开始
|
# 找到第一个class定义开始
|
||||||
first_class_pos = content.find('class ChatStreams(Base):')
|
first_class_pos = content.find("class ChatStreams(Base):")
|
||||||
|
|
||||||
# 找到所有class定义,直到遇到非class的def
|
# 找到所有class定义,直到遇到非class的def
|
||||||
# 简单策略:找到所有以"class "开头且继承Base的类
|
# 简单策略:找到所有以"class "开头且继承Base的类
|
||||||
classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)'
|
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
|
||||||
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
||||||
|
|
||||||
if matches:
|
if matches:
|
||||||
@@ -53,14 +53,14 @@ Base = declarative_base()
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
new_content = header + get_string_field + '\n\n' + models_content
|
new_content = header + get_string_field + "\n\n" + models_content
|
||||||
|
|
||||||
# 写入新文件
|
# 写入新文件
|
||||||
with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f:
|
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
|
||||||
f.write(new_content)
|
f.write(new_content)
|
||||||
|
|
||||||
print('✅ Models file rewritten successfully')
|
print("✅ Models file rewritten successfully")
|
||||||
print(f'File size: {len(new_content)} characters')
|
print(f"File size: {len(new_content)} characters")
|
||||||
pattern = r"^class \w+\(Base\):"
|
pattern = r"^class \w+\(Base\):"
|
||||||
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
||||||
print(f'Number of model classes: {model_count}')
|
print(f"Number of model classes: {model_count}")
|
||||||
|
|||||||
@@ -8,54 +8,53 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
# 定义导入映射规则
|
# 定义导入映射规则
|
||||||
IMPORT_MAPPINGS = {
|
IMPORT_MAPPINGS = {
|
||||||
# 模型导入
|
# 模型导入
|
||||||
r'from src\.common\.database\.sqlalchemy_models import (.+)':
|
r"from src\.common\.database\.sqlalchemy_models import (.+)":
|
||||||
r'from src.common.database.core.models import \1',
|
r"from src.common.database.core.models import \1",
|
||||||
|
|
||||||
# API导入 - 需要特殊处理
|
# API导入 - 需要特殊处理
|
||||||
r'from src\.common\.database\.sqlalchemy_database_api import (.+)':
|
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
|
||||||
r'from src.common.database.compatibility import \1',
|
r"from src.common.database.compatibility import \1",
|
||||||
|
|
||||||
# get_db_session 从 sqlalchemy_database_api 导入
|
# get_db_session 从 sqlalchemy_database_api 导入
|
||||||
r'from src\.common\.database\.sqlalchemy_database_api import get_db_session':
|
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
|
||||||
r'from src.common.database.core import get_db_session',
|
r"from src.common.database.core import get_db_session",
|
||||||
|
|
||||||
# get_db_session 从 sqlalchemy_models 导入
|
# get_db_session 从 sqlalchemy_models 导入
|
||||||
r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)':
|
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
|
||||||
lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}'
|
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
|
||||||
if 'get_db_session' in m.group(0) else m.group(0),
|
if "get_db_session" in m.group(0) else m.group(0),
|
||||||
|
|
||||||
# get_engine 导入
|
# get_engine 导入
|
||||||
r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)':
|
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
|
||||||
lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}',
|
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
|
||||||
|
|
||||||
# Base 导入
|
# Base 导入
|
||||||
r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)':
|
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
|
||||||
lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}',
|
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
|
||||||
|
|
||||||
# initialize_database 导入
|
# initialize_database 导入
|
||||||
r'from src\.common\.database\.sqlalchemy_models import initialize_database':
|
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
|
||||||
r'from src.common.database.core import check_and_migrate_database as initialize_database',
|
r"from src.common.database.core import check_and_migrate_database as initialize_database",
|
||||||
|
|
||||||
# database.py 导入
|
# database.py 导入
|
||||||
r'from src\.common\.database\.database import stop_database':
|
r"from src\.common\.database\.database import stop_database":
|
||||||
r'from src.common.database.core import close_engine as stop_database',
|
r"from src.common.database.core import close_engine as stop_database",
|
||||||
|
|
||||||
r'from src\.common\.database\.database import initialize_sql_database':
|
r"from src\.common\.database\.database import initialize_sql_database":
|
||||||
r'from src.common.database.core import check_and_migrate_database as initialize_sql_database',
|
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 需要排除的文件
|
# 需要排除的文件
|
||||||
EXCLUDE_PATTERNS = [
|
EXCLUDE_PATTERNS = [
|
||||||
'**/database_refactoring_plan.md', # 文档文件
|
"**/database_refactoring_plan.md", # 文档文件
|
||||||
'**/old/**', # 旧文件目录
|
"**/old/**", # 旧文件目录
|
||||||
'**/sqlalchemy_*.py', # 旧的数据库文件本身
|
"**/sqlalchemy_*.py", # 旧的数据库文件本身
|
||||||
'**/database.py', # 旧的database文件
|
"**/database.py", # 旧的database文件
|
||||||
'**/db_*.py', # 旧的db文件
|
"**/db_*.py", # 旧的db文件
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -67,7 +66,7 @@ def should_exclude(file_path: Path) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]:
|
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
|
||||||
"""更新单个文件中的导入语句
|
"""更新单个文件中的导入语句
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -78,7 +77,7 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
|
|||||||
(修改次数, 修改详情列表)
|
(修改次数, 修改详情列表)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
content = file_path.read_text(encoding='utf-8')
|
content = file_path.read_text(encoding="utf-8")
|
||||||
original_content = content
|
original_content = content
|
||||||
changes = []
|
changes = []
|
||||||
|
|
||||||
@@ -103,7 +102,7 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
|
|||||||
# 如果有修改且不是dry_run,写回文件
|
# 如果有修改且不是dry_run,写回文件
|
||||||
if content != original_content:
|
if content != original_content:
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
file_path.write_text(content, encoding='utf-8')
|
file_path.write_text(content, encoding="utf-8")
|
||||||
return len(changes) // 2, changes
|
return len(changes) // 2, changes
|
||||||
|
|
||||||
return 0, []
|
return 0, []
|
||||||
@@ -155,7 +154,7 @@ def main():
|
|||||||
total_changes += count
|
total_changes += count
|
||||||
|
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print(f"\n📊 统计:")
|
print("\n📊 统计:")
|
||||||
print(f" - 需要更新的文件: {len(files_to_update)}")
|
print(f" - 需要更新的文件: {len(files_to_update)}")
|
||||||
print(f" - 总修改次数: {total_changes}")
|
print(f" - 总修改次数: {total_changes}")
|
||||||
|
|
||||||
@@ -163,7 +162,7 @@ def main():
|
|||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
response = input("\n是否执行更新?(yes/no): ").strip().lower()
|
response = input("\n是否执行更新?(yes/no): ").strip().lower()
|
||||||
|
|
||||||
if response != 'yes':
|
if response != "yes":
|
||||||
print("❌ 已取消更新")
|
print("❌ 已取消更新")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -263,8 +263,8 @@ class AntiPromptInjector:
|
|||||||
try:
|
try:
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
|
|
||||||
from src.common.database.core.models import Messages
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import Messages
|
||||||
|
|
||||||
message_id = message_data.get("message_id")
|
message_id = message_data.get("message_id")
|
||||||
if not message_id:
|
if not message_id:
|
||||||
@@ -291,8 +291,8 @@ class AntiPromptInjector:
|
|||||||
try:
|
try:
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from src.common.database.core.models import Messages
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import Messages
|
||||||
|
|
||||||
message_id = message_data.get("message_id")
|
message_id = message_data.get("message_id")
|
||||||
if not message_id:
|
if not message_id:
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
|
|||||||
|
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
|
||||||
from src.common.database.core.models import AntiInjectionStats
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import AntiInjectionStats
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import datetime
|
|||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.core.models import BanUser
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import BanUser
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from ..types import DetectionResult
|
from ..types import DetectionResult
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from rich.traceback import install
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import Emoji, Images
|
from src.common.database.core.models import Emoji, Images
|
||||||
from src.common.database.api.crud import CRUDBase
|
|
||||||
from src.common.database.utils.decorators import cached
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|||||||
@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.utils.decorators import cached
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("energy_system")
|
logger = get_logger("energy_system")
|
||||||
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
|||||||
|
|
||||||
# 从数据库获取聊天流兴趣分数
|
# 从数据库获取聊天流兴趣分数
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from src.common.database.core.models import ChatStreams
|
from src.common.database.core.models import ChatStreams
|
||||||
|
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ class ExpressionLearner:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 存储到数据库 Expression 表
|
# 存储到数据库 Expression 表
|
||||||
crud = CRUDBase(Expression)
|
CRUDBase(Expression)
|
||||||
for chat_id, expr_list in chat_dict.items():
|
for chat_id, expr_list in chat_dict.items():
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for new_expr in expr_list:
|
for new_expr in expr_list:
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ from json_repair import repair_json
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.database.api.crud import CRUDBase
|
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import Expression
|
from src.common.database.core.models import Expression
|
||||||
from src.common.database.utils.decorators import cached
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|||||||
@@ -728,7 +728,6 @@ class MemorySystem:
|
|||||||
context = context or {}
|
context = context or {}
|
||||||
|
|
||||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
|
||||||
|
|
||||||
self.status = MemorySystemStatus.RETRIEVING
|
self.status = MemorySystemStatus.RETRIEVING
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ import time
|
|||||||
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.database.api.crud import CRUDBase
|
||||||
|
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams # 新增导入
|
from src.common.database.core.models import ChatStreams # 新增导入
|
||||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config # 新增导入
|
from src.config.config import global_config # 新增导入
|
||||||
|
|
||||||
|
|||||||
@@ -105,8 +105,8 @@ class MessageStorageBatcher:
|
|||||||
for msg_data in messages_to_store:
|
for msg_data in messages_to_store:
|
||||||
try:
|
try:
|
||||||
message_dict = await self._prepare_message_dict(
|
message_dict = await self._prepare_message_dict(
|
||||||
msg_data['message'],
|
msg_data["message"],
|
||||||
msg_data['chat_stream']
|
msg_data["chat_stream"]
|
||||||
)
|
)
|
||||||
if message_dict:
|
if message_dict:
|
||||||
messages_dicts.append(message_dict)
|
messages_dicts.append(message_dict)
|
||||||
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
|
|||||||
is_picid = message.is_picid
|
is_picid = message.is_picid
|
||||||
is_notify = message.is_notify
|
is_notify = message.is_notify
|
||||||
is_command = message.is_command
|
is_command = message.is_command
|
||||||
is_public_notice = getattr(message, 'is_public_notice', False)
|
is_public_notice = getattr(message, "is_public_notice", False)
|
||||||
notice_type = getattr(message, 'notice_type', None)
|
notice_type = getattr(message, "notice_type", None)
|
||||||
actions = getattr(message, 'actions', None)
|
actions = getattr(message, "actions", None)
|
||||||
should_reply = getattr(message, 'should_reply', None)
|
should_reply = getattr(message, "should_reply", None)
|
||||||
should_act = getattr(message, 'should_act', None)
|
should_act = getattr(message, "should_act", None)
|
||||||
additional_config = getattr(message, 'additional_config', None)
|
additional_config = getattr(message, "additional_config", None)
|
||||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||||
|
|
||||||
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
|
|||||||
|
|
||||||
|
|
||||||
# 全局批处理器实例
|
# 全局批处理器实例
|
||||||
_message_storage_batcher: Optional[MessageStorageBatcher] = None
|
_message_storage_batcher: MessageStorageBatcher | None = None
|
||||||
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -488,8 +488,8 @@ class MessageStorage:
|
|||||||
if use_batch:
|
if use_batch:
|
||||||
batcher = get_message_storage_batcher()
|
batcher = get_message_storage_batcher()
|
||||||
await batcher.add_message({
|
await batcher.add_message({
|
||||||
'message': message,
|
"message": message,
|
||||||
'chat_stream': chat_stream
|
"chat_stream": chat_stream
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.database.compatibility import db_get, db_query, db_save
|
from src.common.database.compatibility import db_get, db_query
|
||||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from PIL import Image
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_, select
|
||||||
|
|
||||||
from src.common.database.core.models import ImageDescriptions, Images
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import ImageDescriptions, Images
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ from typing import Any
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from src.common.database.core.models import Videos
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import Videos
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|||||||
@@ -9,6 +9,37 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# ===== 核心层 =====
|
# ===== 核心层 =====
|
||||||
|
# ===== API层 =====
|
||||||
|
from src.common.database.api import (
|
||||||
|
AggregateQuery,
|
||||||
|
CRUDBase,
|
||||||
|
QueryBuilder,
|
||||||
|
# ChatStreams API
|
||||||
|
get_active_streams,
|
||||||
|
# Messages API
|
||||||
|
get_chat_history,
|
||||||
|
get_message_count,
|
||||||
|
# PersonInfo API
|
||||||
|
get_or_create_person,
|
||||||
|
# ActionRecords API
|
||||||
|
get_recent_actions,
|
||||||
|
# LLMUsage API
|
||||||
|
get_usage_statistics,
|
||||||
|
record_llm_usage,
|
||||||
|
# 业务API
|
||||||
|
save_message,
|
||||||
|
store_action_info,
|
||||||
|
update_person_affinity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ===== 兼容层(向后兼容旧API)=====
|
||||||
|
from src.common.database.compatibility import (
|
||||||
|
MODEL_MAPPING,
|
||||||
|
build_filters,
|
||||||
|
db_get,
|
||||||
|
db_query,
|
||||||
|
db_save,
|
||||||
|
)
|
||||||
from src.common.database.core import (
|
from src.common.database.core import (
|
||||||
Base,
|
Base,
|
||||||
check_and_migrate_database,
|
check_and_migrate_database,
|
||||||
@@ -27,29 +58,6 @@ from src.common.database.optimization import (
|
|||||||
get_preloader,
|
get_preloader,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== API层 =====
|
|
||||||
from src.common.database.api import (
|
|
||||||
AggregateQuery,
|
|
||||||
CRUDBase,
|
|
||||||
QueryBuilder,
|
|
||||||
# ActionRecords API
|
|
||||||
get_recent_actions,
|
|
||||||
# ChatStreams API
|
|
||||||
get_active_streams,
|
|
||||||
# Messages API
|
|
||||||
get_chat_history,
|
|
||||||
get_message_count,
|
|
||||||
# PersonInfo API
|
|
||||||
get_or_create_person,
|
|
||||||
# LLMUsage API
|
|
||||||
get_usage_statistics,
|
|
||||||
record_llm_usage,
|
|
||||||
# 业务API
|
|
||||||
save_message,
|
|
||||||
store_action_info,
|
|
||||||
update_person_affinity,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ===== Utils层 =====
|
# ===== Utils层 =====
|
||||||
from src.common.database.utils import (
|
from src.common.database.utils import (
|
||||||
cached,
|
cached,
|
||||||
@@ -66,61 +74,52 @@ from src.common.database.utils import (
|
|||||||
transactional,
|
transactional,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ===== 兼容层(向后兼容旧API)=====
|
|
||||||
from src.common.database.compatibility import (
|
|
||||||
MODEL_MAPPING,
|
|
||||||
build_filters,
|
|
||||||
db_get,
|
|
||||||
db_query,
|
|
||||||
db_save,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 核心层
|
|
||||||
"Base",
|
|
||||||
"get_engine",
|
|
||||||
"get_session_factory",
|
|
||||||
"get_db_session",
|
|
||||||
"check_and_migrate_database",
|
|
||||||
# 优化层
|
|
||||||
"MultiLevelCache",
|
|
||||||
"DataPreloader",
|
|
||||||
"AdaptiveBatchScheduler",
|
|
||||||
"get_cache",
|
|
||||||
"get_preloader",
|
|
||||||
"get_batch_scheduler",
|
|
||||||
# API层 - 基础类
|
|
||||||
"CRUDBase",
|
|
||||||
"QueryBuilder",
|
|
||||||
"AggregateQuery",
|
|
||||||
# API层 - 业务API
|
|
||||||
"store_action_info",
|
|
||||||
"get_recent_actions",
|
|
||||||
"get_chat_history",
|
|
||||||
"get_message_count",
|
|
||||||
"save_message",
|
|
||||||
"get_or_create_person",
|
|
||||||
"update_person_affinity",
|
|
||||||
"get_active_streams",
|
|
||||||
"record_llm_usage",
|
|
||||||
"get_usage_statistics",
|
|
||||||
# Utils层
|
|
||||||
"retry",
|
|
||||||
"timeout",
|
|
||||||
"cached",
|
|
||||||
"measure_time",
|
|
||||||
"transactional",
|
|
||||||
"db_operation",
|
|
||||||
"get_monitor",
|
|
||||||
"record_operation",
|
|
||||||
"record_cache_hit",
|
|
||||||
"record_cache_miss",
|
|
||||||
"print_stats",
|
|
||||||
"reset_stats",
|
|
||||||
# 兼容层
|
# 兼容层
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
|
"AdaptiveBatchScheduler",
|
||||||
|
"AggregateQuery",
|
||||||
|
# 核心层
|
||||||
|
"Base",
|
||||||
|
# API层 - 基础类
|
||||||
|
"CRUDBase",
|
||||||
|
"DataPreloader",
|
||||||
|
# 优化层
|
||||||
|
"MultiLevelCache",
|
||||||
|
"QueryBuilder",
|
||||||
"build_filters",
|
"build_filters",
|
||||||
|
"cached",
|
||||||
|
"check_and_migrate_database",
|
||||||
|
"db_get",
|
||||||
|
"db_operation",
|
||||||
"db_query",
|
"db_query",
|
||||||
"db_save",
|
"db_save",
|
||||||
"db_get",
|
"get_active_streams",
|
||||||
|
"get_batch_scheduler",
|
||||||
|
"get_cache",
|
||||||
|
"get_chat_history",
|
||||||
|
"get_db_session",
|
||||||
|
"get_engine",
|
||||||
|
"get_message_count",
|
||||||
|
"get_monitor",
|
||||||
|
"get_or_create_person",
|
||||||
|
"get_preloader",
|
||||||
|
"get_recent_actions",
|
||||||
|
"get_session_factory",
|
||||||
|
"get_usage_statistics",
|
||||||
|
"measure_time",
|
||||||
|
"print_stats",
|
||||||
|
"record_cache_hit",
|
||||||
|
"record_cache_miss",
|
||||||
|
"record_llm_usage",
|
||||||
|
"record_operation",
|
||||||
|
"reset_stats",
|
||||||
|
# Utils层
|
||||||
|
"retry",
|
||||||
|
"save_message",
|
||||||
|
# API层 - 业务API
|
||||||
|
"store_action_info",
|
||||||
|
"timeout",
|
||||||
|
"transactional",
|
||||||
|
"update_person_affinity",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,49 +11,49 @@ from src.common.database.api.query import AggregateQuery, QueryBuilder
|
|||||||
|
|
||||||
# 业务特定API
|
# 业务特定API
|
||||||
from src.common.database.api.specialized import (
|
from src.common.database.api.specialized import (
|
||||||
# ActionRecords
|
|
||||||
get_recent_actions,
|
|
||||||
store_action_info,
|
|
||||||
# ChatStreams
|
# ChatStreams
|
||||||
get_active_streams,
|
get_active_streams,
|
||||||
get_or_create_chat_stream,
|
|
||||||
# LLMUsage
|
|
||||||
get_usage_statistics,
|
|
||||||
record_llm_usage,
|
|
||||||
# Messages
|
# Messages
|
||||||
get_chat_history,
|
get_chat_history,
|
||||||
get_message_count,
|
get_message_count,
|
||||||
save_message,
|
get_or_create_chat_stream,
|
||||||
# PersonInfo
|
# PersonInfo
|
||||||
get_or_create_person,
|
get_or_create_person,
|
||||||
update_person_affinity,
|
# ActionRecords
|
||||||
|
get_recent_actions,
|
||||||
|
# LLMUsage
|
||||||
|
get_usage_statistics,
|
||||||
# UserRelationships
|
# UserRelationships
|
||||||
get_user_relationship,
|
get_user_relationship,
|
||||||
|
record_llm_usage,
|
||||||
|
save_message,
|
||||||
|
store_action_info,
|
||||||
|
update_person_affinity,
|
||||||
update_relationship_affinity,
|
update_relationship_affinity,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AggregateQuery",
|
||||||
# 基础类
|
# 基础类
|
||||||
"CRUDBase",
|
"CRUDBase",
|
||||||
"QueryBuilder",
|
"QueryBuilder",
|
||||||
"AggregateQuery",
|
"get_active_streams",
|
||||||
# ActionRecords API
|
|
||||||
"store_action_info",
|
|
||||||
"get_recent_actions",
|
|
||||||
# Messages API
|
# Messages API
|
||||||
"get_chat_history",
|
"get_chat_history",
|
||||||
"get_message_count",
|
"get_message_count",
|
||||||
"save_message",
|
|
||||||
# PersonInfo API
|
|
||||||
"get_or_create_person",
|
|
||||||
"update_person_affinity",
|
|
||||||
# ChatStreams API
|
# ChatStreams API
|
||||||
"get_or_create_chat_stream",
|
"get_or_create_chat_stream",
|
||||||
"get_active_streams",
|
# PersonInfo API
|
||||||
# LLMUsage API
|
"get_or_create_person",
|
||||||
"record_llm_usage",
|
"get_recent_actions",
|
||||||
"get_usage_statistics",
|
"get_usage_statistics",
|
||||||
# UserRelationships API
|
# UserRelationships API
|
||||||
"get_user_relationship",
|
"get_user_relationship",
|
||||||
|
# LLMUsage API
|
||||||
|
"record_llm_usage",
|
||||||
|
"save_message",
|
||||||
|
# ActionRecords API
|
||||||
|
"store_action_info",
|
||||||
|
"update_person_affinity",
|
||||||
"update_relationship_affinity",
|
"update_relationship_affinity",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ class CRUDBase:
|
|||||||
# 应用过滤条件
|
# 应用过滤条件
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if hasattr(self.model, key):
|
if hasattr(self.model, key):
|
||||||
if isinstance(value, (list, tuple, set)):
|
if isinstance(value, list | tuple | set):
|
||||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||||
else:
|
else:
|
||||||
stmt = stmt.where(getattr(self.model, key) == value)
|
stmt = stmt.where(getattr(self.model, key) == value)
|
||||||
@@ -397,7 +397,7 @@ class CRUDBase:
|
|||||||
# 应用过滤条件
|
# 应用过滤条件
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if hasattr(self.model, key):
|
if hasattr(self.model, key):
|
||||||
if isinstance(value, (list, tuple, set)):
|
if isinstance(value, list | tuple | set):
|
||||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||||
else:
|
else:
|
||||||
stmt = stmt.where(getattr(self.model, key) == value)
|
stmt = stmt.where(getattr(self.model, key) == value)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
@@ -42,9 +42,9 @@ async def store_action_info(
|
|||||||
action_prompt_display: str = "",
|
action_prompt_display: str = "",
|
||||||
action_done: bool = True,
|
action_done: bool = True,
|
||||||
thinking_id: str = "",
|
thinking_id: str = "",
|
||||||
action_data: Optional[dict] = None,
|
action_data: dict | None = None,
|
||||||
action_name: str = "",
|
action_name: str = "",
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> dict[str, Any] | None:
|
||||||
"""存储动作信息到数据库
|
"""存储动作信息到数据库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -167,7 +167,7 @@ async def get_message_count(stream_id: str) -> int:
|
|||||||
async def save_message(
|
async def save_message(
|
||||||
message_data: dict[str, Any],
|
message_data: dict[str, Any],
|
||||||
use_batch: bool = True,
|
use_batch: bool = True,
|
||||||
) -> Optional[Messages]:
|
) -> Messages | None:
|
||||||
"""保存消息
|
"""保存消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -185,8 +185,8 @@ async def save_message(
|
|||||||
async def get_or_create_person(
|
async def get_or_create_person(
|
||||||
platform: str,
|
platform: str,
|
||||||
person_id: str,
|
person_id: str,
|
||||||
defaults: Optional[dict[str, Any]] = None,
|
defaults: dict[str, Any] | None = None,
|
||||||
) -> tuple[Optional[PersonInfo], bool]:
|
) -> tuple[PersonInfo | None, bool]:
|
||||||
"""获取或创建人员信息
|
"""获取或创建人员信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -255,8 +255,8 @@ async def update_person_affinity(
|
|||||||
async def get_or_create_chat_stream(
|
async def get_or_create_chat_stream(
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
defaults: Optional[dict[str, Any]] = None,
|
defaults: dict[str, Any] | None = None,
|
||||||
) -> tuple[Optional[ChatStreams], bool]:
|
) -> tuple[ChatStreams | None, bool]:
|
||||||
"""获取或创建聊天流
|
"""获取或创建聊天流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -275,7 +275,7 @@ async def get_or_create_chat_stream(
|
|||||||
|
|
||||||
|
|
||||||
async def get_active_streams(
|
async def get_active_streams(
|
||||||
platform: Optional[str] = None,
|
platform: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> list[ChatStreams]:
|
) -> list[ChatStreams]:
|
||||||
"""获取活跃的聊天流
|
"""获取活跃的聊天流
|
||||||
@@ -300,18 +300,18 @@ async def record_llm_usage(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
input_tokens: int,
|
input_tokens: int,
|
||||||
output_tokens: int,
|
output_tokens: int,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: str | None = None,
|
||||||
platform: Optional[str] = None,
|
platform: str | None = None,
|
||||||
user_id: str = "system",
|
user_id: str = "system",
|
||||||
request_type: str = "chat",
|
request_type: str = "chat",
|
||||||
model_assign_name: Optional[str] = None,
|
model_assign_name: str | None = None,
|
||||||
model_api_provider: Optional[str] = None,
|
model_api_provider: str | None = None,
|
||||||
endpoint: str = "/v1/chat/completions",
|
endpoint: str = "/v1/chat/completions",
|
||||||
cost: float = 0.0,
|
cost: float = 0.0,
|
||||||
status: str = "success",
|
status: str = "success",
|
||||||
time_cost: Optional[float] = None,
|
time_cost: float | None = None,
|
||||||
use_batch: bool = True,
|
use_batch: bool = True,
|
||||||
) -> Optional[LLMUsage]:
|
) -> LLMUsage | None:
|
||||||
"""记录LLM使用情况
|
"""记录LLM使用情况
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -354,9 +354,9 @@ async def record_llm_usage(
|
|||||||
|
|
||||||
|
|
||||||
async def get_usage_statistics(
|
async def get_usage_statistics(
|
||||||
start_time: Optional[float] = None,
|
start_time: float | None = None,
|
||||||
end_time: Optional[float] = None,
|
end_time: float | None = None,
|
||||||
model_name: Optional[str] = None,
|
model_name: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""获取使用统计
|
"""获取使用统计
|
||||||
|
|
||||||
@@ -374,8 +374,7 @@ async def get_usage_statistics(
|
|||||||
|
|
||||||
# 添加时间过滤
|
# 添加时间过滤
|
||||||
if start_time:
|
if start_time:
|
||||||
async with get_db_session() as session:
|
async with get_db_session():
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
conditions = []
|
conditions = []
|
||||||
if start_time:
|
if start_time:
|
||||||
@@ -407,7 +406,7 @@ async def get_user_relationship(
|
|||||||
platform: str,
|
platform: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
target_id: str,
|
target_id: str,
|
||||||
) -> Optional[UserRelationships]:
|
) -> UserRelationships | None:
|
||||||
"""获取用户关系
|
"""获取用户关系
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -14,14 +14,14 @@ from .adapter import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 从 core 重新导出的函数
|
|
||||||
"get_db_session",
|
|
||||||
"get_engine",
|
|
||||||
# 兼容层适配器
|
# 兼容层适配器
|
||||||
"MODEL_MAPPING",
|
"MODEL_MAPPING",
|
||||||
"build_filters",
|
"build_filters",
|
||||||
|
"db_get",
|
||||||
"db_query",
|
"db_query",
|
||||||
"db_save",
|
"db_save",
|
||||||
"db_get",
|
# 从 core 重新导出的函数
|
||||||
|
"get_db_session",
|
||||||
|
"get_engine",
|
||||||
"store_action_info",
|
"store_action_info",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,15 +4,13 @@
|
|||||||
保持原有函数签名和行为不变
|
保持原有函数签名和行为不变
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
from typing import Any
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from sqlalchemy import and_, asc, desc, select
|
|
||||||
|
|
||||||
from src.common.database.api import (
|
from src.common.database.api import (
|
||||||
CRUDBase,
|
CRUDBase,
|
||||||
QueryBuilder,
|
QueryBuilder,
|
||||||
|
)
|
||||||
|
from src.common.database.api import (
|
||||||
store_action_info as new_store_action_info,
|
store_action_info as new_store_action_info,
|
||||||
)
|
)
|
||||||
from src.common.database.core.models import (
|
from src.common.database.core.models import (
|
||||||
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
|
|||||||
Messages,
|
Messages,
|
||||||
MonthlyPlan,
|
MonthlyPlan,
|
||||||
OnlineTime,
|
OnlineTime,
|
||||||
PersonInfo,
|
|
||||||
PermissionNodes,
|
PermissionNodes,
|
||||||
|
PersonInfo,
|
||||||
Schedule,
|
Schedule,
|
||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
UserPermissions,
|
UserPermissions,
|
||||||
UserRelationships,
|
UserRelationships,
|
||||||
Videos,
|
Videos,
|
||||||
)
|
)
|
||||||
from src.common.database.core.session import get_db_session
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("database.compatibility")
|
logger = get_logger("database.compatibility")
|
||||||
@@ -145,12 +142,12 @@ def _model_to_dict(instance) -> dict[str, Any]:
|
|||||||
|
|
||||||
async def db_query(
|
async def db_query(
|
||||||
model_class,
|
model_class,
|
||||||
data: Optional[dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
query_type: Optional[str] = "get",
|
query_type: str | None = "get",
|
||||||
filters: Optional[dict[str, Any]] = None,
|
filters: dict[str, Any] | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
order_by: Optional[list[str]] = None,
|
order_by: list[str] | None = None,
|
||||||
single_result: Optional[bool] = False,
|
single_result: bool | None = False,
|
||||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||||
"""执行异步数据库查询操作(兼容旧API)
|
"""执行异步数据库查询操作(兼容旧API)
|
||||||
|
|
||||||
@@ -286,7 +283,7 @@ async def db_save(
|
|||||||
data: dict[str, Any],
|
data: dict[str, Any],
|
||||||
key_field: str,
|
key_field: str,
|
||||||
key_value: Any,
|
key_value: Any,
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> dict[str, Any] | None:
|
||||||
"""保存或更新记录(兼容旧API)
|
"""保存或更新记录(兼容旧API)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -319,10 +316,10 @@ async def db_save(
|
|||||||
|
|
||||||
async def db_get(
|
async def db_get(
|
||||||
model_class,
|
model_class,
|
||||||
filters: Optional[dict[str, Any]] = None,
|
filters: dict[str, Any] | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
order_by: Optional[str] = None,
|
order_by: str | None = None,
|
||||||
single_result: Optional[bool] = False,
|
single_result: bool | None = False,
|
||||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||||
"""从数据库获取记录(兼容旧API)
|
"""从数据库获取记录(兼容旧API)
|
||||||
|
|
||||||
@@ -353,9 +350,9 @@ async def store_action_info(
|
|||||||
action_prompt_display: str = "",
|
action_prompt_display: str = "",
|
||||||
action_done: bool = True,
|
action_done: bool = True,
|
||||||
thinking_id: str = "",
|
thinking_id: str = "",
|
||||||
action_data: Optional[dict] = None,
|
action_data: dict | None = None,
|
||||||
action_name: str = "",
|
action_name: str = "",
|
||||||
) -> Optional[dict[str, Any]]:
|
) -> dict[str, Any] | None:
|
||||||
"""存储动作信息到数据库(兼容旧API)
|
"""存储动作信息到数据库(兼容旧API)
|
||||||
|
|
||||||
直接使用新的specialized API
|
直接使用新的specialized API
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -25,19 +25,19 @@ class DatabaseConfig:
|
|||||||
engine_kwargs: dict[str, Any]
|
engine_kwargs: dict[str, Any]
|
||||||
|
|
||||||
# SQLite特定配置
|
# SQLite特定配置
|
||||||
sqlite_path: Optional[str] = None
|
sqlite_path: str | None = None
|
||||||
|
|
||||||
# MySQL特定配置
|
# MySQL特定配置
|
||||||
mysql_host: Optional[str] = None
|
mysql_host: str | None = None
|
||||||
mysql_port: Optional[int] = None
|
mysql_port: int | None = None
|
||||||
mysql_user: Optional[str] = None
|
mysql_user: str | None = None
|
||||||
mysql_password: Optional[str] = None
|
mysql_password: str | None = None
|
||||||
mysql_database: Optional[str] = None
|
mysql_database: str | None = None
|
||||||
mysql_charset: str = "utf8mb4"
|
mysql_charset: str = "utf8mb4"
|
||||||
mysql_unix_socket: Optional[str] = None
|
mysql_unix_socket: str | None = None
|
||||||
|
|
||||||
|
|
||||||
_database_config: Optional[DatabaseConfig] = None
|
_database_config: DatabaseConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_database_config() -> DatabaseConfig:
|
def get_database_config() -> DatabaseConfig:
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from .models import (
|
|||||||
ChatStreams,
|
ChatStreams,
|
||||||
Emoji,
|
Emoji,
|
||||||
Expression,
|
Expression,
|
||||||
get_string_field,
|
|
||||||
GraphEdges,
|
GraphEdges,
|
||||||
GraphNodes,
|
GraphNodes,
|
||||||
ImageDescriptions,
|
ImageDescriptions,
|
||||||
@@ -37,30 +36,17 @@ from .models import (
|
|||||||
UserPermissions,
|
UserPermissions,
|
||||||
UserRelationships,
|
UserRelationships,
|
||||||
Videos,
|
Videos,
|
||||||
|
get_string_field,
|
||||||
)
|
)
|
||||||
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
|
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine
|
|
||||||
"get_engine",
|
|
||||||
"close_engine",
|
|
||||||
"get_engine_info",
|
|
||||||
# Session
|
|
||||||
"get_db_session",
|
|
||||||
"get_db_session_direct",
|
|
||||||
"get_session_factory",
|
|
||||||
"reset_session_factory",
|
|
||||||
# Migration
|
|
||||||
"check_and_migrate_database",
|
|
||||||
"create_all_tables",
|
|
||||||
"drop_all_tables",
|
|
||||||
# Models - Base
|
|
||||||
"Base",
|
|
||||||
"get_string_field",
|
|
||||||
# Models - Tables (按字母顺序)
|
# Models - Tables (按字母顺序)
|
||||||
"ActionRecords",
|
"ActionRecords",
|
||||||
"AntiInjectionStats",
|
"AntiInjectionStats",
|
||||||
"BanUser",
|
"BanUser",
|
||||||
|
# Models - Base
|
||||||
|
"Base",
|
||||||
"BotPersonalityInterests",
|
"BotPersonalityInterests",
|
||||||
"CacheEntries",
|
"CacheEntries",
|
||||||
"ChatStreams",
|
"ChatStreams",
|
||||||
@@ -83,4 +69,18 @@ __all__ = [
|
|||||||
"UserPermissions",
|
"UserPermissions",
|
||||||
"UserRelationships",
|
"UserRelationships",
|
||||||
"Videos",
|
"Videos",
|
||||||
|
# Migration
|
||||||
|
"check_and_migrate_database",
|
||||||
|
"close_engine",
|
||||||
|
"create_all_tables",
|
||||||
|
"drop_all_tables",
|
||||||
|
# Session
|
||||||
|
"get_db_session",
|
||||||
|
"get_db_session_direct",
|
||||||
|
# Engine
|
||||||
|
"get_engine",
|
||||||
|
"get_engine_info",
|
||||||
|
"get_session_factory",
|
||||||
|
"get_string_field",
|
||||||
|
"reset_session_factory",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@@ -18,8 +17,8 @@ from ..utils.exceptions import DatabaseInitializationError
|
|||||||
logger = get_logger("database.engine")
|
logger = get_logger("database.engine")
|
||||||
|
|
||||||
# 全局引擎实例
|
# 全局引擎实例
|
||||||
_engine: Optional[AsyncEngine] = None
|
_engine: AsyncEngine | None = None
|
||||||
_engine_lock: Optional[asyncio.Lock] = None
|
_engine_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
|
|
||||||
async def get_engine() -> AsyncEngine:
|
async def get_engine() -> AsyncEngine:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -18,8 +17,8 @@ from .engine import get_engine
|
|||||||
logger = get_logger("database.session")
|
logger = get_logger("database.session")
|
||||||
|
|
||||||
# 全局会话工厂
|
# 全局会话工厂
|
||||||
_session_factory: Optional[async_sessionmaker] = None
|
_session_factory: async_sessionmaker | None = None
|
||||||
_factory_lock: Optional[asyncio.Lock] = None
|
_factory_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
|
|
||||||
async def get_session_factory() -> async_sessionmaker:
|
async def get_session_factory() -> async_sessionmaker:
|
||||||
|
|||||||
@@ -11,17 +11,17 @@ from .batch_scheduler import (
|
|||||||
AdaptiveBatchScheduler,
|
AdaptiveBatchScheduler,
|
||||||
BatchOperation,
|
BatchOperation,
|
||||||
BatchStats,
|
BatchStats,
|
||||||
|
Priority,
|
||||||
close_batch_scheduler,
|
close_batch_scheduler,
|
||||||
get_batch_scheduler,
|
get_batch_scheduler,
|
||||||
Priority,
|
|
||||||
)
|
)
|
||||||
from .cache_manager import (
|
from .cache_manager import (
|
||||||
CacheEntry,
|
CacheEntry,
|
||||||
CacheStats,
|
CacheStats,
|
||||||
close_cache,
|
|
||||||
get_cache,
|
|
||||||
LRUCache,
|
LRUCache,
|
||||||
MultiLevelCache,
|
MultiLevelCache,
|
||||||
|
close_cache,
|
||||||
|
get_cache,
|
||||||
)
|
)
|
||||||
from .connection_pool import (
|
from .connection_pool import (
|
||||||
ConnectionPoolManager,
|
ConnectionPoolManager,
|
||||||
@@ -31,36 +31,36 @@ from .connection_pool import (
|
|||||||
)
|
)
|
||||||
from .preloader import (
|
from .preloader import (
|
||||||
AccessPattern,
|
AccessPattern,
|
||||||
close_preloader,
|
|
||||||
CommonDataPreloader,
|
CommonDataPreloader,
|
||||||
DataPreloader,
|
DataPreloader,
|
||||||
|
close_preloader,
|
||||||
get_preloader,
|
get_preloader,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Connection Pool
|
|
||||||
"ConnectionPoolManager",
|
|
||||||
"get_connection_pool_manager",
|
|
||||||
"start_connection_pool",
|
|
||||||
"stop_connection_pool",
|
|
||||||
# Cache
|
|
||||||
"MultiLevelCache",
|
|
||||||
"LRUCache",
|
|
||||||
"CacheEntry",
|
|
||||||
"CacheStats",
|
|
||||||
"get_cache",
|
|
||||||
"close_cache",
|
|
||||||
# Preloader
|
|
||||||
"DataPreloader",
|
|
||||||
"CommonDataPreloader",
|
|
||||||
"AccessPattern",
|
"AccessPattern",
|
||||||
"get_preloader",
|
|
||||||
"close_preloader",
|
|
||||||
# Batch Scheduler
|
# Batch Scheduler
|
||||||
"AdaptiveBatchScheduler",
|
"AdaptiveBatchScheduler",
|
||||||
"BatchOperation",
|
"BatchOperation",
|
||||||
"BatchStats",
|
"BatchStats",
|
||||||
|
"CacheEntry",
|
||||||
|
"CacheStats",
|
||||||
|
"CommonDataPreloader",
|
||||||
|
# Connection Pool
|
||||||
|
"ConnectionPoolManager",
|
||||||
|
# Preloader
|
||||||
|
"DataPreloader",
|
||||||
|
"LRUCache",
|
||||||
|
# Cache
|
||||||
|
"MultiLevelCache",
|
||||||
"Priority",
|
"Priority",
|
||||||
"get_batch_scheduler",
|
|
||||||
"close_batch_scheduler",
|
"close_batch_scheduler",
|
||||||
|
"close_cache",
|
||||||
|
"close_preloader",
|
||||||
|
"get_batch_scheduler",
|
||||||
|
"get_cache",
|
||||||
|
"get_connection_pool_manager",
|
||||||
|
"get_preloader",
|
||||||
|
"start_connection_pool",
|
||||||
|
"stop_connection_pool",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,12 +10,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import delete, insert, select, update
|
from sqlalchemy import delete, insert, select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from src.common.database.core.session import get_db_session
|
from src.common.database.core.session import get_db_session
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -40,12 +40,12 @@ class BatchOperation:
|
|||||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||||
model_class: type
|
model_class: type
|
||||||
conditions: dict[str, Any] = field(default_factory=dict)
|
conditions: dict[str, Any] = field(default_factory=dict)
|
||||||
data: Optional[dict[str, Any]] = None
|
data: dict[str, Any] | None = None
|
||||||
callback: Optional[Callable] = None
|
callback: Callable | None = None
|
||||||
future: Optional[asyncio.Future] = None
|
future: asyncio.Future | None = None
|
||||||
timestamp: float = field(default_factory=time.time)
|
timestamp: float = field(default_factory=time.time)
|
||||||
priority: Priority = Priority.NORMAL
|
priority: Priority = Priority.NORMAL
|
||||||
timeout: Optional[float] = None # 超时时间(秒)
|
timeout: float | None = None # 超时时间(秒)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -111,7 +111,7 @@ class AdaptiveBatchScheduler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 调度控制
|
# 调度控制
|
||||||
self._scheduler_task: Optional[asyncio.Task] = None
|
self._scheduler_task: asyncio.Task | None = None
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
@@ -257,7 +257,7 @@ class AdaptiveBatchScheduler:
|
|||||||
op_groups[key].append(op)
|
op_groups[key].append(op)
|
||||||
|
|
||||||
# 执行各组操作
|
# 执行各组操作
|
||||||
for group_key, ops in op_groups.items():
|
for ops in op_groups.values():
|
||||||
await self._execute_group(ops)
|
await self._execute_group(ops)
|
||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
@@ -323,7 +323,7 @@ class AdaptiveBatchScheduler:
|
|||||||
stmt = select(op.model_class)
|
stmt = select(op.model_class)
|
||||||
for key, value in op.conditions.items():
|
for key, value in op.conditions.items():
|
||||||
attr = getattr(op.model_class, key)
|
attr = getattr(op.model_class, key)
|
||||||
if isinstance(value, (list, tuple, set)):
|
if isinstance(value, list | tuple | set):
|
||||||
stmt = stmt.where(attr.in_(value))
|
stmt = stmt.where(attr.in_(value))
|
||||||
else:
|
else:
|
||||||
stmt = stmt.where(attr == value)
|
stmt = stmt.where(attr == value)
|
||||||
@@ -366,7 +366,7 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
# 批量插入
|
# 批量插入
|
||||||
stmt = insert(operations[0].model_class).values(all_data)
|
stmt = insert(operations[0].model_class).values(all_data)
|
||||||
result = await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# 设置结果
|
# 设置结果
|
||||||
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
|
|||||||
]
|
]
|
||||||
return "|".join(key_parts)
|
return "|".join(key_parts)
|
||||||
|
|
||||||
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
|
def _get_from_cache(self, cache_key: str) -> Any | None:
|
||||||
"""从缓存获取结果"""
|
"""从缓存获取结果"""
|
||||||
if cache_key in self._result_cache:
|
if cache_key in self._result_cache:
|
||||||
result, timestamp = self._result_cache[cache_key]
|
result, timestamp = self._result_cache[cache_key]
|
||||||
@@ -551,7 +551,7 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
|
|
||||||
# 全局调度器实例
|
# 全局调度器实例
|
||||||
_global_scheduler: Optional[AdaptiveBatchScheduler] = None
|
_global_scheduler: AdaptiveBatchScheduler | None = None
|
||||||
_scheduler_lock = asyncio.Lock()
|
_scheduler_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -94,7 +95,7 @@ class LRUCache(Generic[T]):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._stats = CacheStats()
|
self._stats = CacheStats()
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[T]:
|
async def get(self, key: str) -> T | None:
|
||||||
"""获取缓存值
|
"""获取缓存值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -135,7 +136,7 @@ class LRUCache(Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: T,
|
value: T,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -255,7 +256,7 @@ class MultiLevelCache:
|
|||||||
"""
|
"""
|
||||||
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
|
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
|
||||||
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
|
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
|
||||||
self._cleanup_task: Optional[asyncio.Task] = None
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
|
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
|
||||||
@@ -265,8 +266,8 @@ class MultiLevelCache:
|
|||||||
async def get(
|
async def get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
loader: Optional[Callable[[], Any]] = None,
|
loader: Callable[[], Any] | None = None,
|
||||||
) -> Optional[Any]:
|
) -> Any | None:
|
||||||
"""从缓存获取数据
|
"""从缓存获取数据
|
||||||
|
|
||||||
查询顺序:L1 -> L2 -> loader
|
查询顺序:L1 -> L2 -> loader
|
||||||
@@ -307,7 +308,7 @@ class MultiLevelCache:
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
size: Optional[int] = None,
|
size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""设置缓存值
|
"""设置缓存值
|
||||||
|
|
||||||
@@ -387,7 +388,7 @@ class MultiLevelCache:
|
|||||||
|
|
||||||
|
|
||||||
# 全局缓存实例
|
# 全局缓存实例
|
||||||
_global_cache: Optional[MultiLevelCache] = None
|
_global_cache: MultiLevelCache | None = None
|
||||||
_cache_lock = asyncio.Lock()
|
_cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -84,7 +85,7 @@ class DataPreloader:
|
|||||||
async def record_access(
|
async def record_access(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
related_keys: Optional[list[str]] = None,
|
related_keys: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录数据访问
|
"""记录数据访问
|
||||||
|
|
||||||
@@ -379,7 +380,7 @@ class CommonDataPreloader:
|
|||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
limit: 消息数量限制
|
limit: 消息数量限制
|
||||||
"""
|
"""
|
||||||
from src.common.database.core.models import ChatStreams, Messages
|
from src.common.database.core.models import ChatStreams
|
||||||
|
|
||||||
# 预加载聊天流信息
|
# 预加载聊天流信息
|
||||||
await self._preload_model(
|
await self._preload_model(
|
||||||
@@ -418,7 +419,7 @@ class CommonDataPreloader:
|
|||||||
|
|
||||||
|
|
||||||
# 全局预加载器实例
|
# 全局预加载器实例
|
||||||
_global_preloader: Optional[DataPreloader] = None
|
_global_preloader: DataPreloader | None = None
|
||||||
_preloader_lock = asyncio.Lock()
|
_preloader_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -37,29 +37,29 @@ from .monitoring import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"BatchSchedulerError",
|
||||||
|
"CacheError",
|
||||||
|
"ConnectionPoolError",
|
||||||
|
"DatabaseConnectionError",
|
||||||
# 异常
|
# 异常
|
||||||
"DatabaseError",
|
"DatabaseError",
|
||||||
"DatabaseInitializationError",
|
"DatabaseInitializationError",
|
||||||
"DatabaseConnectionError",
|
"DatabaseMigrationError",
|
||||||
|
# 监控
|
||||||
|
"DatabaseMonitor",
|
||||||
"DatabaseQueryError",
|
"DatabaseQueryError",
|
||||||
"DatabaseTransactionError",
|
"DatabaseTransactionError",
|
||||||
"DatabaseMigrationError",
|
"cached",
|
||||||
"CacheError",
|
"db_operation",
|
||||||
"BatchSchedulerError",
|
"get_monitor",
|
||||||
"ConnectionPoolError",
|
"measure_time",
|
||||||
|
"print_stats",
|
||||||
|
"record_cache_hit",
|
||||||
|
"record_cache_miss",
|
||||||
|
"record_operation",
|
||||||
|
"reset_stats",
|
||||||
# 装饰器
|
# 装饰器
|
||||||
"retry",
|
"retry",
|
||||||
"timeout",
|
"timeout",
|
||||||
"cached",
|
|
||||||
"measure_time",
|
|
||||||
"transactional",
|
"transactional",
|
||||||
"db_operation",
|
|
||||||
# 监控
|
|
||||||
"DatabaseMonitor",
|
|
||||||
"get_monitor",
|
|
||||||
"record_operation",
|
|
||||||
"record_cache_hit",
|
|
||||||
"record_cache_miss",
|
|
||||||
"print_stats",
|
|
||||||
"reset_stats",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError
|
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||||
|
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -141,8 +143,8 @@ def timeout(seconds: float):
|
|||||||
|
|
||||||
|
|
||||||
def cached(
|
def cached(
|
||||||
ttl: Optional[int] = 300,
|
ttl: int | None = 300,
|
||||||
key_prefix: Optional[str] = None,
|
key_prefix: str | None = None,
|
||||||
use_args: bool = True,
|
use_args: bool = True,
|
||||||
use_kwargs: bool = True,
|
use_kwargs: bool = True,
|
||||||
):
|
):
|
||||||
@@ -207,7 +209,7 @@ def cached(
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def measure_time(log_slow: Optional[float] = None):
|
def measure_time(log_slow: float | None = None):
|
||||||
"""性能测量装饰器
|
"""性能测量装饰器
|
||||||
|
|
||||||
测量函数执行时间,可选择性记录慢查询
|
测量函数执行时间,可选择性记录慢查询
|
||||||
@@ -306,8 +308,8 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
|||||||
# 组合装饰器示例
|
# 组合装饰器示例
|
||||||
def db_operation(
|
def db_operation(
|
||||||
retry_attempts: int = 3,
|
retry_attempts: int = 3,
|
||||||
timeout_seconds: Optional[float] = None,
|
timeout_seconds: float | None = None,
|
||||||
cache_ttl: Optional[int] = None,
|
cache_ttl: int | None = None,
|
||||||
measure: bool = True,
|
measure: bool = True,
|
||||||
):
|
):
|
||||||
"""组合装饰器
|
"""组合装饰器
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ class OperationMetrics:
|
|||||||
min_time: float = float("inf")
|
min_time: float = float("inf")
|
||||||
max_time: float = 0.0
|
max_time: float = 0.0
|
||||||
error_count: int = 0
|
error_count: int = 0
|
||||||
last_execution_time: Optional[float] = None
|
last_execution_time: float | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def avg_time(self) -> float:
|
def avg_time(self) -> float:
|
||||||
@@ -285,7 +284,7 @@ class DatabaseMonitor:
|
|||||||
|
|
||||||
|
|
||||||
# 全局监控器实例
|
# 全局监控器实例
|
||||||
_monitor: Optional[DatabaseMonitor] = None
|
_monitor: DatabaseMonitor | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_monitor() -> DatabaseMonitor:
|
def get_monitor() -> DatabaseMonitor:
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from datetime import datetime
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from src.common.database.core.models import LLMUsage
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import LLMUsage
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.api_ada_configs import ModelInfo
|
from src.config.api_ada_configs import ModelInfo
|
||||||
|
|
||||||
|
|||||||
@@ -256,9 +256,10 @@ class RelationshipFetcher:
|
|||||||
str: 格式化后的聊天流印象字符串
|
str: 格式化后的聊天流印象字符串
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||||
|
|
||||||
# 使用优化后的API(带缓存)
|
# 使用优化后的API(带缓存)
|
||||||
# 从stream_id解析platform,或使用默认值
|
# 从stream_id解析platform,或使用默认值
|
||||||
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"
|
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"
|
||||||
|
|||||||
@@ -52,8 +52,8 @@ from typing import Any
|
|||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.schedule.database import get_active_plans_for_month
|
from src.schedule.database import get_active_plans_for_month
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from src.plugin_system.base.component_types import ChatType, CommandInfo, Compon
|
|||||||
from src.plugin_system.base.plus_command import PlusCommand
|
from src.plugin_system.base.plus_command import PlusCommand
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
pass
|
||||||
|
|
||||||
logger = get_logger("base_command")
|
logger = get_logger("base_command")
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from sqlalchemy import delete, select
|
|||||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||||
|
|
||||||
from src.common.database.core.models import PermissionNodes, UserPermissions
|
|
||||||
from src.common.database.core import get_engine
|
from src.common.database.core import get_engine
|
||||||
|
from src.common.database.core.models import PermissionNodes, UserPermissions
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.common.database.core.models import UserRelationships
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import UserRelationships
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from src.common.database.compatibility import get_db_session
|
|
||||||
from src.common.database.core.models import ChatStreams
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.utils.decorators import cached
|
from src.common.database.core.models import ChatStreams
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|||||||
@@ -7,13 +7,10 @@ import json
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.utils.prompt import Prompt
|
from src.chat.utils.prompt import Prompt
|
||||||
from src.common.database.compatibility import get_db_session
|
|
||||||
from src.common.database.core.models import ChatStreams
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
from src.common.database.api.crud import CRUDBase
|
||||||
|
from src.common.database.core.models import ChatStreams
|
||||||
from src.common.database.utils.decorators import cached
|
from src.common.database.utils.decorators import cached
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
from sqlalchemy import delete, func, select, update
|
from sqlalchemy import delete, func, select, update
|
||||||
|
|
||||||
from src.common.database.core.models import MonthlyPlan
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import MonthlyPlan
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -312,7 +312,7 @@ async def delete_plans_older_than(month: str):
|
|||||||
logger.info(f"没有找到比 {month} 更早的月度计划需要删除。")
|
logger.info(f"没有找到比 {month} 更早的月度计划需要删除。")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
plan_months = sorted(list(set(p.target_month for p in plans_to_delete)))
|
plan_months = sorted({p.target_month for p in plans_to_delete})
|
||||||
logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。")
|
logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。")
|
||||||
|
|
||||||
# 然后,执行删除操作
|
# 然后,执行删除操作
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import Any
|
|||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
|
||||||
from src.common.database.core import get_db_session
|
from src.common.database.core import get_db_session
|
||||||
|
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||||
|
|||||||
Reference in New Issue
Block a user