rufffffff
This commit is contained in:
@@ -16,7 +16,7 @@ models_file = os.path.join(
|
||||
print(f"正在清理文件: {models_file}")
|
||||
|
||||
# 读取文件
|
||||
with open(models_file, "r", encoding="utf-8") as f:
|
||||
with open(models_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
||||
@@ -26,7 +26,7 @@ found_end = False
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
keep_lines.append(line)
|
||||
|
||||
|
||||
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
|
||||
if i > 580 and line.strip() == ")":
|
||||
# 再检查前一行是否有 Index 相关内容
|
||||
@@ -43,7 +43,7 @@ if not found_end:
|
||||
with open(models_file, "w", encoding="utf-8") as f:
|
||||
f.writelines(keep_lines)
|
||||
|
||||
print(f"✅ 文件清理完成")
|
||||
print("✅ 文件清理完成")
|
||||
print(f"保留行数: {len(keep_lines)}")
|
||||
print(f"原始行数: {len(lines)}")
|
||||
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
||||
|
||||
@@ -4,20 +4,20 @@
|
||||
import re
|
||||
|
||||
# 读取原始文件
|
||||
with open('src/common/database/sqlalchemy_models.py', 'r', encoding='utf-8') as f:
|
||||
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 找到get_string_field函数的开始和结束
|
||||
get_string_field_start = content.find('# MySQL兼容的字段类型辅助函数')
|
||||
get_string_field_end = content.find('\n\nclass ChatStreams(Base):')
|
||||
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
|
||||
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
|
||||
get_string_field = content[get_string_field_start:get_string_field_end]
|
||||
|
||||
# 找到第一个class定义开始
|
||||
first_class_pos = content.find('class ChatStreams(Base):')
|
||||
first_class_pos = content.find("class ChatStreams(Base):")
|
||||
|
||||
# 找到所有class定义,直到遇到非class的def
|
||||
# 简单策略:找到所有以"class "开头且继承Base的类
|
||||
classes_pattern = r'class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)'
|
||||
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
|
||||
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
||||
|
||||
if matches:
|
||||
@@ -53,14 +53,14 @@ Base = declarative_base()
|
||||
|
||||
'''
|
||||
|
||||
new_content = header + get_string_field + '\n\n' + models_content
|
||||
new_content = header + get_string_field + "\n\n" + models_content
|
||||
|
||||
# 写入新文件
|
||||
with open('src/common/database/core/models.py', 'w', encoding='utf-8') as f:
|
||||
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
print('✅ Models file rewritten successfully')
|
||||
print(f'File size: {len(new_content)} characters')
|
||||
print("✅ Models file rewritten successfully")
|
||||
print(f"File size: {len(new_content)} characters")
|
||||
pattern = r"^class \w+\(Base\):"
|
||||
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
||||
print(f'Number of model classes: {model_count}')
|
||||
print(f"Number of model classes: {model_count}")
|
||||
|
||||
@@ -8,54 +8,53 @@
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
# 定义导入映射规则
|
||||
IMPORT_MAPPINGS = {
|
||||
# 模型导入
|
||||
r'from src\.common\.database\.sqlalchemy_models import (.+)':
|
||||
r'from src.common.database.core.models import \1',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.+)":
|
||||
r"from src.common.database.core.models import \1",
|
||||
|
||||
# API导入 - 需要特殊处理
|
||||
r'from src\.common\.database\.sqlalchemy_database_api import (.+)':
|
||||
r'from src.common.database.compatibility import \1',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
|
||||
r"from src.common.database.compatibility import \1",
|
||||
|
||||
# get_db_session 从 sqlalchemy_database_api 导入
|
||||
r'from src\.common\.database\.sqlalchemy_database_api import get_db_session':
|
||||
r'from src.common.database.core import get_db_session',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
|
||||
r"from src.common.database.core import get_db_session",
|
||||
|
||||
# get_db_session 从 sqlalchemy_models 导入
|
||||
r'from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)':
|
||||
lambda m: f'from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}'
|
||||
if 'get_db_session' in m.group(0) else m.group(0),
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
|
||||
if "get_db_session" in m.group(0) else m.group(0),
|
||||
|
||||
# get_engine 导入
|
||||
r'from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)':
|
||||
lambda m: f'from src.common.database.core import {m.group(1)}get_engine{m.group(2)}',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
|
||||
|
||||
# Base 导入
|
||||
r'from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)':
|
||||
lambda m: f'from src.common.database.core.models import {m.group(1)}Base{m.group(2)}',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
|
||||
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
|
||||
|
||||
# initialize_database 导入
|
||||
r'from src\.common\.database\.sqlalchemy_models import initialize_database':
|
||||
r'from src.common.database.core import check_and_migrate_database as initialize_database',
|
||||
|
||||
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_database",
|
||||
|
||||
# database.py 导入
|
||||
r'from src\.common\.database\.database import stop_database':
|
||||
r'from src.common.database.core import close_engine as stop_database',
|
||||
|
||||
r'from src\.common\.database\.database import initialize_sql_database':
|
||||
r'from src.common.database.core import check_and_migrate_database as initialize_sql_database',
|
||||
r"from src\.common\.database\.database import stop_database":
|
||||
r"from src.common.database.core import close_engine as stop_database",
|
||||
|
||||
r"from src\.common\.database\.database import initialize_sql_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
|
||||
}
|
||||
|
||||
# 需要排除的文件
|
||||
EXCLUDE_PATTERNS = [
|
||||
'**/database_refactoring_plan.md', # 文档文件
|
||||
'**/old/**', # 旧文件目录
|
||||
'**/sqlalchemy_*.py', # 旧的数据库文件本身
|
||||
'**/database.py', # 旧的database文件
|
||||
'**/db_*.py', # 旧的db文件
|
||||
"**/database_refactoring_plan.md", # 文档文件
|
||||
"**/old/**", # 旧文件目录
|
||||
"**/sqlalchemy_*.py", # 旧的数据库文件本身
|
||||
"**/database.py", # 旧的database文件
|
||||
"**/db_*.py", # 旧的db文件
|
||||
]
|
||||
|
||||
|
||||
@@ -67,47 +66,47 @@ def should_exclude(file_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int, List[str]]:
|
||||
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
|
||||
"""更新单个文件中的导入语句
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
dry_run: 是否只是预览而不实际修改
|
||||
|
||||
|
||||
Returns:
|
||||
(修改次数, 修改详情列表)
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
original_content = content
|
||||
changes = []
|
||||
|
||||
|
||||
# 应用每个映射规则
|
||||
for pattern, replacement in IMPORT_MAPPINGS.items():
|
||||
matches = list(re.finditer(pattern, content))
|
||||
for match in matches:
|
||||
old_line = match.group(0)
|
||||
|
||||
|
||||
# 处理函数类型的替换
|
||||
if callable(replacement):
|
||||
new_line_result = replacement(match)
|
||||
new_line = new_line_result if isinstance(new_line_result, str) else old_line
|
||||
else:
|
||||
new_line = re.sub(pattern, replacement, old_line)
|
||||
|
||||
|
||||
if old_line != new_line and isinstance(new_line, str):
|
||||
content = content.replace(old_line, new_line, 1)
|
||||
changes.append(f" - {old_line}")
|
||||
changes.append(f" + {new_line}")
|
||||
|
||||
|
||||
# 如果有修改且不是dry_run,写回文件
|
||||
if content != original_content:
|
||||
if not dry_run:
|
||||
file_path.write_text(content, encoding='utf-8')
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return len(changes) // 2, changes
|
||||
|
||||
|
||||
return 0, []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理文件 {file_path} 时出错: {e}")
|
||||
return 0, []
|
||||
@@ -116,34 +115,34 @@ def update_imports_in_file(file_path: Path, dry_run: bool = True) -> Tuple[int,
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🔍 搜索需要更新导入的文件...")
|
||||
|
||||
|
||||
# 获取项目根目录
|
||||
root_dir = Path(__file__).parent.parent
|
||||
|
||||
|
||||
# 搜索所有Python文件
|
||||
all_python_files = list(root_dir.rglob("*.py"))
|
||||
|
||||
|
||||
# 过滤掉排除的文件
|
||||
target_files = [f for f in all_python_files if not should_exclude(f)]
|
||||
|
||||
|
||||
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
# 第一遍:预览模式
|
||||
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
|
||||
|
||||
|
||||
files_to_update = []
|
||||
for file_path in target_files:
|
||||
count, changes = update_imports_in_file(file_path, dry_run=True)
|
||||
if count > 0:
|
||||
files_to_update.append((file_path, count, changes))
|
||||
|
||||
|
||||
if not files_to_update:
|
||||
print("✅ 没有文件需要更新!")
|
||||
return
|
||||
|
||||
|
||||
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
|
||||
|
||||
|
||||
total_changes = 0
|
||||
for file_path, count, changes in files_to_update:
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
@@ -153,23 +152,23 @@ def main():
|
||||
if len(changes) > 10:
|
||||
print(f" ... 还有 {len(changes) - 10} 行")
|
||||
total_changes += count
|
||||
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"\n📊 统计:")
|
||||
print("\n📊 统计:")
|
||||
print(f" - 需要更新的文件: {len(files_to_update)}")
|
||||
print(f" - 总修改次数: {total_changes}")
|
||||
|
||||
|
||||
# 询问是否继续
|
||||
print("\n" + "="*80)
|
||||
response = input("\n是否执行更新?(yes/no): ").strip().lower()
|
||||
|
||||
if response != 'yes':
|
||||
|
||||
if response != "yes":
|
||||
print("❌ 已取消更新")
|
||||
return
|
||||
|
||||
|
||||
# 第二遍:实际更新
|
||||
print("\n✨ 开始更新文件...\n")
|
||||
|
||||
|
||||
success_count = 0
|
||||
for file_path, _, _ in files_to_update:
|
||||
count, _ = update_imports_in_file(file_path, dry_run=False)
|
||||
@@ -177,7 +176,7 @@ def main():
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
print(f"✅ {rel_path} ({count} 处修改)")
|
||||
success_count += 1
|
||||
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"\n🎉 完成!成功更新 {success_count} 个文件")
|
||||
|
||||
|
||||
@@ -263,8 +263,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
@@ -291,8 +291,8 @@ class AntiPromptInjector:
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.core.models import Messages
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Messages
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from src.common.database.core.models import AntiInjectionStats
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import AntiInjectionStats
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.core.models import BanUser
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import BanUser
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
@@ -15,9 +15,9 @@ from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Emoji, Images
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -215,7 +215,7 @@ class MaiEmoji:
|
||||
else:
|
||||
await crud.delete(will_delete_emoji.id)
|
||||
result = 1 # Successfully deleted one record
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
try:
|
||||
# 使用CRUD进行查询
|
||||
crud = CRUDBase(Emoji)
|
||||
|
||||
|
||||
if emoji_hash:
|
||||
# 查询特定hash的表情包
|
||||
emoji_record = await crud.get_by(emoji_hash=emoji_hash)
|
||||
|
||||
@@ -9,9 +9,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("energy_system")
|
||||
@@ -203,7 +202,6 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
|
||||
# 从数据库获取聊天流兴趣分数
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.core.models import ChatStreams
|
||||
|
||||
|
||||
@@ -236,12 +236,12 @@ class ExpressionLearner:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式(带10分钟缓存)
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
|
||||
|
||||
优化: 使用CRUD和缓存,减少数据库访问
|
||||
"""
|
||||
# 使用静态方法以正确处理缓存键
|
||||
return await self._get_expressions_by_chat_id_cached(self.chat_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@cached(ttl=600, key_prefix="chat_expressions")
|
||||
async def _get_expressions_by_chat_id_cached(chat_id: str) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
@@ -278,7 +278,7 @@ class ExpressionLearner:
|
||||
async def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
|
||||
|
||||
优化: 使用CRUD批量处理所有更改,最后统一提交
|
||||
"""
|
||||
try:
|
||||
@@ -288,7 +288,7 @@ class ExpressionLearner:
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
|
||||
# 需要手动操作的情况下使用session
|
||||
async with get_db_session() as session:
|
||||
# 批量处理所有修改
|
||||
@@ -391,7 +391,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
crud = CRUDBase(Expression)
|
||||
CRUDBase(Expression)
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
async with get_db_session() as session:
|
||||
for new_expr in expr_list:
|
||||
@@ -437,10 +437,10 @@ class ExpressionLearner:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
await session.delete(expr)
|
||||
|
||||
|
||||
# 提交后清除相关缓存
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 清除该chat_id的表达方式缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
|
||||
@@ -9,10 +9,8 @@ from json_repair import repair_json
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -152,7 +150,7 @@ class ExpressionSelector:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
# 使用CRUD查询(由于需要IN条件,使用session)
|
||||
async with get_db_session() as session:
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
@@ -224,7 +222,7 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
affected_chat_ids.add(source_id)
|
||||
|
||||
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
@@ -247,7 +245,7 @@ class ExpressionSelector:
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 清除所有受影响的chat_id的缓存
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
|
||||
@@ -728,7 +728,6 @@ class MemorySystem:
|
||||
context = context or {}
|
||||
|
||||
# 所有记忆完全共享,统一使用 global 作用域,不区分用户
|
||||
resolved_user_id = GLOBAL_MEMORY_SCOPE
|
||||
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
start_time = time.time()
|
||||
|
||||
@@ -4,15 +4,14 @@ import time
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams # 新增导入
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
@@ -708,7 +707,7 @@ class ChatManager:
|
||||
# 使用CRUD批量查询
|
||||
crud = CRUDBase(ChatStreams)
|
||||
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
|
||||
|
||||
|
||||
for model_instance in all_streams:
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
|
||||
@@ -22,14 +22,14 @@ logger = get_logger("message_storage")
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
消息存储批处理器
|
||||
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
@@ -51,7 +51,7 @@ class MessageStorageBatcher:
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
@@ -67,7 +67,7 @@ class MessageStorageBatcher:
|
||||
async def add_message(self, message_data: dict):
|
||||
"""
|
||||
添加消息到批处理队列
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 包含消息对象和chat_stream的字典
|
||||
{
|
||||
@@ -97,23 +97,23 @@ class MessageStorageBatcher:
|
||||
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||
messages_dicts = []
|
||||
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data['message'],
|
||||
msg_data['chat_stream']
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"]
|
||||
)
|
||||
if message_dict:
|
||||
messages_dicts.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 批量写入数据库 - 使用高效的批量INSERT
|
||||
if messages_dicts:
|
||||
from sqlalchemy import insert
|
||||
@@ -122,7 +122,7 @@ class MessageStorageBatcher:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
success_count = len(messages_dicts)
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
@@ -134,18 +134,18 @@ class MessageStorageBatcher:
|
||||
|
||||
async def _prepare_message_dict(self, message, chat_stream):
|
||||
"""准备消息字典数据(用于批量INSERT)
|
||||
|
||||
|
||||
这个方法准备字典而不是ORM对象,性能更高
|
||||
"""
|
||||
message_obj = await self._prepare_message_object(message, chat_stream)
|
||||
if message_obj is None:
|
||||
return None
|
||||
|
||||
|
||||
# 将ORM对象转换为字典(只包含列字段)
|
||||
message_dict = {}
|
||||
for column in Messages.__table__.columns:
|
||||
message_dict[column.name] = getattr(message_obj, column.name)
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
is_public_notice = getattr(message, 'is_public_notice', False)
|
||||
notice_type = getattr(message, 'notice_type', None)
|
||||
actions = getattr(message, 'actions', None)
|
||||
should_reply = getattr(message, 'should_reply', None)
|
||||
should_act = getattr(message, 'should_act', None)
|
||||
additional_config = getattr(message, 'additional_config', None)
|
||||
is_public_notice = getattr(message, "is_public_notice", False)
|
||||
notice_type = getattr(message, "notice_type", None)
|
||||
actions = getattr(message, "actions", None)
|
||||
should_reply = getattr(message, "should_reply", None)
|
||||
should_act = getattr(message, "should_act", None)
|
||||
additional_config = getattr(message, "additional_config", None)
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
|
||||
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_storage_batcher: Optional[MessageStorageBatcher] = None
|
||||
_message_storage_batcher: MessageStorageBatcher | None = None
|
||||
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher:
|
||||
class MessageUpdateBatcher:
|
||||
"""
|
||||
消息更新批处理器
|
||||
|
||||
|
||||
优化: 将多个消息ID更新操作批量处理,减少数据库连接次数
|
||||
"""
|
||||
|
||||
@@ -478,7 +478,7 @@ class MessageStorage:
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
|
||||
"""
|
||||
存储消息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
chat_stream: 聊天流对象
|
||||
@@ -488,11 +488,11 @@ class MessageStorage:
|
||||
if use_batch:
|
||||
batcher = get_message_storage_batcher()
|
||||
await batcher.add_message({
|
||||
'message': message,
|
||||
'chat_stream': chat_stream
|
||||
"message": message,
|
||||
"chat_stream": chat_stream
|
||||
})
|
||||
return
|
||||
|
||||
|
||||
# 直接写入模式(保留用于特殊场景)
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -675,9 +675,9 @@ class MessageStorage:
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典
|
||||
use_batch: 是否使用批处理(默认True)
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.compatibility import db_get, db_query, db_save
|
||||
from src.common.database.compatibility import db_get, db_query
|
||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
@@ -12,8 +12,8 @@ from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.common.database.core.models import ImageDescriptions, Images
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import ImageDescriptions, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -25,8 +25,8 @@ from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from src.common.database.core.models import Videos
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import Videos
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -9,6 +9,37 @@
|
||||
"""
|
||||
|
||||
# ===== 核心层 =====
|
||||
# ===== API层 =====
|
||||
from src.common.database.api import (
|
||||
AggregateQuery,
|
||||
CRUDBase,
|
||||
QueryBuilder,
|
||||
# ChatStreams API
|
||||
get_active_streams,
|
||||
# Messages API
|
||||
get_chat_history,
|
||||
get_message_count,
|
||||
# PersonInfo API
|
||||
get_or_create_person,
|
||||
# ActionRecords API
|
||||
get_recent_actions,
|
||||
# LLMUsage API
|
||||
get_usage_statistics,
|
||||
record_llm_usage,
|
||||
# 业务API
|
||||
save_message,
|
||||
store_action_info,
|
||||
update_person_affinity,
|
||||
)
|
||||
|
||||
# ===== 兼容层(向后兼容旧API)=====
|
||||
from src.common.database.compatibility import (
|
||||
MODEL_MAPPING,
|
||||
build_filters,
|
||||
db_get,
|
||||
db_query,
|
||||
db_save,
|
||||
)
|
||||
from src.common.database.core import (
|
||||
Base,
|
||||
check_and_migrate_database,
|
||||
@@ -27,29 +58,6 @@ from src.common.database.optimization import (
|
||||
get_preloader,
|
||||
)
|
||||
|
||||
# ===== API层 =====
|
||||
from src.common.database.api import (
|
||||
AggregateQuery,
|
||||
CRUDBase,
|
||||
QueryBuilder,
|
||||
# ActionRecords API
|
||||
get_recent_actions,
|
||||
# ChatStreams API
|
||||
get_active_streams,
|
||||
# Messages API
|
||||
get_chat_history,
|
||||
get_message_count,
|
||||
# PersonInfo API
|
||||
get_or_create_person,
|
||||
# LLMUsage API
|
||||
get_usage_statistics,
|
||||
record_llm_usage,
|
||||
# 业务API
|
||||
save_message,
|
||||
store_action_info,
|
||||
update_person_affinity,
|
||||
)
|
||||
|
||||
# ===== Utils层 =====
|
||||
from src.common.database.utils import (
|
||||
cached,
|
||||
@@ -66,61 +74,52 @@ from src.common.database.utils import (
|
||||
transactional,
|
||||
)
|
||||
|
||||
# ===== 兼容层(向后兼容旧API)=====
|
||||
from src.common.database.compatibility import (
|
||||
MODEL_MAPPING,
|
||||
build_filters,
|
||||
db_get,
|
||||
db_query,
|
||||
db_save,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心层
|
||||
"Base",
|
||||
"get_engine",
|
||||
"get_session_factory",
|
||||
"get_db_session",
|
||||
"check_and_migrate_database",
|
||||
# 优化层
|
||||
"MultiLevelCache",
|
||||
"DataPreloader",
|
||||
"AdaptiveBatchScheduler",
|
||||
"get_cache",
|
||||
"get_preloader",
|
||||
"get_batch_scheduler",
|
||||
# API层 - 基础类
|
||||
"CRUDBase",
|
||||
"QueryBuilder",
|
||||
"AggregateQuery",
|
||||
# API层 - 业务API
|
||||
"store_action_info",
|
||||
"get_recent_actions",
|
||||
"get_chat_history",
|
||||
"get_message_count",
|
||||
"save_message",
|
||||
"get_or_create_person",
|
||||
"update_person_affinity",
|
||||
"get_active_streams",
|
||||
"record_llm_usage",
|
||||
"get_usage_statistics",
|
||||
# Utils层
|
||||
"retry",
|
||||
"timeout",
|
||||
"cached",
|
||||
"measure_time",
|
||||
"transactional",
|
||||
"db_operation",
|
||||
"get_monitor",
|
||||
"record_operation",
|
||||
"record_cache_hit",
|
||||
"record_cache_miss",
|
||||
"print_stats",
|
||||
"reset_stats",
|
||||
# 兼容层
|
||||
"MODEL_MAPPING",
|
||||
"AdaptiveBatchScheduler",
|
||||
"AggregateQuery",
|
||||
# 核心层
|
||||
"Base",
|
||||
# API层 - 基础类
|
||||
"CRUDBase",
|
||||
"DataPreloader",
|
||||
# 优化层
|
||||
"MultiLevelCache",
|
||||
"QueryBuilder",
|
||||
"build_filters",
|
||||
"cached",
|
||||
"check_and_migrate_database",
|
||||
"db_get",
|
||||
"db_operation",
|
||||
"db_query",
|
||||
"db_save",
|
||||
"db_get",
|
||||
"get_active_streams",
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_chat_history",
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
"get_message_count",
|
||||
"get_monitor",
|
||||
"get_or_create_person",
|
||||
"get_preloader",
|
||||
"get_recent_actions",
|
||||
"get_session_factory",
|
||||
"get_usage_statistics",
|
||||
"measure_time",
|
||||
"print_stats",
|
||||
"record_cache_hit",
|
||||
"record_cache_miss",
|
||||
"record_llm_usage",
|
||||
"record_operation",
|
||||
"reset_stats",
|
||||
# Utils层
|
||||
"retry",
|
||||
"save_message",
|
||||
# API层 - 业务API
|
||||
"store_action_info",
|
||||
"timeout",
|
||||
"transactional",
|
||||
"update_person_affinity",
|
||||
]
|
||||
|
||||
@@ -11,49 +11,49 @@ from src.common.database.api.query import AggregateQuery, QueryBuilder
|
||||
|
||||
# 业务特定API
|
||||
from src.common.database.api.specialized import (
|
||||
# ActionRecords
|
||||
get_recent_actions,
|
||||
store_action_info,
|
||||
# ChatStreams
|
||||
get_active_streams,
|
||||
get_or_create_chat_stream,
|
||||
# LLMUsage
|
||||
get_usage_statistics,
|
||||
record_llm_usage,
|
||||
# Messages
|
||||
get_chat_history,
|
||||
get_message_count,
|
||||
save_message,
|
||||
get_or_create_chat_stream,
|
||||
# PersonInfo
|
||||
get_or_create_person,
|
||||
update_person_affinity,
|
||||
# ActionRecords
|
||||
get_recent_actions,
|
||||
# LLMUsage
|
||||
get_usage_statistics,
|
||||
# UserRelationships
|
||||
get_user_relationship,
|
||||
record_llm_usage,
|
||||
save_message,
|
||||
store_action_info,
|
||||
update_person_affinity,
|
||||
update_relationship_affinity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AggregateQuery",
|
||||
# 基础类
|
||||
"CRUDBase",
|
||||
"QueryBuilder",
|
||||
"AggregateQuery",
|
||||
# ActionRecords API
|
||||
"store_action_info",
|
||||
"get_recent_actions",
|
||||
"get_active_streams",
|
||||
# Messages API
|
||||
"get_chat_history",
|
||||
"get_message_count",
|
||||
"save_message",
|
||||
# PersonInfo API
|
||||
"get_or_create_person",
|
||||
"update_person_affinity",
|
||||
# ChatStreams API
|
||||
"get_or_create_chat_stream",
|
||||
"get_active_streams",
|
||||
# LLMUsage API
|
||||
"record_llm_usage",
|
||||
# PersonInfo API
|
||||
"get_or_create_person",
|
||||
"get_recent_actions",
|
||||
"get_usage_statistics",
|
||||
# UserRelationships API
|
||||
"get_user_relationship",
|
||||
# LLMUsage API
|
||||
"record_llm_usage",
|
||||
"save_message",
|
||||
# ActionRecords API
|
||||
"store_action_info",
|
||||
"update_person_affinity",
|
||||
"update_relationship_affinity",
|
||||
]
|
||||
|
||||
@@ -110,12 +110,12 @@ class CRUDBase:
|
||||
if instance is not None:
|
||||
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
|
||||
instance_dict = _model_to_dict(instance)
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
# 从字典重建对象返回(detached状态,所有字段已加载)
|
||||
return _dict_to_model(self.model, instance_dict)
|
||||
|
||||
@@ -159,12 +159,12 @@ class CRUDBase:
|
||||
if instance is not None:
|
||||
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
|
||||
instance_dict = _model_to_dict(instance)
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
# 从字典重建对象返回(detached状态,所有字段已加载)
|
||||
return _dict_to_model(self.model, instance_dict)
|
||||
|
||||
@@ -206,7 +206,7 @@ class CRUDBase:
|
||||
# 应用过滤条件
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
@@ -219,12 +219,12 @@ class CRUDBase:
|
||||
|
||||
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
|
||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
||||
|
||||
@@ -266,13 +266,13 @@ class CRUDBase:
|
||||
await session.refresh(instance)
|
||||
# 注意:commit在get_db_session的context manager退出时自动执行
|
||||
# 但为了明确性,这里不需要显式commit
|
||||
|
||||
|
||||
# 注意:create不清除缓存,因为:
|
||||
# 1. 新记录不会影响已有的单条查询缓存(get/get_by)
|
||||
# 2. get_multi的缓存会自然过期(TTL机制)
|
||||
# 3. 清除所有缓存代价太大,影响性能
|
||||
# 如果需要强一致性,应该在查询时设置use_cache=False
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
async def update(
|
||||
@@ -397,7 +397,7 @@ class CRUDBase:
|
||||
# 应用过滤条件
|
||||
for key, value in filters.items():
|
||||
if hasattr(self.model, key):
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||
else:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
@@ -466,14 +466,14 @@ class CRUDBase:
|
||||
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
|
||||
|
||||
# 批量创建的缓存策略:
|
||||
# bulk_create通常用于批量导入场景,此时清除缓存是合理的
|
||||
# 因为可能创建大量记录,缓存的列表查询会明显过期
|
||||
cache = await get_cache()
|
||||
await cache.clear()
|
||||
logger.info(f"批量创建{len(instances)}条{self.model_name}记录后已清除缓存")
|
||||
|
||||
|
||||
return instances
|
||||
|
||||
async def bulk_update(
|
||||
|
||||
@@ -207,12 +207,12 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
|
||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
||||
|
||||
@@ -241,12 +241,12 @@ class QueryBuilder(Generic[T]):
|
||||
if instance is not None:
|
||||
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
|
||||
instance_dict = _model_to_dict(instance)
|
||||
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
|
||||
|
||||
# 从字典重建对象返回(detached状态,所有字段已加载)
|
||||
return _dict_to_model(self.model, instance_dict)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
@@ -42,11 +42,11 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: dict | None = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
@@ -55,7 +55,7 @@ async def store_action_info(
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
@@ -71,7 +71,7 @@ async def store_action_info(
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update(
|
||||
@@ -89,20 +89,20 @@ async def store_action_info(
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 使用get_or_create保存记录
|
||||
saved_record, created = await _action_records_crud.get_or_create(
|
||||
defaults=record_data,
|
||||
action_id=action_id,
|
||||
)
|
||||
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"成功存储动作信息: {action_name} (ID: {action_id})")
|
||||
return {col.name: getattr(saved_record, col.name) for col in saved_record.__table__.columns}
|
||||
else:
|
||||
logger.error(f"存储动作信息失败: {action_name}")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储动作信息时发生错误: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -113,11 +113,11 @@ async def get_recent_actions(
|
||||
limit: int = 10,
|
||||
) -> list[ActionRecords]:
|
||||
"""获取最近的动作记录
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
limit: 限制数量
|
||||
|
||||
|
||||
Returns:
|
||||
动作记录列表
|
||||
"""
|
||||
@@ -132,12 +132,12 @@ async def get_chat_history(
|
||||
offset: int = 0,
|
||||
) -> list[Messages]:
|
||||
"""获取聊天历史
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
@@ -153,10 +153,10 @@ async def get_chat_history(
|
||||
|
||||
async def get_message_count(stream_id: str) -> int:
|
||||
"""获取消息数量
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
|
||||
|
||||
Returns:
|
||||
消息数量
|
||||
"""
|
||||
@@ -167,13 +167,13 @@ async def get_message_count(stream_id: str) -> int:
|
||||
async def save_message(
|
||||
message_data: dict[str, Any],
|
||||
use_batch: bool = True,
|
||||
) -> Optional[Messages]:
|
||||
) -> Messages | None:
|
||||
"""保存消息
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 消息数据
|
||||
use_batch: 是否使用批处理
|
||||
|
||||
|
||||
Returns:
|
||||
保存的消息实例
|
||||
"""
|
||||
@@ -185,15 +185,15 @@ async def save_message(
|
||||
async def get_or_create_person(
|
||||
platform: str,
|
||||
person_id: str,
|
||||
defaults: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[Optional[PersonInfo], bool]:
|
||||
defaults: dict[str, Any] | None = None,
|
||||
) -> tuple[PersonInfo | None, bool]:
|
||||
"""获取或创建人员信息
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
person_id: 人员ID
|
||||
defaults: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
(人员信息实例, 是否新创建)
|
||||
"""
|
||||
@@ -210,12 +210,12 @@ async def update_person_affinity(
|
||||
affinity_delta: float,
|
||||
) -> bool:
|
||||
"""更新人员好感度
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
person_id: 人员ID
|
||||
affinity_delta: 好感度变化值
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
@@ -225,26 +225,26 @@ async def update_person_affinity(
|
||||
platform=platform,
|
||||
person_id=person_id,
|
||||
)
|
||||
|
||||
|
||||
if not person:
|
||||
logger.warning(f"人员不存在: {platform}/{person_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 更新好感度
|
||||
new_affinity = (person.affinity or 0.0) + affinity_delta
|
||||
await _person_info_crud.update(
|
||||
person.id,
|
||||
{"affinity": new_affinity},
|
||||
)
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
cache = await get_cache()
|
||||
cache_key = generate_cache_key("person_info", platform, person_id)
|
||||
await cache.delete(cache_key)
|
||||
|
||||
|
||||
logger.debug(f"更新好感度: {platform}/{person_id} {affinity_delta:+.2f} -> {new_affinity:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新好感度失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -255,15 +255,15 @@ async def update_person_affinity(
|
||||
async def get_or_create_chat_stream(
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
defaults: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[Optional[ChatStreams], bool]:
|
||||
defaults: dict[str, Any] | None = None,
|
||||
) -> tuple[ChatStreams | None, bool]:
|
||||
"""获取或创建聊天流
|
||||
|
||||
|
||||
Args:
|
||||
stream_id: 流ID
|
||||
platform: 平台
|
||||
defaults: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
(聊天流实例, 是否新创建)
|
||||
"""
|
||||
@@ -275,23 +275,23 @@ async def get_or_create_chat_stream(
|
||||
|
||||
|
||||
async def get_active_streams(
|
||||
platform: Optional[str] = None,
|
||||
platform: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[ChatStreams]:
|
||||
"""获取活跃的聊天流
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台(可选)
|
||||
limit: 限制数量
|
||||
|
||||
|
||||
Returns:
|
||||
聊天流列表
|
||||
"""
|
||||
query = QueryBuilder(ChatStreams)
|
||||
|
||||
|
||||
if platform:
|
||||
query = query.filter(platform=platform)
|
||||
|
||||
|
||||
return await query.order_by("-last_message_time").limit(limit).all()
|
||||
|
||||
|
||||
@@ -300,20 +300,20 @@ async def record_llm_usage(
|
||||
model_name: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
stream_id: Optional[str] = None,
|
||||
platform: Optional[str] = None,
|
||||
stream_id: str | None = None,
|
||||
platform: str | None = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat",
|
||||
model_assign_name: Optional[str] = None,
|
||||
model_api_provider: Optional[str] = None,
|
||||
model_assign_name: str | None = None,
|
||||
model_api_provider: str | None = None,
|
||||
endpoint: str = "/v1/chat/completions",
|
||||
cost: float = 0.0,
|
||||
status: str = "success",
|
||||
time_cost: Optional[float] = None,
|
||||
time_cost: float | None = None,
|
||||
use_batch: bool = True,
|
||||
) -> Optional[LLMUsage]:
|
||||
) -> LLMUsage | None:
|
||||
"""记录LLM使用情况
|
||||
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
input_tokens: 输入token数
|
||||
@@ -329,7 +329,7 @@ async def record_llm_usage(
|
||||
status: 状态
|
||||
time_cost: 时间成本
|
||||
use_batch: 是否使用批处理
|
||||
|
||||
|
||||
Returns:
|
||||
LLM使用记录实例
|
||||
"""
|
||||
@@ -346,37 +346,36 @@ async def record_llm_usage(
|
||||
"model_assign_name": model_assign_name or model_name,
|
||||
"model_api_provider": model_api_provider or "unknown",
|
||||
}
|
||||
|
||||
|
||||
if time_cost is not None:
|
||||
usage_data["time_cost"] = time_cost
|
||||
|
||||
|
||||
return await _llm_usage_crud.create(usage_data, use_batch=use_batch)
|
||||
|
||||
|
||||
async def get_usage_statistics(
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: float | None = None,
|
||||
end_time: float | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""获取使用统计
|
||||
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
model_name: 模型名称
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
from src.common.database.api.query import AggregateQuery
|
||||
|
||||
|
||||
query = AggregateQuery(LLMUsage)
|
||||
|
||||
|
||||
# 添加时间过滤
|
||||
if start_time:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import and_
|
||||
|
||||
async with get_db_session():
|
||||
|
||||
conditions = []
|
||||
if start_time:
|
||||
conditions.append(LLMUsage.timestamp >= start_time)
|
||||
@@ -384,15 +383,15 @@ async def get_usage_statistics(
|
||||
conditions.append(LLMUsage.timestamp <= end_time)
|
||||
if model_name:
|
||||
conditions.append(LLMUsage.model_name == model_name)
|
||||
|
||||
|
||||
if conditions:
|
||||
query._conditions = conditions
|
||||
|
||||
|
||||
# 聚合统计
|
||||
total_input = await query.sum("input_tokens")
|
||||
total_output = await query.sum("output_tokens")
|
||||
total_count = await query.filter().count() if hasattr(query, "count") else 0
|
||||
|
||||
|
||||
return {
|
||||
"total_input_tokens": int(total_input),
|
||||
"total_output_tokens": int(total_output),
|
||||
@@ -407,14 +406,14 @@ async def get_user_relationship(
|
||||
platform: str,
|
||||
user_id: str,
|
||||
target_id: str,
|
||||
) -> Optional[UserRelationships]:
|
||||
) -> UserRelationships | None:
|
||||
"""获取用户关系
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
target_id: 目标用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
用户关系实例
|
||||
"""
|
||||
@@ -432,13 +431,13 @@ async def update_relationship_affinity(
|
||||
affinity_delta: float,
|
||||
) -> bool:
|
||||
"""更新关系好感度
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台
|
||||
user_id: 用户ID
|
||||
target_id: 目标用户ID
|
||||
affinity_delta: 好感度变化值
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
@@ -450,15 +449,15 @@ async def update_relationship_affinity(
|
||||
user_id=user_id,
|
||||
target_id=target_id,
|
||||
)
|
||||
|
||||
|
||||
if not relationship:
|
||||
logger.error(f"无法创建关系: {platform}/{user_id}->{target_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 更新好感度和互动次数
|
||||
new_affinity = (relationship.affinity or 0.0) + affinity_delta
|
||||
new_count = (relationship.interaction_count or 0) + 1
|
||||
|
||||
|
||||
await _user_relationships_crud.update(
|
||||
relationship.id,
|
||||
{
|
||||
@@ -467,19 +466,19 @@ async def update_relationship_affinity(
|
||||
"last_interaction_time": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
cache = await get_cache()
|
||||
cache_key = generate_cache_key("user_relationship", platform, user_id, target_id)
|
||||
await cache.delete(cache_key)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"更新关系: {platform}/{user_id}->{target_id} "
|
||||
f"好感度{affinity_delta:+.2f}->{new_affinity:.2f} "
|
||||
f"互动{new_count}次"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新关系好感度失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -14,14 +14,14 @@ from .adapter import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 从 core 重新导出的函数
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
# 兼容层适配器
|
||||
"MODEL_MAPPING",
|
||||
"build_filters",
|
||||
"db_get",
|
||||
"db_query",
|
||||
"db_save",
|
||||
"db_get",
|
||||
# 从 core 重新导出的函数
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
"store_action_info",
|
||||
]
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
保持原有函数签名和行为不变
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import and_, asc, desc, select
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.api import (
|
||||
CRUDBase,
|
||||
QueryBuilder,
|
||||
)
|
||||
from src.common.database.api import (
|
||||
store_action_info as new_store_action_info,
|
||||
)
|
||||
from src.common.database.core.models import (
|
||||
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
|
||||
Messages,
|
||||
MonthlyPlan,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
PermissionNodes,
|
||||
PersonInfo,
|
||||
Schedule,
|
||||
ThinkingLog,
|
||||
UserPermissions,
|
||||
UserRelationships,
|
||||
Videos,
|
||||
)
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.compatibility")
|
||||
@@ -82,11 +79,11 @@ _crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items(
|
||||
|
||||
async def build_filters(model_class, filters: dict[str, Any]):
|
||||
"""构建查询过滤条件(兼容MongoDB风格操作符)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件字典
|
||||
|
||||
|
||||
Returns:
|
||||
条件列表
|
||||
"""
|
||||
@@ -127,16 +124,16 @@ async def build_filters(model_class, filters: dict[str, Any]):
|
||||
|
||||
def _model_to_dict(instance) -> dict[str, Any]:
|
||||
"""将模型实例转换为字典
|
||||
|
||||
|
||||
Args:
|
||||
instance: 模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
字典表示
|
||||
"""
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
|
||||
result = {}
|
||||
for column in instance.__table__.columns:
|
||||
result[column.name] = getattr(instance, column.name)
|
||||
@@ -145,15 +142,15 @@ def _model_to_dict(instance) -> dict[str, Any]:
|
||||
|
||||
async def db_query(
|
||||
model_class,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[list[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
data: dict[str, Any] | None = None,
|
||||
query_type: str | None = "get",
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[str] | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""执行异步数据库查询操作(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
@@ -162,7 +159,7 @@ async def db_query(
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
@@ -179,7 +176,7 @@ async def db_query(
|
||||
if query_type == "get":
|
||||
# 使用QueryBuilder
|
||||
query_builder = QueryBuilder(model_class)
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
# 将MongoDB风格过滤器转换为QueryBuilder格式
|
||||
@@ -202,15 +199,15 @@ async def db_query(
|
||||
query_builder = query_builder.filter(**{f"{field_name}__nin": op_value})
|
||||
else:
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
query_builder = query_builder.order_by(*order_by)
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query_builder = query_builder.limit(limit)
|
||||
|
||||
|
||||
# 执行查询
|
||||
if single_result:
|
||||
result = await query_builder.first()
|
||||
@@ -223,7 +220,7 @@ async def db_query(
|
||||
if not data:
|
||||
logger.error("创建操作需要提供data参数")
|
||||
return None
|
||||
|
||||
|
||||
instance = await crud.create(data)
|
||||
return _model_to_dict(instance)
|
||||
|
||||
@@ -231,17 +228,17 @@ async def db_query(
|
||||
if not filters or not data:
|
||||
logger.error("更新操作需要提供filters和data参数")
|
||||
return None
|
||||
|
||||
|
||||
# 先查找记录
|
||||
query_builder = QueryBuilder(model_class)
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
instance = await query_builder.first()
|
||||
if not instance:
|
||||
logger.warning(f"未找到匹配的记录: {filters}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新记录
|
||||
updated = await crud.update(instance.id, data)
|
||||
return _model_to_dict(updated)
|
||||
@@ -250,29 +247,29 @@ async def db_query(
|
||||
if not filters:
|
||||
logger.error("删除操作需要提供filters参数")
|
||||
return None
|
||||
|
||||
|
||||
# 先查找记录
|
||||
query_builder = QueryBuilder(model_class)
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
instance = await query_builder.first()
|
||||
if not instance:
|
||||
logger.warning(f"未找到匹配的记录: {filters}")
|
||||
return None
|
||||
|
||||
|
||||
# 删除记录
|
||||
success = await crud.delete(instance.id)
|
||||
return {"deleted": success}
|
||||
|
||||
elif query_type == "count":
|
||||
query_builder = QueryBuilder(model_class)
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
count = await query_builder.count()
|
||||
return {"count": count}
|
||||
|
||||
@@ -286,15 +283,15 @@ async def db_save(
|
||||
data: dict[str, Any],
|
||||
key_field: str,
|
||||
key_value: Any,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""保存或更新记录(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 数据字典
|
||||
key_field: 主键字段名
|
||||
key_value: 主键值
|
||||
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
@@ -303,15 +300,15 @@ async def db_save(
|
||||
crud = _crud_instances.get(model_name)
|
||||
if not crud:
|
||||
crud = CRUDBase(model_class)
|
||||
|
||||
|
||||
# 使用get_or_create (返回tuple[T, bool])
|
||||
instance, created = await crud.get_or_create(
|
||||
defaults=data,
|
||||
**{key_field: key_value},
|
||||
)
|
||||
|
||||
|
||||
return _model_to_dict(instance)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存数据库记录出错: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -319,20 +316,20 @@ async def db_save(
|
||||
|
||||
async def db_get(
|
||||
model_class,
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: str | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""从数据库获取记录(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
@@ -353,11 +350,11 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: dict | None = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""存储动作信息到数据库(兼容旧API)
|
||||
|
||||
|
||||
直接使用新的specialized API
|
||||
"""
|
||||
return await new_store_action_info(
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -16,50 +16,50 @@ logger = get_logger("database_config")
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""数据库配置"""
|
||||
|
||||
|
||||
# 基础配置
|
||||
db_type: str # "sqlite" 或 "mysql"
|
||||
url: str # 数据库连接URL
|
||||
|
||||
|
||||
# 引擎配置
|
||||
engine_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
# SQLite特定配置
|
||||
sqlite_path: Optional[str] = None
|
||||
|
||||
sqlite_path: str | None = None
|
||||
|
||||
# MySQL特定配置
|
||||
mysql_host: Optional[str] = None
|
||||
mysql_port: Optional[int] = None
|
||||
mysql_user: Optional[str] = None
|
||||
mysql_password: Optional[str] = None
|
||||
mysql_database: Optional[str] = None
|
||||
mysql_host: str | None = None
|
||||
mysql_port: int | None = None
|
||||
mysql_user: str | None = None
|
||||
mysql_password: str | None = None
|
||||
mysql_database: str | None = None
|
||||
mysql_charset: str = "utf8mb4"
|
||||
mysql_unix_socket: Optional[str] = None
|
||||
mysql_unix_socket: str | None = None
|
||||
|
||||
|
||||
_database_config: Optional[DatabaseConfig] = None
|
||||
_database_config: DatabaseConfig | None = None
|
||||
|
||||
|
||||
def get_database_config() -> DatabaseConfig:
|
||||
"""获取数据库配置
|
||||
|
||||
|
||||
从全局配置中读取数据库设置并构建配置对象
|
||||
"""
|
||||
global _database_config
|
||||
|
||||
|
||||
if _database_config is not None:
|
||||
return _database_config
|
||||
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
config = global_config.database
|
||||
|
||||
|
||||
# 构建数据库URL
|
||||
if config.database_type == "mysql":
|
||||
# MySQL配置
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
|
||||
if config.mysql_unix_socket:
|
||||
# Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
@@ -75,7 +75,7 @@ def get_database_config() -> DatabaseConfig:
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -90,7 +90,7 @@ def get_database_config() -> DatabaseConfig:
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_database_config = DatabaseConfig(
|
||||
db_type="mysql",
|
||||
url=url,
|
||||
@@ -103,12 +103,12 @@ def get_database_config() -> DatabaseConfig:
|
||||
mysql_charset=config.mysql_charset,
|
||||
mysql_unix_socket=config.mysql_unix_socket,
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"MySQL配置已加载: "
|
||||
f"{config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# SQLite配置
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
@@ -116,12 +116,12 @@ def get_database_config() -> DatabaseConfig:
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -130,16 +130,16 @@ def get_database_config() -> DatabaseConfig:
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_database_config = DatabaseConfig(
|
||||
db_type="sqlite",
|
||||
url=url,
|
||||
engine_kwargs=engine_kwargs,
|
||||
sqlite_path=db_path,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"SQLite配置已加载: {db_path}")
|
||||
|
||||
|
||||
return _database_config
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from .models import (
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Expression,
|
||||
get_string_field,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
@@ -37,30 +36,17 @@ from .models import (
|
||||
UserPermissions,
|
||||
UserRelationships,
|
||||
Videos,
|
||||
get_string_field,
|
||||
)
|
||||
from .session import get_db_session, get_db_session_direct, get_session_factory, reset_session_factory
|
||||
|
||||
__all__ = [
|
||||
# Engine
|
||||
"get_engine",
|
||||
"close_engine",
|
||||
"get_engine_info",
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
"get_session_factory",
|
||||
"reset_session_factory",
|
||||
# Migration
|
||||
"check_and_migrate_database",
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
# Models - Base
|
||||
"Base",
|
||||
"get_string_field",
|
||||
# Models - Tables (按字母顺序)
|
||||
"ActionRecords",
|
||||
"AntiInjectionStats",
|
||||
"BanUser",
|
||||
# Models - Base
|
||||
"Base",
|
||||
"BotPersonalityInterests",
|
||||
"CacheEntries",
|
||||
"ChatStreams",
|
||||
@@ -83,4 +69,18 @@ __all__ = [
|
||||
"UserPermissions",
|
||||
"UserRelationships",
|
||||
"Videos",
|
||||
# Migration
|
||||
"check_and_migrate_database",
|
||||
"close_engine",
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
# Engine
|
||||
"get_engine",
|
||||
"get_engine_info",
|
||||
"get_session_factory",
|
||||
"get_string_field",
|
||||
"reset_session_factory",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy import text
|
||||
@@ -18,49 +17,49 @@ from ..utils.exceptions import DatabaseInitializationError
|
||||
logger = get_logger("database.engine")
|
||||
|
||||
# 全局引擎实例
|
||||
_engine: Optional[AsyncEngine] = None
|
||||
_engine_lock: Optional[asyncio.Lock] = None
|
||||
_engine: AsyncEngine | None = None
|
||||
_engine_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
async def get_engine() -> AsyncEngine:
|
||||
"""获取全局数据库引擎(单例模式)
|
||||
|
||||
|
||||
Returns:
|
||||
AsyncEngine: SQLAlchemy异步引擎
|
||||
|
||||
|
||||
Raises:
|
||||
DatabaseInitializationError: 引擎初始化失败
|
||||
"""
|
||||
global _engine, _engine_lock
|
||||
|
||||
|
||||
# 快速路径:引擎已初始化
|
||||
if _engine is not None:
|
||||
return _engine
|
||||
|
||||
|
||||
# 延迟创建锁(避免在导入时创建)
|
||||
if _engine_lock is None:
|
||||
_engine_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 使用锁保护初始化过程
|
||||
async with _engine_lock:
|
||||
# 双重检查锁定模式
|
||||
if _engine is not None:
|
||||
return _engine
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
|
||||
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
|
||||
|
||||
|
||||
# 构建数据库URL和引擎参数
|
||||
if db_type == "mysql":
|
||||
# MySQL配置
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
|
||||
if config.mysql_unix_socket:
|
||||
# Unix socket连接
|
||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||
@@ -76,7 +75,7 @@ async def get_engine() -> AsyncEngine:
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -91,11 +90,11 @@ async def get_engine() -> AsyncEngine:
|
||||
"connect_timeout": config.connection_timeout,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
logger.info(
|
||||
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# SQLite配置
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
@@ -103,12 +102,12 @@ async def get_engine() -> AsyncEngine:
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
@@ -117,19 +116,19 @@ async def get_engine() -> AsyncEngine:
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
|
||||
|
||||
# 创建异步引擎
|
||||
_engine = create_async_engine(url, **engine_kwargs)
|
||||
|
||||
|
||||
# SQLite特定优化
|
||||
if db_type == "sqlite":
|
||||
await _enable_sqlite_optimizations(_engine)
|
||||
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 数据库引擎初始化失败: {e}", exc_info=True)
|
||||
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||
@@ -137,11 +136,11 @@ async def get_engine() -> AsyncEngine:
|
||||
|
||||
async def close_engine():
|
||||
"""关闭数据库引擎
|
||||
|
||||
|
||||
释放所有连接池资源
|
||||
"""
|
||||
global _engine
|
||||
|
||||
|
||||
if _engine is not None:
|
||||
logger.info("正在关闭数据库引擎...")
|
||||
await _engine.dispose()
|
||||
@@ -151,13 +150,13 @@ async def close_engine():
|
||||
|
||||
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
"""启用SQLite性能优化
|
||||
|
||||
|
||||
优化项:
|
||||
- WAL模式:提高并发性能
|
||||
- NORMAL同步:平衡性能和安全性
|
||||
- 启用外键约束
|
||||
- 设置busy_timeout:避免锁定错误
|
||||
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy异步引擎
|
||||
"""
|
||||
@@ -175,22 +174,22 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
await conn.execute(text("PRAGMA cache_size = -10000"))
|
||||
# 临时存储使用内存
|
||||
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
||||
|
||||
|
||||
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def get_engine_info() -> dict:
|
||||
"""获取引擎信息(用于监控和调试)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 引擎信息字典
|
||||
"""
|
||||
try:
|
||||
engine = await get_engine()
|
||||
|
||||
|
||||
info = {
|
||||
"name": engine.name,
|
||||
"driver": engine.driver,
|
||||
@@ -199,9 +198,9 @@ async def get_engine_info() -> dict:
|
||||
"pool_checked_out": getattr(engine.pool, "checked_out", lambda: 0)(),
|
||||
"pool_overflow": getattr(engine.pool, "overflow", lambda: 0)(),
|
||||
}
|
||||
|
||||
|
||||
return info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取引擎信息失败: {e}")
|
||||
return {}
|
||||
|
||||
@@ -20,15 +20,15 @@ logger = get_logger("db_migration")
|
||||
|
||||
async def check_and_migrate_database(existing_engine=None):
|
||||
"""异步检查数据库结构并自动迁移
|
||||
|
||||
|
||||
自动执行以下操作:
|
||||
- 创建不存在的表
|
||||
- 为现有表添加缺失的列
|
||||
- 为现有表创建缺失的索引
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎
|
||||
|
||||
|
||||
Note:
|
||||
此函数是幂等的,可以安全地多次调用
|
||||
"""
|
||||
@@ -65,7 +65,7 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
for table in tables_to_create:
|
||||
logger.info(f"表 '{table.name}' 创建成功。")
|
||||
db_table_names.add(table.name) # 将新创建的表添加到集合中
|
||||
|
||||
|
||||
# 提交表创建事务
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
@@ -191,40 +191,40 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
|
||||
async def create_all_tables(existing_engine=None):
|
||||
"""创建所有表(不进行迁移检查)
|
||||
|
||||
|
||||
直接创建所有在 Base.metadata 中定义的表。
|
||||
如果表已存在,将被跳过。
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎
|
||||
|
||||
|
||||
Note:
|
||||
生产环境建议使用 check_and_migrate_database()
|
||||
"""
|
||||
logger.info("正在创建所有数据库表...")
|
||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||
|
||||
|
||||
async with engine.begin() as connection:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
logger.info("数据库表创建完成。")
|
||||
|
||||
|
||||
async def drop_all_tables(existing_engine=None):
|
||||
"""删除所有表(危险操作!)
|
||||
|
||||
|
||||
删除所有在 Base.metadata 中定义的表。
|
||||
|
||||
|
||||
Args:
|
||||
existing_engine: 可选的已存在的数据库引擎
|
||||
|
||||
|
||||
Warning:
|
||||
此操作将删除所有数据,不可恢复!仅用于测试环境!
|
||||
"""
|
||||
logger.warning("⚠️ 正在删除所有数据库表...")
|
||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||
|
||||
|
||||
async with engine.begin() as connection:
|
||||
await connection.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
logger.warning("所有数据库表已删除。")
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -18,38 +17,38 @@ from .engine import get_engine
|
||||
logger = get_logger("database.session")
|
||||
|
||||
# 全局会话工厂
|
||||
_session_factory: Optional[async_sessionmaker] = None
|
||||
_factory_lock: Optional[asyncio.Lock] = None
|
||||
_session_factory: async_sessionmaker | None = None
|
||||
_factory_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
async def get_session_factory() -> async_sessionmaker:
|
||||
"""获取会话工厂(单例模式)
|
||||
|
||||
|
||||
Returns:
|
||||
async_sessionmaker: SQLAlchemy异步会话工厂
|
||||
"""
|
||||
global _session_factory, _factory_lock
|
||||
|
||||
|
||||
# 快速路径
|
||||
if _session_factory is not None:
|
||||
return _session_factory
|
||||
|
||||
|
||||
# 延迟创建锁
|
||||
if _factory_lock is None:
|
||||
_factory_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async with _factory_lock:
|
||||
# 双重检查
|
||||
if _session_factory is not None:
|
||||
return _session_factory
|
||||
|
||||
|
||||
engine = await get_engine()
|
||||
_session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False, # 避免在commit后访问属性时重新查询
|
||||
)
|
||||
|
||||
|
||||
logger.debug("会话工厂已创建")
|
||||
return _session_factory
|
||||
|
||||
@@ -57,28 +56,28 @@ async def get_session_factory() -> async_sessionmaker:
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话上下文管理器
|
||||
|
||||
|
||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||
|
||||
|
||||
使用示例:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(User))
|
||||
users = result.scalars().all()
|
||||
|
||||
|
||||
Yields:
|
||||
AsyncSession: SQLAlchemy异步会话对象
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from ..optimization.connection_pool import get_connection_pool_manager
|
||||
|
||||
|
||||
session_factory = await get_session_factory()
|
||||
pool_manager = get_connection_pool_manager()
|
||||
|
||||
|
||||
# 使用连接池管理器(透明复用连接)
|
||||
async with pool_manager.get_session(session_factory) as session:
|
||||
# 为SQLite设置特定的PRAGMA
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
if global_config.database.database_type == "sqlite":
|
||||
try:
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
@@ -86,22 +85,22 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
except Exception:
|
||||
# 复用连接时PRAGMA可能已设置,忽略错误
|
||||
pass
|
||||
|
||||
|
||||
yield session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话(直接模式,不使用连接池)
|
||||
|
||||
|
||||
用于特殊场景,如需要完全独立的连接时。
|
||||
一般情况下应使用 get_db_session()。
|
||||
|
||||
|
||||
Yields:
|
||||
AsyncSession: SQLAlchemy异步会话对象
|
||||
"""
|
||||
session_factory = await get_session_factory()
|
||||
|
||||
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
|
||||
@@ -11,17 +11,17 @@ from .batch_scheduler import (
|
||||
AdaptiveBatchScheduler,
|
||||
BatchOperation,
|
||||
BatchStats,
|
||||
Priority,
|
||||
close_batch_scheduler,
|
||||
get_batch_scheduler,
|
||||
Priority,
|
||||
)
|
||||
from .cache_manager import (
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
close_cache,
|
||||
get_cache,
|
||||
LRUCache,
|
||||
MultiLevelCache,
|
||||
close_cache,
|
||||
get_cache,
|
||||
)
|
||||
from .connection_pool import (
|
||||
ConnectionPoolManager,
|
||||
@@ -31,36 +31,36 @@ from .connection_pool import (
|
||||
)
|
||||
from .preloader import (
|
||||
AccessPattern,
|
||||
close_preloader,
|
||||
CommonDataPreloader,
|
||||
DataPreloader,
|
||||
close_preloader,
|
||||
get_preloader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Connection Pool
|
||||
"ConnectionPoolManager",
|
||||
"get_connection_pool_manager",
|
||||
"start_connection_pool",
|
||||
"stop_connection_pool",
|
||||
# Cache
|
||||
"MultiLevelCache",
|
||||
"LRUCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"close_cache",
|
||||
# Preloader
|
||||
"DataPreloader",
|
||||
"CommonDataPreloader",
|
||||
"AccessPattern",
|
||||
"get_preloader",
|
||||
"close_preloader",
|
||||
# Batch Scheduler
|
||||
"AdaptiveBatchScheduler",
|
||||
"BatchOperation",
|
||||
"BatchStats",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"CommonDataPreloader",
|
||||
# Connection Pool
|
||||
"ConnectionPoolManager",
|
||||
# Preloader
|
||||
"DataPreloader",
|
||||
"LRUCache",
|
||||
# Cache
|
||||
"MultiLevelCache",
|
||||
"Priority",
|
||||
"get_batch_scheduler",
|
||||
"close_batch_scheduler",
|
||||
"close_cache",
|
||||
"close_preloader",
|
||||
"get_batch_scheduler",
|
||||
"get_cache",
|
||||
"get_connection_pool_manager",
|
||||
"get_preloader",
|
||||
"start_connection_pool",
|
||||
"stop_connection_pool",
|
||||
]
|
||||
|
||||
@@ -10,12 +10,12 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
@@ -36,22 +36,22 @@ class Priority(IntEnum):
|
||||
@dataclass
|
||||
class BatchOperation:
|
||||
"""批量操作"""
|
||||
|
||||
|
||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||
model_class: type
|
||||
conditions: dict[str, Any] = field(default_factory=dict)
|
||||
data: Optional[dict[str, Any]] = None
|
||||
callback: Optional[Callable] = None
|
||||
future: Optional[asyncio.Future] = None
|
||||
data: dict[str, Any] | None = None
|
||||
callback: Callable | None = None
|
||||
future: asyncio.Future | None = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
priority: Priority = Priority.NORMAL
|
||||
timeout: Optional[float] = None # 超时时间(秒)
|
||||
timeout: float | None = None # 超时时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStats:
|
||||
"""批处理统计"""
|
||||
|
||||
|
||||
total_operations: int = 0
|
||||
batched_operations: int = 0
|
||||
cache_hits: int = 0
|
||||
@@ -60,7 +60,7 @@ class BatchStats:
|
||||
avg_wait_time: float = 0.0
|
||||
timeout_count: int = 0
|
||||
error_count: int = 0
|
||||
|
||||
|
||||
# 自适应统计
|
||||
last_batch_duration: float = 0.0
|
||||
last_batch_size: int = 0
|
||||
@@ -69,7 +69,7 @@ class BatchStats:
|
||||
|
||||
class AdaptiveBatchScheduler:
|
||||
"""自适应批量调度器
|
||||
|
||||
|
||||
特性:
|
||||
- 动态批次大小:根据负载自动调整
|
||||
- 优先级队列:高优先级操作优先执行
|
||||
@@ -87,7 +87,7 @@ class AdaptiveBatchScheduler:
|
||||
cache_ttl: float = 5.0,
|
||||
):
|
||||
"""初始化调度器
|
||||
|
||||
|
||||
Args:
|
||||
min_batch_size: 最小批次大小
|
||||
max_batch_size: 最大批次大小
|
||||
@@ -104,23 +104,23 @@ class AdaptiveBatchScheduler:
|
||||
self.current_wait_time = base_wait_time
|
||||
self.max_queue_size = max_queue_size
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
|
||||
# 操作队列,按优先级分类
|
||||
self.operation_queues: dict[Priority, deque[BatchOperation]] = {
|
||||
priority: deque() for priority in Priority
|
||||
}
|
||||
|
||||
|
||||
# 调度控制
|
||||
self._scheduler_task: Optional[asyncio.Task] = None
|
||||
self._scheduler_task: asyncio.Task | None = None
|
||||
self._is_running = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
# 统计信息
|
||||
self.stats = BatchStats()
|
||||
|
||||
|
||||
# 简单的结果缓存
|
||||
self._result_cache: dict[str, tuple[Any, float]] = {}
|
||||
|
||||
|
||||
logger.info(
|
||||
f"自适应批量调度器初始化: "
|
||||
f"批次大小{min_batch_size}-{max_batch_size}, "
|
||||
@@ -132,7 +132,7 @@ class AdaptiveBatchScheduler:
|
||||
if self._is_running:
|
||||
logger.warning("调度器已在运行")
|
||||
return
|
||||
|
||||
|
||||
self._is_running = True
|
||||
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
|
||||
logger.info("批量调度器已启动")
|
||||
@@ -141,16 +141,16 @@ class AdaptiveBatchScheduler:
|
||||
"""停止调度器"""
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
|
||||
self._is_running = False
|
||||
|
||||
|
||||
if self._scheduler_task:
|
||||
self._scheduler_task.cancel()
|
||||
try:
|
||||
await self._scheduler_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# 处理剩余操作
|
||||
await self._flush_all_queues()
|
||||
logger.info("批量调度器已停止")
|
||||
@@ -160,10 +160,10 @@ class AdaptiveBatchScheduler:
|
||||
operation: BatchOperation,
|
||||
) -> asyncio.Future:
|
||||
"""添加操作到队列
|
||||
|
||||
|
||||
Args:
|
||||
operation: 批量操作
|
||||
|
||||
|
||||
Returns:
|
||||
Future对象,可用于获取结果
|
||||
"""
|
||||
@@ -175,11 +175,11 @@ class AdaptiveBatchScheduler:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future.set_result(cached_result)
|
||||
return future
|
||||
|
||||
|
||||
# 创建future
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
operation.future = future
|
||||
|
||||
|
||||
async with self._lock:
|
||||
# 检查队列是否已满
|
||||
total_queued = sum(len(q) for q in self.operation_queues.values())
|
||||
@@ -191,7 +191,7 @@ class AdaptiveBatchScheduler:
|
||||
# 添加到优先级队列
|
||||
self.operation_queues[operation.priority].append(operation)
|
||||
self.stats.total_operations += 1
|
||||
|
||||
|
||||
return future
|
||||
|
||||
async def _scheduler_loop(self) -> None:
|
||||
@@ -217,10 +217,10 @@ class AdaptiveBatchScheduler:
|
||||
for _ in range(count):
|
||||
if queue:
|
||||
operations.append(queue.popleft())
|
||||
|
||||
|
||||
if not operations:
|
||||
return
|
||||
|
||||
|
||||
# 执行批量操作
|
||||
await self._execute_operations(operations)
|
||||
|
||||
@@ -231,10 +231,10 @@ class AdaptiveBatchScheduler:
|
||||
"""执行批量操作"""
|
||||
if not operations:
|
||||
return
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
batch_size = len(operations)
|
||||
|
||||
|
||||
try:
|
||||
# 检查超时
|
||||
valid_operations = []
|
||||
@@ -246,41 +246,41 @@ class AdaptiveBatchScheduler:
|
||||
self.stats.timeout_count += 1
|
||||
else:
|
||||
valid_operations.append(op)
|
||||
|
||||
|
||||
if not valid_operations:
|
||||
return
|
||||
|
||||
|
||||
# 按操作类型分组
|
||||
op_groups = defaultdict(list)
|
||||
for op in valid_operations:
|
||||
key = f"{op.operation_type}_{op.model_class.__name__}"
|
||||
op_groups[key].append(op)
|
||||
|
||||
|
||||
# 执行各组操作
|
||||
for group_key, ops in op_groups.items():
|
||||
for ops in op_groups.values():
|
||||
await self._execute_group(ops)
|
||||
|
||||
|
||||
# 更新统计
|
||||
duration = time.time() - start_time
|
||||
self.stats.batched_operations += batch_size
|
||||
self.stats.total_execution_time += duration
|
||||
self.stats.last_batch_duration = duration
|
||||
self.stats.last_batch_size = batch_size
|
||||
|
||||
|
||||
if self.stats.batched_operations > 0:
|
||||
self.stats.avg_batch_size = (
|
||||
self.stats.batched_operations /
|
||||
self.stats.batched_operations /
|
||||
(self.stats.total_execution_time / duration)
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"批量执行完成: {batch_size}个操作, 耗时{duration*1000:.2f}ms"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量操作执行失败: {e}", exc_info=True)
|
||||
self.stats.error_count += 1
|
||||
|
||||
|
||||
# 设置所有future的异常
|
||||
for op in operations:
|
||||
if op.future and not op.future.done():
|
||||
@@ -290,9 +290,9 @@ class AdaptiveBatchScheduler:
|
||||
"""执行同类操作组"""
|
||||
if not operations:
|
||||
return
|
||||
|
||||
|
||||
op_type = operations[0].operation_type
|
||||
|
||||
|
||||
try:
|
||||
if op_type == "select":
|
||||
await self._execute_select_batch(operations)
|
||||
@@ -304,7 +304,7 @@ class AdaptiveBatchScheduler:
|
||||
await self._execute_delete_batch(operations)
|
||||
else:
|
||||
raise ValueError(f"未知操作类型: {op_type}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行{op_type}操作组失败: {e}", exc_info=True)
|
||||
for op in operations:
|
||||
@@ -323,30 +323,30 @@ class AdaptiveBatchScheduler:
|
||||
stmt = select(op.model_class)
|
||||
for key, value in op.conditions.items():
|
||||
attr = getattr(op.model_class, key)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(attr.in_(value))
|
||||
else:
|
||||
stmt = stmt.where(attr == value)
|
||||
|
||||
|
||||
# 执行查询
|
||||
result = await session.execute(stmt)
|
||||
data = result.scalars().all()
|
||||
|
||||
|
||||
# 设置结果
|
||||
if op.future and not op.future.done():
|
||||
op.future.set_result(data)
|
||||
|
||||
|
||||
# 缓存结果
|
||||
cache_key = self._generate_cache_key(op)
|
||||
self._set_cache(cache_key, data)
|
||||
|
||||
|
||||
# 执行回调
|
||||
if op.callback:
|
||||
try:
|
||||
op.callback(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"回调执行失败: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询失败: {e}", exc_info=True)
|
||||
if op.future and not op.future.done():
|
||||
@@ -363,23 +363,23 @@ class AdaptiveBatchScheduler:
|
||||
all_data = [op.data for op in operations if op.data]
|
||||
if not all_data:
|
||||
return
|
||||
|
||||
|
||||
# 批量插入
|
||||
stmt = insert(operations[0].model_class).values(all_data)
|
||||
result = await session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 设置结果
|
||||
for op in operations:
|
||||
if op.future and not op.future.done():
|
||||
op.future.set_result(True)
|
||||
|
||||
|
||||
if op.callback:
|
||||
try:
|
||||
op.callback(True)
|
||||
except Exception as e:
|
||||
logger.warning(f"回调执行失败: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量插入失败: {e}", exc_info=True)
|
||||
await session.rollback()
|
||||
@@ -402,28 +402,28 @@ class AdaptiveBatchScheduler:
|
||||
for key, value in op.conditions.items():
|
||||
attr = getattr(op.model_class, key)
|
||||
stmt = stmt.where(attr == value)
|
||||
|
||||
|
||||
if op.data:
|
||||
stmt = stmt.values(**op.data)
|
||||
|
||||
|
||||
# 执行更新(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
|
||||
|
||||
# 所有操作成功后,一次性commit
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 设置所有操作的结果
|
||||
for op, rowcount in results:
|
||||
if op.future and not op.future.done():
|
||||
op.future.set_result(rowcount)
|
||||
|
||||
|
||||
if op.callback:
|
||||
try:
|
||||
op.callback(rowcount)
|
||||
except Exception as e:
|
||||
logger.warning(f"回调执行失败: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新失败: {e}", exc_info=True)
|
||||
await session.rollback()
|
||||
@@ -447,25 +447,25 @@ class AdaptiveBatchScheduler:
|
||||
for key, value in op.conditions.items():
|
||||
attr = getattr(op.model_class, key)
|
||||
stmt = stmt.where(attr == value)
|
||||
|
||||
|
||||
# 执行删除(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
|
||||
|
||||
# 所有操作成功后,一次性commit
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 设置所有操作的结果
|
||||
for op, rowcount in results:
|
||||
if op.future and not op.future.done():
|
||||
op.future.set_result(rowcount)
|
||||
|
||||
|
||||
if op.callback:
|
||||
try:
|
||||
op.callback(rowcount)
|
||||
except Exception as e:
|
||||
logger.warning(f"回调执行失败: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除失败: {e}", exc_info=True)
|
||||
await session.rollback()
|
||||
@@ -479,7 +479,7 @@ class AdaptiveBatchScheduler:
|
||||
# 计算拥塞评分
|
||||
total_queued = sum(len(q) for q in self.operation_queues.values())
|
||||
self.stats.congestion_score = min(1.0, total_queued / self.max_queue_size)
|
||||
|
||||
|
||||
# 根据拥塞情况调整批次大小
|
||||
if self.stats.congestion_score > 0.7:
|
||||
# 高拥塞,增加批次大小
|
||||
@@ -493,7 +493,7 @@ class AdaptiveBatchScheduler:
|
||||
self.min_batch_size,
|
||||
int(self.current_batch_size * 0.9),
|
||||
)
|
||||
|
||||
|
||||
# 根据批次执行时间调整等待时间
|
||||
if self.stats.last_batch_duration > 0:
|
||||
if self.stats.last_batch_duration > self.current_wait_time * 2:
|
||||
@@ -518,7 +518,7 @@ class AdaptiveBatchScheduler:
|
||||
]
|
||||
return "|".join(key_parts)
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
|
||||
def _get_from_cache(self, cache_key: str) -> Any | None:
|
||||
"""从缓存获取结果"""
|
||||
if cache_key in self._result_cache:
|
||||
result, timestamp = self._result_cache[cache_key]
|
||||
@@ -551,27 +551,27 @@ class AdaptiveBatchScheduler:
|
||||
|
||||
|
||||
# 全局调度器实例
|
||||
_global_scheduler: Optional[AdaptiveBatchScheduler] = None
|
||||
_global_scheduler: AdaptiveBatchScheduler | None = None
|
||||
_scheduler_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_batch_scheduler() -> AdaptiveBatchScheduler:
|
||||
"""获取全局批量调度器(单例)"""
|
||||
global _global_scheduler
|
||||
|
||||
|
||||
if _global_scheduler is None:
|
||||
async with _scheduler_lock:
|
||||
if _global_scheduler is None:
|
||||
_global_scheduler = AdaptiveBatchScheduler()
|
||||
await _global_scheduler.start()
|
||||
|
||||
|
||||
return _global_scheduler
|
||||
|
||||
|
||||
async def close_batch_scheduler() -> None:
|
||||
"""关闭全局批量调度器"""
|
||||
global _global_scheduler
|
||||
|
||||
|
||||
if _global_scheduler is not None:
|
||||
await _global_scheduler.stop()
|
||||
_global_scheduler = None
|
||||
|
||||
@@ -11,8 +11,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -24,7 +25,7 @@ T = TypeVar("T")
|
||||
@dataclass
|
||||
class CacheEntry(Generic[T]):
|
||||
"""缓存条目
|
||||
|
||||
|
||||
Attributes:
|
||||
value: 缓存的值
|
||||
created_at: 创建时间戳
|
||||
@@ -42,7 +43,7 @@ class CacheEntry(Generic[T]):
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""缓存统计信息
|
||||
|
||||
|
||||
Attributes:
|
||||
hits: 命中次数
|
||||
misses: 未命中次数
|
||||
@@ -70,7 +71,7 @@ class CacheStats:
|
||||
|
||||
class LRUCache(Generic[T]):
|
||||
"""LRU缓存实现
|
||||
|
||||
|
||||
使用OrderedDict实现O(1)的get/set操作
|
||||
"""
|
||||
|
||||
@@ -81,7 +82,7 @@ class LRUCache(Generic[T]):
|
||||
name: str = "cache",
|
||||
):
|
||||
"""初始化LRU缓存
|
||||
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存条目数
|
||||
ttl: 过期时间(秒)
|
||||
@@ -94,18 +95,18 @@ class LRUCache(Generic[T]):
|
||||
self._lock = asyncio.Lock()
|
||||
self._stats = CacheStats()
|
||||
|
||||
async def get(self, key: str) -> Optional[T]:
|
||||
async def get(self, key: str) -> T | None:
|
||||
"""获取缓存值
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
|
||||
Returns:
|
||||
缓存值,如果不存在或已过期返回None
|
||||
"""
|
||||
async with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
|
||||
|
||||
if entry is None:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
@@ -125,20 +126,20 @@ class LRUCache(Generic[T]):
|
||||
entry.last_accessed = now
|
||||
entry.access_count += 1
|
||||
self._stats.hits += 1
|
||||
|
||||
|
||||
# 移到末尾(最近使用)
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
|
||||
return entry.value
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: T,
|
||||
size: Optional[int] = None,
|
||||
size: int | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
@@ -146,16 +147,16 @@ class LRUCache(Generic[T]):
|
||||
"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
|
||||
# 如果键已存在,更新值
|
||||
if key in self._cache:
|
||||
old_entry = self._cache[key]
|
||||
self._stats.total_size -= old_entry.size
|
||||
|
||||
|
||||
# 估算大小
|
||||
if size is None:
|
||||
size = self._estimate_size(value)
|
||||
|
||||
|
||||
# 创建新条目
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
@@ -164,7 +165,7 @@ class LRUCache(Generic[T]):
|
||||
access_count=0,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
# 如果缓存已满,淘汰最久未使用的条目
|
||||
while len(self._cache) >= self.max_size:
|
||||
oldest_key, oldest_entry = self._cache.popitem(last=False)
|
||||
@@ -175,7 +176,7 @@ class LRUCache(Generic[T]):
|
||||
f"[{self.name}] 淘汰缓存条目: {oldest_key} "
|
||||
f"(访问{oldest_entry.access_count}次)"
|
||||
)
|
||||
|
||||
|
||||
# 添加新条目
|
||||
self._cache[key] = entry
|
||||
self._stats.item_count += 1
|
||||
@@ -183,10 +184,10 @@ class LRUCache(Generic[T]):
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""删除缓存条目
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
@@ -217,7 +218,7 @@ class LRUCache(Generic[T]):
|
||||
|
||||
def _estimate_size(self, value: Any) -> int:
|
||||
"""估算数据大小(字节)
|
||||
|
||||
|
||||
这是一个简单的估算,实际大小可能不同
|
||||
"""
|
||||
import sys
|
||||
@@ -230,11 +231,11 @@ class LRUCache(Generic[T]):
|
||||
|
||||
class MultiLevelCache:
|
||||
"""多级缓存管理器
|
||||
|
||||
|
||||
实现两级缓存架构:
|
||||
- L1: 高速缓存,小容量,短TTL
|
||||
- L2: 扩展缓存,大容量,长TTL
|
||||
|
||||
|
||||
查询时先查L1,未命中再查L2,未命中再从数据源加载
|
||||
"""
|
||||
|
||||
@@ -246,7 +247,7 @@ class MultiLevelCache:
|
||||
l2_ttl: float = 300,
|
||||
):
|
||||
"""初始化多级缓存
|
||||
|
||||
|
||||
Args:
|
||||
l1_max_size: L1缓存最大条目数
|
||||
l1_ttl: L1缓存TTL(秒)
|
||||
@@ -255,8 +256,8 @@ class MultiLevelCache:
|
||||
"""
|
||||
self.l1_cache: LRUCache[Any] = LRUCache(l1_max_size, l1_ttl, "L1")
|
||||
self.l2_cache: LRUCache[Any] = LRUCache(l2_max_size, l2_ttl, "L2")
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
logger.info(
|
||||
f"多级缓存初始化: L1({l1_max_size}项/{l1_ttl}s) "
|
||||
f"L2({l2_max_size}项/{l2_ttl}s)"
|
||||
@@ -265,16 +266,16 @@ class MultiLevelCache:
|
||||
async def get(
|
||||
self,
|
||||
key: str,
|
||||
loader: Optional[Callable[[], Any]] = None,
|
||||
) -> Optional[Any]:
|
||||
loader: Callable[[], Any] | None = None,
|
||||
) -> Any | None:
|
||||
"""从缓存获取数据
|
||||
|
||||
|
||||
查询顺序:L1 -> L2 -> loader
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
loader: 数据加载函数,当缓存未命中时调用
|
||||
|
||||
|
||||
Returns:
|
||||
缓存值或加载的值,如果都不存在返回None
|
||||
"""
|
||||
@@ -307,12 +308,12 @@ class MultiLevelCache:
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
size: Optional[int] = None,
|
||||
size: int | None = None,
|
||||
) -> None:
|
||||
"""设置缓存值
|
||||
|
||||
|
||||
同时写入L1和L2
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
@@ -323,9 +324,9 @@ class MultiLevelCache:
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""删除缓存条目
|
||||
|
||||
|
||||
同时从L1和L2删除
|
||||
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
"""
|
||||
@@ -347,7 +348,7 @@ class MultiLevelCache:
|
||||
|
||||
async def start_cleanup_task(self, interval: float = 60) -> None:
|
||||
"""启动定期清理任务
|
||||
|
||||
|
||||
Args:
|
||||
interval: 清理间隔(秒)
|
||||
"""
|
||||
@@ -387,27 +388,27 @@ class MultiLevelCache:
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
_global_cache: Optional[MultiLevelCache] = None
|
||||
_global_cache: MultiLevelCache | None = None
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_cache() -> MultiLevelCache:
|
||||
"""获取全局缓存实例(单例)"""
|
||||
global _global_cache
|
||||
|
||||
|
||||
if _global_cache is None:
|
||||
async with _cache_lock:
|
||||
if _global_cache is None:
|
||||
_global_cache = MultiLevelCache()
|
||||
await _global_cache.start_cleanup_task()
|
||||
|
||||
|
||||
return _global_cache
|
||||
|
||||
|
||||
async def close_cache() -> None:
|
||||
"""关闭全局缓存"""
|
||||
global _global_cache
|
||||
|
||||
|
||||
if _global_cache is not None:
|
||||
await _global_cache.stop_cleanup_task()
|
||||
await _global_cache.clear()
|
||||
|
||||
@@ -150,7 +150,7 @@ class ConnectionPoolManager:
|
||||
logger.debug(f"🆕 创建连接 (池大小: {len(self._connections)})")
|
||||
|
||||
yield connection_info.session
|
||||
|
||||
|
||||
# 🔧 修复:正常退出时提交事务
|
||||
# 这对SQLite至关重要,因为SQLite没有autocommit
|
||||
if connection_info and connection_info.session:
|
||||
@@ -249,7 +249,7 @@ class ConnectionPoolManager:
|
||||
"""获取连接池统计信息"""
|
||||
total_requests = self._stats["pool_hits"] + self._stats["pool_misses"]
|
||||
pool_efficiency = (self._stats["pool_hits"] / max(1, total_requests)) * 100 if total_requests > 0 else 0
|
||||
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"active_connections": len(self._connections),
|
||||
|
||||
@@ -10,8 +10,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -25,7 +26,7 @@ logger = get_logger("preloader")
|
||||
@dataclass
|
||||
class AccessPattern:
|
||||
"""访问模式统计
|
||||
|
||||
|
||||
Attributes:
|
||||
key: 数据键
|
||||
access_count: 访问次数
|
||||
@@ -42,7 +43,7 @@ class AccessPattern:
|
||||
|
||||
class DataPreloader:
|
||||
"""数据预加载器
|
||||
|
||||
|
||||
通过分析访问模式,预测并预加载可能需要的数据
|
||||
"""
|
||||
|
||||
@@ -53,7 +54,7 @@ class DataPreloader:
|
||||
max_patterns: int = 1000,
|
||||
):
|
||||
"""初始化预加载器
|
||||
|
||||
|
||||
Args:
|
||||
decay_factor: 时间衰减因子(0-1),越小衰减越快
|
||||
preload_threshold: 预加载阈值,score超过此值时预加载
|
||||
@@ -62,7 +63,7 @@ class DataPreloader:
|
||||
self.decay_factor = decay_factor
|
||||
self.preload_threshold = preload_threshold
|
||||
self.max_patterns = max_patterns
|
||||
|
||||
|
||||
# 访问模式跟踪
|
||||
self._patterns: dict[str, AccessPattern] = {}
|
||||
# 关联关系:key -> [related_keys]
|
||||
@@ -73,9 +74,9 @@ class DataPreloader:
|
||||
self._total_accesses = 0
|
||||
self._preload_count = 0
|
||||
self._preload_hits = 0
|
||||
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"数据预加载器初始化: 衰减因子={decay_factor}, "
|
||||
f"预加载阈值={preload_threshold}"
|
||||
@@ -84,10 +85,10 @@ class DataPreloader:
|
||||
async def record_access(
|
||||
self,
|
||||
key: str,
|
||||
related_keys: Optional[list[str]] = None,
|
||||
related_keys: list[str] | None = None,
|
||||
) -> None:
|
||||
"""记录数据访问
|
||||
|
||||
|
||||
Args:
|
||||
key: 被访问的数据键
|
||||
related_keys: 关联访问的数据键列表
|
||||
@@ -95,7 +96,7 @@ class DataPreloader:
|
||||
async with self._lock:
|
||||
self._total_accesses += 1
|
||||
now = time.time()
|
||||
|
||||
|
||||
# 更新或创建访问模式
|
||||
if key in self._patterns:
|
||||
pattern = self._patterns[key]
|
||||
@@ -108,15 +109,15 @@ class DataPreloader:
|
||||
last_access=now,
|
||||
)
|
||||
self._patterns[key] = pattern
|
||||
|
||||
|
||||
# 更新热度评分(时间衰减)
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
# 记录关联关系
|
||||
if related_keys:
|
||||
self._associations[key].update(related_keys)
|
||||
pattern.related_keys = list(self._associations[key])
|
||||
|
||||
|
||||
# 如果模式过多,删除评分最低的
|
||||
if len(self._patterns) > self.max_patterns:
|
||||
min_key = min(self._patterns, key=lambda k: self._patterns[k].score)
|
||||
@@ -126,10 +127,10 @@ class DataPreloader:
|
||||
|
||||
async def should_preload(self, key: str) -> bool:
|
||||
"""判断是否应该预加载某个数据
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
|
||||
|
||||
Returns:
|
||||
是否应该预加载
|
||||
"""
|
||||
@@ -137,18 +138,18 @@ class DataPreloader:
|
||||
pattern = self._patterns.get(key)
|
||||
if pattern is None:
|
||||
return False
|
||||
|
||||
|
||||
# 更新评分
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
return pattern.score >= self.preload_threshold
|
||||
|
||||
async def get_preload_keys(self, limit: int = 100) -> list[str]:
|
||||
"""获取应该预加载的数据键列表
|
||||
|
||||
|
||||
Args:
|
||||
limit: 最大返回数量
|
||||
|
||||
|
||||
Returns:
|
||||
按评分排序的数据键列表
|
||||
"""
|
||||
@@ -156,14 +157,14 @@ class DataPreloader:
|
||||
# 更新所有评分
|
||||
for pattern in self._patterns.values():
|
||||
pattern.score = self._calculate_score(pattern)
|
||||
|
||||
|
||||
# 按评分排序
|
||||
sorted_patterns = sorted(
|
||||
self._patterns.values(),
|
||||
key=lambda p: p.score,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
# 返回超过阈值的键
|
||||
return [
|
||||
p.key for p in sorted_patterns[:limit]
|
||||
@@ -172,10 +173,10 @@ class DataPreloader:
|
||||
|
||||
async def get_related_keys(self, key: str) -> list[str]:
|
||||
"""获取关联数据键
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
|
||||
|
||||
Returns:
|
||||
关联数据键列表
|
||||
"""
|
||||
@@ -188,27 +189,27 @@ class DataPreloader:
|
||||
loader: Callable[[], Awaitable[Any]],
|
||||
) -> None:
|
||||
"""预加载数据
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
loader: 异步加载函数
|
||||
"""
|
||||
try:
|
||||
cache = await get_cache()
|
||||
|
||||
|
||||
# 检查缓存中是否已存在
|
||||
if await cache.l1_cache.get(key) is not None:
|
||||
return
|
||||
|
||||
|
||||
# 加载数据
|
||||
logger.debug(f"预加载数据: {key}")
|
||||
data = await loader()
|
||||
|
||||
|
||||
if data is not None:
|
||||
# 写入缓存
|
||||
await cache.set(key, data)
|
||||
self._preload_count += 1
|
||||
|
||||
|
||||
# 预加载关联数据
|
||||
related_keys = await self.get_related_keys(key)
|
||||
for related_key in related_keys[:5]: # 最多预加载5个关联项
|
||||
@@ -216,7 +217,7 @@ class DataPreloader:
|
||||
# 这里需要调用者提供关联数据的加载函数
|
||||
# 暂时只记录,不实际加载
|
||||
logger.debug(f"发现关联数据: {related_key}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预加载数据失败 {key}: {e}", exc_info=True)
|
||||
|
||||
@@ -226,13 +227,13 @@ class DataPreloader:
|
||||
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
||||
) -> None:
|
||||
"""批量启动预加载任务
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
loaders: 数据键到加载函数的映射
|
||||
"""
|
||||
preload_keys = await self.get_preload_keys()
|
||||
|
||||
|
||||
for key in preload_keys:
|
||||
if key in loaders:
|
||||
loader = loaders[key]
|
||||
@@ -242,9 +243,9 @@ class DataPreloader:
|
||||
|
||||
async def record_hit(self, key: str) -> None:
|
||||
"""记录预加载命中
|
||||
|
||||
|
||||
当缓存命中的数据是预加载的,调用此方法统计
|
||||
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
"""
|
||||
@@ -259,7 +260,7 @@ class DataPreloader:
|
||||
if self._preload_count > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"total_accesses": self._total_accesses,
|
||||
"tracked_patterns": len(self._patterns),
|
||||
@@ -278,7 +279,7 @@ class DataPreloader:
|
||||
self._total_accesses = 0
|
||||
self._preload_count = 0
|
||||
self._preload_hits = 0
|
||||
|
||||
|
||||
# 取消所有预加载任务
|
||||
for task in self._preload_tasks:
|
||||
task.cancel()
|
||||
@@ -286,38 +287,38 @@ class DataPreloader:
|
||||
|
||||
def _calculate_score(self, pattern: AccessPattern) -> float:
|
||||
"""计算热度评分
|
||||
|
||||
|
||||
使用时间衰减的访问频率:
|
||||
score = access_count * decay_factor^(time_since_last_access)
|
||||
|
||||
|
||||
Args:
|
||||
pattern: 访问模式
|
||||
|
||||
|
||||
Returns:
|
||||
热度评分
|
||||
"""
|
||||
now = time.time()
|
||||
time_diff = now - pattern.last_access
|
||||
|
||||
|
||||
# 时间衰减(以小时为单位)
|
||||
hours_passed = time_diff / 3600
|
||||
decay = self.decay_factor ** hours_passed
|
||||
|
||||
|
||||
# 评分 = 访问次数 * 时间衰减
|
||||
score = pattern.access_count * decay
|
||||
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class CommonDataPreloader:
|
||||
"""常见数据预加载器
|
||||
|
||||
|
||||
针对特定的数据类型提供预加载策略
|
||||
"""
|
||||
|
||||
def __init__(self, preloader: DataPreloader):
|
||||
"""初始化
|
||||
|
||||
|
||||
Args:
|
||||
preloader: 基础预加载器
|
||||
"""
|
||||
@@ -330,16 +331,16 @@ class CommonDataPreloader:
|
||||
platform: str,
|
||||
) -> None:
|
||||
"""预加载用户相关数据
|
||||
|
||||
|
||||
包括:个人信息、权限、关系等
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
user_id: 用户ID
|
||||
platform: 平台
|
||||
"""
|
||||
from src.common.database.core.models import PersonInfo, UserPermissions, UserRelationships
|
||||
|
||||
|
||||
# 预加载个人信息
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -347,7 +348,7 @@ class CommonDataPreloader:
|
||||
PersonInfo,
|
||||
{"platform": platform, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载用户权限
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -355,7 +356,7 @@ class CommonDataPreloader:
|
||||
UserPermissions,
|
||||
{"platform": platform, "user_id": user_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载用户关系
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -371,16 +372,16 @@ class CommonDataPreloader:
|
||||
limit: int = 50,
|
||||
) -> None:
|
||||
"""预加载聊天上下文
|
||||
|
||||
|
||||
包括:最近消息、聊天流信息等
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
stream_id: 聊天流ID
|
||||
limit: 消息数量限制
|
||||
"""
|
||||
from src.common.database.core.models import ChatStreams, Messages
|
||||
|
||||
from src.common.database.core.models import ChatStreams
|
||||
|
||||
# 预加载聊天流信息
|
||||
await self._preload_model(
|
||||
session,
|
||||
@@ -388,7 +389,7 @@ class CommonDataPreloader:
|
||||
ChatStreams,
|
||||
{"stream_id": stream_id},
|
||||
)
|
||||
|
||||
|
||||
# 预加载最近消息(这个比较复杂,暂时跳过)
|
||||
# TODO: 实现消息列表的预加载
|
||||
|
||||
@@ -400,7 +401,7 @@ class CommonDataPreloader:
|
||||
filters: dict[str, Any],
|
||||
) -> None:
|
||||
"""预加载模型数据
|
||||
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
cache_key: 缓存键
|
||||
@@ -413,31 +414,31 @@ class CommonDataPreloader:
|
||||
stmt = stmt.where(getattr(model_class, key) == value)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
await self.preloader.preload_data(cache_key, loader)
|
||||
|
||||
|
||||
# 全局预加载器实例
|
||||
_global_preloader: Optional[DataPreloader] = None
|
||||
_global_preloader: DataPreloader | None = None
|
||||
_preloader_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_preloader() -> DataPreloader:
|
||||
"""获取全局预加载器实例(单例)"""
|
||||
global _global_preloader
|
||||
|
||||
|
||||
if _global_preloader is None:
|
||||
async with _preloader_lock:
|
||||
if _global_preloader is None:
|
||||
_global_preloader = DataPreloader()
|
||||
|
||||
|
||||
return _global_preloader
|
||||
|
||||
|
||||
async def close_preloader() -> None:
|
||||
"""关闭全局预加载器"""
|
||||
global _global_preloader
|
||||
|
||||
|
||||
if _global_preloader is not None:
|
||||
await _global_preloader.clear()
|
||||
_global_preloader = None
|
||||
|
||||
@@ -37,29 +37,29 @@ from .monitoring import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatchSchedulerError",
|
||||
"CacheError",
|
||||
"ConnectionPoolError",
|
||||
"DatabaseConnectionError",
|
||||
# 异常
|
||||
"DatabaseError",
|
||||
"DatabaseInitializationError",
|
||||
"DatabaseConnectionError",
|
||||
"DatabaseMigrationError",
|
||||
# 监控
|
||||
"DatabaseMonitor",
|
||||
"DatabaseQueryError",
|
||||
"DatabaseTransactionError",
|
||||
"DatabaseMigrationError",
|
||||
"CacheError",
|
||||
"BatchSchedulerError",
|
||||
"ConnectionPoolError",
|
||||
"cached",
|
||||
"db_operation",
|
||||
"get_monitor",
|
||||
"measure_time",
|
||||
"print_stats",
|
||||
"record_cache_hit",
|
||||
"record_cache_miss",
|
||||
"record_operation",
|
||||
"reset_stats",
|
||||
# 装饰器
|
||||
"retry",
|
||||
"timeout",
|
||||
"cached",
|
||||
"measure_time",
|
||||
"transactional",
|
||||
"db_operation",
|
||||
# 监控
|
||||
"DatabaseMonitor",
|
||||
"get_monitor",
|
||||
"record_operation",
|
||||
"record_cache_hit",
|
||||
"record_cache_miss",
|
||||
"print_stats",
|
||||
"reset_stats",
|
||||
]
|
||||
|
||||
@@ -10,9 +10,11 @@ import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError, TimeoutError as SQLTimeoutError
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,33 +27,33 @@ def generate_cache_key(
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""生成与@cached装饰器相同的缓存键
|
||||
|
||||
|
||||
用于手动缓存失效等操作
|
||||
|
||||
|
||||
Args:
|
||||
key_prefix: 缓存键前缀
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
|
||||
Returns:
|
||||
缓存键字符串
|
||||
|
||||
|
||||
Example:
|
||||
cache_key = generate_cache_key("person_info", platform, person_id)
|
||||
await cache.delete(cache_key)
|
||||
"""
|
||||
cache_key_parts = [key_prefix]
|
||||
|
||||
|
||||
if args:
|
||||
args_str = ",".join(str(arg) for arg in args)
|
||||
args_hash = hashlib.md5(args_str.encode()).hexdigest()[:8]
|
||||
cache_key_parts.append(f"args:{args_hash}")
|
||||
|
||||
|
||||
if kwargs:
|
||||
kwargs_str = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||
kwargs_hash = hashlib.md5(kwargs_str.encode()).hexdigest()[:8]
|
||||
cache_key_parts.append(f"kwargs:{kwargs_hash}")
|
||||
|
||||
|
||||
return ":".join(cache_key_parts)
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -65,15 +67,15 @@ def retry(
|
||||
exceptions: tuple[type[Exception], ...] = (OperationalError, DBAPIError, SQLTimeoutError),
|
||||
):
|
||||
"""重试装饰器
|
||||
|
||||
|
||||
自动重试失败的数据库操作,适用于临时性错误
|
||||
|
||||
|
||||
Args:
|
||||
max_attempts: 最大尝试次数
|
||||
delay: 初始延迟时间(秒)
|
||||
backoff: 延迟倍数(指数退避)
|
||||
exceptions: 需要重试的异常类型
|
||||
|
||||
|
||||
Example:
|
||||
@retry(max_attempts=3, delay=1.0)
|
||||
async def query_data():
|
||||
@@ -114,12 +116,12 @@ def retry(
|
||||
|
||||
def timeout(seconds: float):
|
||||
"""超时装饰器
|
||||
|
||||
|
||||
为数据库操作添加超时控制
|
||||
|
||||
|
||||
Args:
|
||||
seconds: 超时时间(秒)
|
||||
|
||||
|
||||
Example:
|
||||
@timeout(30.0)
|
||||
async def long_query():
|
||||
@@ -141,21 +143,21 @@ def timeout(seconds: float):
|
||||
|
||||
|
||||
def cached(
|
||||
ttl: Optional[int] = 300,
|
||||
key_prefix: Optional[str] = None,
|
||||
ttl: int | None = 300,
|
||||
key_prefix: str | None = None,
|
||||
use_args: bool = True,
|
||||
use_kwargs: bool = True,
|
||||
):
|
||||
"""缓存装饰器
|
||||
|
||||
|
||||
自动缓存函数返回值
|
||||
|
||||
|
||||
Args:
|
||||
ttl: 缓存过期时间(秒),None表示永不过期
|
||||
key_prefix: 缓存键前缀,默认使用函数名
|
||||
use_args: 是否将位置参数包含在缓存键中
|
||||
use_kwargs: 是否将关键字参数包含在缓存键中
|
||||
|
||||
|
||||
Example:
|
||||
@cached(ttl=60, key_prefix="user_data")
|
||||
async def get_user_info(user_id: str) -> dict:
|
||||
@@ -167,7 +169,7 @@ def cached(
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
# 延迟导入避免循环依赖
|
||||
from src.common.database.optimization import get_cache
|
||||
|
||||
|
||||
# 生成缓存键
|
||||
cache_key_parts = [key_prefix or func.__name__]
|
||||
|
||||
@@ -207,14 +209,14 @@ def cached(
|
||||
return decorator
|
||||
|
||||
|
||||
def measure_time(log_slow: Optional[float] = None):
|
||||
def measure_time(log_slow: float | None = None):
|
||||
"""性能测量装饰器
|
||||
|
||||
|
||||
测量函数执行时间,可选择性记录慢查询
|
||||
|
||||
|
||||
Args:
|
||||
log_slow: 慢查询阈值(秒),超过此时间会记录warning日志
|
||||
|
||||
|
||||
Example:
|
||||
@measure_time(log_slow=1.0)
|
||||
async def complex_query():
|
||||
@@ -246,19 +248,19 @@ def measure_time(log_slow: Optional[float] = None):
|
||||
|
||||
def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
"""事务装饰器
|
||||
|
||||
|
||||
自动管理事务的提交和回滚
|
||||
|
||||
|
||||
Args:
|
||||
auto_commit: 是否自动提交
|
||||
auto_rollback: 发生异常时是否自动回滚
|
||||
|
||||
|
||||
Example:
|
||||
@transactional()
|
||||
async def update_multiple_records(session):
|
||||
await session.execute(stmt1)
|
||||
await session.execute(stmt2)
|
||||
|
||||
|
||||
Note:
|
||||
函数需要接受session参数
|
||||
"""
|
||||
@@ -306,20 +308,20 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
# 组合装饰器示例
|
||||
def db_operation(
|
||||
retry_attempts: int = 3,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
cache_ttl: Optional[int] = None,
|
||||
timeout_seconds: float | None = None,
|
||||
cache_ttl: int | None = None,
|
||||
measure: bool = True,
|
||||
):
|
||||
"""组合装饰器
|
||||
|
||||
|
||||
组合多个装饰器,提供完整的数据库操作保护
|
||||
|
||||
|
||||
Args:
|
||||
retry_attempts: 重试次数
|
||||
timeout_seconds: 超时时间
|
||||
cache_ttl: 缓存时间
|
||||
measure: 是否测量性能
|
||||
|
||||
|
||||
Example:
|
||||
@db_operation(retry_attempts=3, timeout_seconds=30, cache_ttl=60)
|
||||
async def important_query():
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -22,7 +21,7 @@ class OperationMetrics:
|
||||
min_time: float = float("inf")
|
||||
max_time: float = 0.0
|
||||
error_count: int = 0
|
||||
last_execution_time: Optional[float] = None
|
||||
last_execution_time: float | None = None
|
||||
|
||||
@property
|
||||
def avg_time(self) -> float:
|
||||
@@ -91,7 +90,7 @@ class DatabaseMetrics:
|
||||
|
||||
class DatabaseMonitor:
|
||||
"""数据库监控器
|
||||
|
||||
|
||||
单例模式,收集和报告数据库性能指标
|
||||
"""
|
||||
|
||||
@@ -285,7 +284,7 @@ class DatabaseMonitor:
|
||||
|
||||
|
||||
# 全局监控器实例
|
||||
_monitor: Optional[DatabaseMonitor] = None
|
||||
_monitor: DatabaseMonitor | None = None
|
||||
|
||||
|
||||
def get_monitor() -> DatabaseMonitor:
|
||||
|
||||
@@ -4,8 +4,8 @@ from datetime import datetime
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from src.common.database.core.models import LLMUsage
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import LLMUsage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
|
||||
|
||||
@@ -224,7 +224,7 @@ class MainSystem:
|
||||
|
||||
storage_batcher = get_message_storage_batcher()
|
||||
cleanup_tasks.append(("消息存储批处理器", storage_batcher.stop()))
|
||||
|
||||
|
||||
update_batcher = get_message_update_batcher()
|
||||
cleanup_tasks.append(("消息更新批处理器", update_batcher.stop()))
|
||||
except Exception as e:
|
||||
@@ -502,7 +502,7 @@ MoFox_Bot(第三方修改版)
|
||||
storage_batcher = get_message_storage_batcher()
|
||||
await storage_batcher.start()
|
||||
logger.info("消息存储批处理器已启动")
|
||||
|
||||
|
||||
update_batcher = get_message_update_batcher()
|
||||
await update_batcher.start()
|
||||
logger.info("消息更新批处理器已启动")
|
||||
|
||||
@@ -256,9 +256,10 @@ class RelationshipFetcher:
|
||||
str: 格式化后的聊天流印象字符串
|
||||
"""
|
||||
try:
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
import time
|
||||
|
||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||
|
||||
# 使用优化后的API(带缓存)
|
||||
# 从stream_id解析platform,或使用默认值
|
||||
platform = stream_id.split("_")[0] if "_" in stream_id else "unknown"
|
||||
@@ -289,7 +290,7 @@ class RelationshipFetcher:
|
||||
except Exception as e:
|
||||
logger.warning(f"访问stream对象属性失败: {e}")
|
||||
stream_data = {}
|
||||
|
||||
|
||||
impression_parts = []
|
||||
|
||||
# 1. 聊天环境基本信息
|
||||
|
||||
@@ -52,8 +52,8 @@ from typing import Any
|
||||
import orjson
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.database import get_active_plans_for_month
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.plugin_system.base.component_types import ChatType, CommandInfo, Compon
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
pass
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from src.common.database.core.models import PermissionNodes, UserPermissions
|
||||
from src.common.database.core import get_engine
|
||||
from src.common.database.core.models import PermissionNodes, UserPermissions
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
|
||||
import time
|
||||
|
||||
from src.common.database.core.models import UserRelationships
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import UserRelationships
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -7,12 +7,8 @@
|
||||
import json
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -358,14 +354,14 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
"stream_interest_score": impression.get("stream_interest_score", 0.5),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 使缓存失效
|
||||
from src.common.database.optimization.cache_manager import get_cache
|
||||
from src.common.database.utils.decorators import generate_cache_key
|
||||
cache = await get_cache()
|
||||
await cache.delete(generate_cache_key("stream_impression", stream_id))
|
||||
await cache.delete(generate_cache_key("chat_stream", stream_id))
|
||||
|
||||
|
||||
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
|
||||
else:
|
||||
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
|
||||
|
||||
@@ -64,7 +64,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if chat_stream:
|
||||
stream_config = chat_stream.get_raw_id()
|
||||
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
|
||||
|
||||
@@ -7,13 +7,10 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.utils.prompt import Prompt
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.database.api.crud import CRUDBase
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.database.utils.decorators import cached
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -208,7 +205,7 @@ class ProactiveThinkingPlanner:
|
||||
# 3. 获取bot人设和时间信息
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
|
||||
# 构建时间信息块
|
||||
time_block = f"当前时间是 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
@@ -554,11 +551,11 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if chat_stream:
|
||||
# 使用 ChatStream 的 get_raw_id() 方法获取配置字符串
|
||||
stream_config = chat_stream.get_raw_id()
|
||||
|
||||
|
||||
# 执行白名单/黑名单检查
|
||||
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
|
||||
logger.debug(f"聊天流 {stream_id} ({stream_config}) 未通过白名单/黑名单检查,跳过主动思考")
|
||||
@@ -567,7 +564,7 @@ async def execute_proactive_thinking(stream_id: str):
|
||||
logger.warning(f"无法获取聊天流 {stream_id} 的信息,跳过白名单检查")
|
||||
except Exception as e:
|
||||
logger.warning(f"白名单检查时出错: {e},继续执行")
|
||||
|
||||
|
||||
# 0.2 检查安静时段
|
||||
if proactive_thinking_scheduler._is_in_quiet_hours():
|
||||
logger.debug("安静时段,跳过")
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from src.common.database.core.models import MonthlyPlan
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import MonthlyPlan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -312,7 +312,7 @@ async def delete_plans_older_than(month: str):
|
||||
logger.info(f"没有找到比 {month} 更早的月度计划需要删除。")
|
||||
return 0
|
||||
|
||||
plan_months = sorted(list(set(p.target_month for p in plans_to_delete)))
|
||||
plan_months = sorted({p.target_month for p in plans_to_delete})
|
||||
logger.info(f"将删除 {len(plans_to_delete)} 条早于 {month} 的月度计划 (涉及月份: {', '.join(plan_months)})。")
|
||||
|
||||
# 然后,执行删除操作
|
||||
|
||||
@@ -100,7 +100,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
|
||||
next_month = datetime(now.year + 1, 1, 1)
|
||||
else:
|
||||
next_month = datetime(now.year, now.month + 1, 1)
|
||||
|
||||
|
||||
sleep_seconds = (next_month - now).total_seconds()
|
||||
logger.info(
|
||||
f" 下一次月度计划生成任务将在 {sleep_seconds:.2f} 秒后运行 (北京时间 {next_month.strftime('%Y-%m-%d %H:%M:%S')})"
|
||||
@@ -110,7 +110,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
|
||||
# 到达月初,先归档上个月的计划
|
||||
last_month = (next_month - timedelta(days=1)).strftime("%Y-%m")
|
||||
await self.monthly_plan_manager.plan_manager.archive_current_month_plans(last_month)
|
||||
|
||||
|
||||
# 为当前月生成新计划
|
||||
current_month = next_month.strftime("%Y-%m")
|
||||
logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...")
|
||||
|
||||
@@ -97,4 +97,4 @@ MONTHLY_PLAN_GENERATION_PROMPT = Prompt(
|
||||
|
||||
请你扮演我,以我的身份和兴趣,为 {target_month} 制定合适的月度计划。
|
||||
""",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||
from src.common.database.core import get_db_session
|
||||
from src.common.database.core.models import MonthlyPlan, Schedule
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
|
||||
Reference in New Issue
Block a user