将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 将同步的 SQLAlchemy 查询方法改为异步执行 - 更新相关的方法签名,添加 async/await 关键字 - 修复由于异步化导致的并发问题和性能问题 这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。 BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
781 lines
29 KiB
Python
781 lines
29 KiB
Python
"""SQLAlchemy数据库模型定义
|
||
|
||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||
"""
|
||
|
||
import datetime
|
||
import os
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from typing import Optional, Any, Dict, AsyncGenerator
|
||
|
||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import Mapped, mapped_column
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("sqlalchemy_models")
|
||
|
||
# 创建基类
|
||
Base = declarative_base()
|
||
|
||
|
||
# MySQL兼容的字段类型辅助函数
|
||
def get_string_field(max_length=255, **kwargs):
|
||
"""
|
||
根据数据库类型返回合适的字符串字段
|
||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||
"""
|
||
from src.config.config import global_config
|
||
|
||
if global_config.database.database_type == "mysql":
|
||
return String(max_length, **kwargs)
|
||
else:
|
||
return Text(**kwargs)
|
||
|
||
|
||
class ChatStreams(Base):
|
||
"""聊天流模型"""
|
||
|
||
__tablename__ = "chat_streams"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||
create_time = Column(Float, nullable=False)
|
||
group_platform = Column(Text, nullable=True)
|
||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||
group_name = Column(Text, nullable=True)
|
||
last_active_time = Column(Float, nullable=False)
|
||
platform = Column(Text, nullable=False)
|
||
user_platform = Column(Text, nullable=False)
|
||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||
user_nickname = Column(Text, nullable=False)
|
||
user_cardname = Column(Text, nullable=True)
|
||
energy_value = Column(Float, nullable=True, default=5.0)
|
||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
||
focus_energy = Column(Float, nullable=True, default=0.5)
|
||
# 动态兴趣度系统字段
|
||
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
||
message_interest_total = Column(Float, nullable=True, default=0.0)
|
||
message_count = Column(Integer, nullable=True, default=0)
|
||
action_count = Column(Integer, nullable=True, default=0)
|
||
reply_count = Column(Integer, nullable=True, default=0)
|
||
last_interaction_time = Column(Float, nullable=True, default=None)
|
||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||
# 消息打断系统字段
|
||
interruption_count = Column(Integer, nullable=True, default=0)
|
||
|
||
__table_args__ = (
|
||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||
Index("idx_chatstreams_user_id", "user_id"),
|
||
Index("idx_chatstreams_group_id", "group_id"),
|
||
)
|
||
|
||
|
||
class LLMUsage(Base):
|
||
"""LLM使用记录模型"""
|
||
|
||
__tablename__ = "llm_usage"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
|
||
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
|
||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||
endpoint = Column(Text, nullable=False)
|
||
prompt_tokens = Column(Integer, nullable=False)
|
||
completion_tokens = Column(Integer, nullable=False)
|
||
time_cost = Column(Float, nullable=True)
|
||
total_tokens = Column(Integer, nullable=False)
|
||
cost = Column(Float, nullable=False)
|
||
status = Column(Text, nullable=False)
|
||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||
|
||
__table_args__ = (
|
||
Index("idx_llmusage_model_name", "model_name"),
|
||
Index("idx_llmusage_model_assign_name", "model_assign_name"),
|
||
Index("idx_llmusage_model_api_provider", "model_api_provider"),
|
||
Index("idx_llmusage_time_cost", "time_cost"),
|
||
Index("idx_llmusage_user_id", "user_id"),
|
||
Index("idx_llmusage_request_type", "request_type"),
|
||
Index("idx_llmusage_timestamp", "timestamp"),
|
||
)
|
||
|
||
|
||
class Emoji(Base):
|
||
"""表情包模型"""
|
||
|
||
__tablename__ = "emoji"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||
format = Column(Text, nullable=False)
|
||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||
description = Column(Text, nullable=False)
|
||
query_count = Column(Integer, nullable=False, default=0)
|
||
is_registered = Column(Boolean, nullable=False, default=False)
|
||
is_banned = Column(Boolean, nullable=False, default=False)
|
||
emotion = Column(Text, nullable=True)
|
||
record_time = Column(Float, nullable=False)
|
||
register_time = Column(Float, nullable=True)
|
||
usage_count = Column(Integer, nullable=False, default=0)
|
||
last_used_time = Column(Float, nullable=True)
|
||
|
||
__table_args__ = (
|
||
Index("idx_emoji_full_path", "full_path"),
|
||
Index("idx_emoji_hash", "emoji_hash"),
|
||
)
|
||
|
||
|
||
class Messages(Base):
|
||
"""消息模型"""
|
||
|
||
__tablename__ = "messages"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||
time = Column(Float, nullable=False)
|
||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||
reply_to = Column(Text, nullable=True)
|
||
interest_value = Column(Float, nullable=True)
|
||
key_words = Column(Text, nullable=True)
|
||
key_words_lite = Column(Text, nullable=True)
|
||
is_mentioned = Column(Boolean, nullable=True)
|
||
|
||
# 从 chat_info 扁平化而来的字段
|
||
chat_info_stream_id = Column(Text, nullable=False)
|
||
chat_info_platform = Column(Text, nullable=False)
|
||
chat_info_user_platform = Column(Text, nullable=False)
|
||
chat_info_user_id = Column(Text, nullable=False)
|
||
chat_info_user_nickname = Column(Text, nullable=False)
|
||
chat_info_user_cardname = Column(Text, nullable=True)
|
||
chat_info_group_platform = Column(Text, nullable=True)
|
||
chat_info_group_id = Column(Text, nullable=True)
|
||
chat_info_group_name = Column(Text, nullable=True)
|
||
chat_info_create_time = Column(Float, nullable=False)
|
||
chat_info_last_active_time = Column(Float, nullable=False)
|
||
|
||
# 从顶层 user_info 扁平化而来的字段
|
||
user_platform = Column(Text, nullable=True)
|
||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||
user_nickname = Column(Text, nullable=True)
|
||
user_cardname = Column(Text, nullable=True)
|
||
|
||
processed_plain_text = Column(Text, nullable=True)
|
||
display_message = Column(Text, nullable=True)
|
||
memorized_times = Column(Integer, nullable=False, default=0)
|
||
priority_mode = Column(Text, nullable=True)
|
||
priority_info = Column(Text, nullable=True)
|
||
additional_config = Column(Text, nullable=True)
|
||
is_emoji = Column(Boolean, nullable=False, default=False)
|
||
is_picid = Column(Boolean, nullable=False, default=False)
|
||
is_command = Column(Boolean, nullable=False, default=False)
|
||
is_notify = Column(Boolean, nullable=False, default=False)
|
||
|
||
# 兴趣度系统字段
|
||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||
should_reply = Column(Boolean, nullable=True, default=False)
|
||
|
||
__table_args__ = (
|
||
Index("idx_messages_message_id", "message_id"),
|
||
Index("idx_messages_chat_id", "chat_id"),
|
||
Index("idx_messages_time", "time"),
|
||
Index("idx_messages_user_id", "user_id"),
|
||
Index("idx_messages_should_reply", "should_reply"),
|
||
)
|
||
|
||
|
||
class ActionRecords(Base):
|
||
"""动作记录模型"""
|
||
|
||
__tablename__ = "action_records"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||
time = Column(Float, nullable=False)
|
||
action_name = Column(Text, nullable=False)
|
||
action_data = Column(Text, nullable=False)
|
||
action_done = Column(Boolean, nullable=False, default=False)
|
||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||
action_prompt_display = Column(Text, nullable=False)
|
||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||
chat_info_stream_id = Column(Text, nullable=False)
|
||
chat_info_platform = Column(Text, nullable=False)
|
||
|
||
__table_args__ = (
|
||
Index("idx_actionrecords_action_id", "action_id"),
|
||
Index("idx_actionrecords_chat_id", "chat_id"),
|
||
Index("idx_actionrecords_time", "time"),
|
||
)
|
||
|
||
|
||
class Images(Base):
|
||
"""图像信息模型"""
|
||
|
||
__tablename__ = "images"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
image_id = Column(Text, nullable=False, default="")
|
||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||
description = Column(Text, nullable=True)
|
||
path = Column(get_string_field(500), nullable=False, unique=True)
|
||
count = Column(Integer, nullable=False, default=1)
|
||
timestamp = Column(Float, nullable=False)
|
||
type = Column(Text, nullable=False)
|
||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||
|
||
__table_args__ = (
|
||
Index("idx_images_emoji_hash", "emoji_hash"),
|
||
Index("idx_images_path", "path"),
|
||
)
|
||
|
||
|
||
class ImageDescriptions(Base):
|
||
"""图像描述信息模型"""
|
||
|
||
__tablename__ = "image_descriptions"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
type = Column(Text, nullable=False)
|
||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||
description = Column(Text, nullable=False)
|
||
timestamp = Column(Float, nullable=False)
|
||
|
||
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
||
|
||
|
||
class Videos(Base):
|
||
"""视频信息模型"""
|
||
|
||
__tablename__ = "videos"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
video_id = Column(Text, nullable=False, default="")
|
||
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
|
||
description = Column(Text, nullable=True)
|
||
count = Column(Integer, nullable=False, default=1)
|
||
timestamp = Column(Float, nullable=False)
|
||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||
|
||
# 视频特有属性
|
||
duration = Column(Float, nullable=True) # 视频时长(秒)
|
||
frame_count = Column(Integer, nullable=True) # 总帧数
|
||
fps = Column(Float, nullable=True) # 帧率
|
||
resolution = Column(Text, nullable=True) # 分辨率
|
||
file_size = Column(Integer, nullable=True) # 文件大小(字节)
|
||
|
||
__table_args__ = (
|
||
Index("idx_videos_video_hash", "video_hash"),
|
||
Index("idx_videos_timestamp", "timestamp"),
|
||
)
|
||
|
||
|
||
class OnlineTime(Base):
|
||
"""在线时长记录模型"""
|
||
|
||
__tablename__ = "online_time"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||
duration = Column(Integer, nullable=False)
|
||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||
|
||
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
||
|
||
|
||
class PersonInfo(Base):
|
||
"""人物信息模型"""
|
||
|
||
__tablename__ = "person_info"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||
person_name = Column(Text, nullable=True)
|
||
name_reason = Column(Text, nullable=True)
|
||
platform = Column(Text, nullable=False)
|
||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||
nickname = Column(Text, nullable=True)
|
||
impression = Column(Text, nullable=True)
|
||
short_impression = Column(Text, nullable=True)
|
||
points = Column(Text, nullable=True)
|
||
forgotten_points = Column(Text, nullable=True)
|
||
info_list = Column(Text, nullable=True)
|
||
know_times = Column(Float, nullable=True)
|
||
know_since = Column(Float, nullable=True)
|
||
last_know = Column(Float, nullable=True)
|
||
attitude = Column(Integer, nullable=True, default=50)
|
||
|
||
__table_args__ = (
|
||
Index("idx_personinfo_person_id", "person_id"),
|
||
Index("idx_personinfo_user_id", "user_id"),
|
||
)
|
||
|
||
|
||
class BotPersonalityInterests(Base):
|
||
"""机器人人格兴趣标签模型"""
|
||
|
||
__tablename__ = "bot_personality_interests"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
||
personality_description = Column(Text, nullable=False)
|
||
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
||
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||
version = Column(Integer, nullable=False, default=1)
|
||
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||
|
||
__table_args__ = (
|
||
Index("idx_botpersonality_personality_id", "personality_id"),
|
||
Index("idx_botpersonality_version", "version"),
|
||
Index("idx_botpersonality_last_updated", "last_updated"),
|
||
)
|
||
|
||
|
||
class Memory(Base):
|
||
"""记忆模型"""
|
||
|
||
__tablename__ = "memory"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||
chat_id = Column(Text, nullable=True)
|
||
memory_text = Column(Text, nullable=True)
|
||
keywords = Column(Text, nullable=True)
|
||
create_time = Column(Float, nullable=True)
|
||
last_view_time = Column(Float, nullable=True)
|
||
|
||
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
||
|
||
|
||
class Expression(Base):
|
||
"""表达风格模型"""
|
||
|
||
__tablename__ = "expression"
|
||
|
||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
||
style: Mapped[str] = mapped_column(Text, nullable=False)
|
||
count: Mapped[float] = mapped_column(Float, nullable=False)
|
||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||
|
||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
||
|
||
|
||
class ThinkingLog(Base):
|
||
"""思考日志模型"""
|
||
|
||
__tablename__ = "thinking_logs"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||
trigger_text = Column(Text, nullable=True)
|
||
response_text = Column(Text, nullable=True)
|
||
trigger_info_json = Column(Text, nullable=True)
|
||
response_info_json = Column(Text, nullable=True)
|
||
timing_results_json = Column(Text, nullable=True)
|
||
chat_history_json = Column(Text, nullable=True)
|
||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||
chat_history_after_response_json = Column(Text, nullable=True)
|
||
heartflow_data_json = Column(Text, nullable=True)
|
||
reasoning_data_json = Column(Text, nullable=True)
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
|
||
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
||
|
||
|
||
class GraphNodes(Base):
|
||
"""记忆图节点模型"""
|
||
|
||
__tablename__ = "graph_nodes"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||
memory_items = Column(Text, nullable=False)
|
||
hash = Column(Text, nullable=False)
|
||
weight = Column(Float, nullable=False, default=1.0)
|
||
created_time = Column(Float, nullable=False)
|
||
last_modified = Column(Float, nullable=False)
|
||
|
||
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
||
|
||
|
||
class GraphEdges(Base):
|
||
"""记忆图边模型"""
|
||
|
||
__tablename__ = "graph_edges"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
source = Column(get_string_field(255), nullable=False, index=True)
|
||
target = Column(get_string_field(255), nullable=False, index=True)
|
||
strength = Column(Integer, nullable=False)
|
||
hash = Column(Text, nullable=False)
|
||
created_time = Column(Float, nullable=False)
|
||
last_modified = Column(Float, nullable=False)
|
||
|
||
__table_args__ = (
|
||
Index("idx_graphedges_source", "source"),
|
||
Index("idx_graphedges_target", "target"),
|
||
)
|
||
|
||
|
||
class Schedule(Base):
|
||
"""日程模型"""
|
||
|
||
__tablename__ = "schedule"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
|
||
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||
|
||
__table_args__ = (Index("idx_schedule_date", "date"),)
|
||
|
||
|
||
class MaiZoneScheduleStatus(Base):
|
||
"""麦麦空间日程处理状态模型"""
|
||
|
||
__tablename__ = "maizone_schedule_status"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
datetime_hour = Column(
|
||
get_string_field(13), nullable=False, unique=True, index=True
|
||
) # YYYY-MM-DD HH格式,精确到小时
|
||
activity = Column(Text, nullable=False) # 该小时的活动内容
|
||
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
|
||
processed_at = Column(DateTime, nullable=True) # 处理时间
|
||
story_content = Column(Text, nullable=True) # 生成的说说内容
|
||
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||
|
||
__table_args__ = (
|
||
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
||
Index("idx_maizone_is_processed", "is_processed"),
|
||
)
|
||
|
||
|
||
class BanUser(Base):
|
||
"""被禁用用户模型"""
|
||
|
||
__tablename__ = "ban_users"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
platform = Column(Text, nullable=False)
|
||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||
violation_num = Column(Integer, nullable=False, default=0)
|
||
reason = Column(Text, nullable=False)
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
|
||
__table_args__ = (
|
||
Index("idx_violation_num", "violation_num"),
|
||
Index("idx_banuser_user_id", "user_id"),
|
||
Index("idx_banuser_platform", "platform"),
|
||
Index("idx_banuser_platform_user_id", "platform", "user_id"),
|
||
)
|
||
|
||
|
||
class AntiInjectionStats(Base):
|
||
"""反注入系统统计模型"""
|
||
|
||
__tablename__ = "anti_injection_stats"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
total_messages = Column(Integer, nullable=False, default=0)
|
||
"""总处理消息数"""
|
||
|
||
detected_injections = Column(Integer, nullable=False, default=0)
|
||
"""检测到的注入攻击数"""
|
||
|
||
blocked_messages = Column(Integer, nullable=False, default=0)
|
||
"""被阻止的消息数"""
|
||
|
||
shielded_messages = Column(Integer, nullable=False, default=0)
|
||
"""被加盾的消息数"""
|
||
|
||
processing_time_total = Column(Float, nullable=False, default=0.0)
|
||
"""总处理时间"""
|
||
|
||
total_process_time = Column(Float, nullable=False, default=0.0)
|
||
"""累计总处理时间"""
|
||
|
||
last_process_time = Column(Float, nullable=False, default=0.0)
|
||
"""最近一次处理时间"""
|
||
|
||
error_count = Column(Integer, nullable=False, default=0)
|
||
"""错误计数"""
|
||
|
||
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
"""统计开始时间"""
|
||
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
"""记录创建时间"""
|
||
|
||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||
"""记录更新时间"""
|
||
|
||
__table_args__ = (
|
||
Index("idx_anti_injection_stats_created_at", "created_at"),
|
||
Index("idx_anti_injection_stats_updated_at", "updated_at"),
|
||
)
|
||
|
||
|
||
class CacheEntries(Base):
|
||
"""工具缓存条目模型"""
|
||
|
||
__tablename__ = "cache_entries"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||
"""缓存键,包含工具名、参数和代码哈希"""
|
||
|
||
cache_value = Column(Text, nullable=False)
|
||
"""缓存的数据,JSON格式"""
|
||
|
||
expires_at = Column(Float, nullable=False, index=True)
|
||
"""过期时间戳"""
|
||
|
||
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
||
"""工具名称"""
|
||
|
||
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
||
"""创建时间戳"""
|
||
|
||
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
||
"""最后访问时间戳"""
|
||
|
||
access_count = Column(Integer, nullable=False, default=0)
|
||
"""访问次数"""
|
||
|
||
__table_args__ = (
|
||
Index("idx_cache_entries_key", "cache_key"),
|
||
Index("idx_cache_entries_expires_at", "expires_at"),
|
||
Index("idx_cache_entries_tool_name", "tool_name"),
|
||
Index("idx_cache_entries_created_at", "created_at"),
|
||
)
|
||
|
||
|
||
class MonthlyPlan(Base):
|
||
"""月度计划模型"""
|
||
|
||
__tablename__ = "monthly_plans"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
plan_text = Column(Text, nullable=False)
|
||
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
|
||
status = Column(
|
||
get_string_field(20), nullable=False, default="active", index=True
|
||
) # 'active', 'completed', 'archived'
|
||
usage_count = Column(Integer, nullable=False, default=0)
|
||
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
|
||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||
|
||
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
||
is_deleted = Column(Boolean, nullable=False, default=False)
|
||
|
||
__table_args__ = (
|
||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
||
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
||
Index("idx_monthlyplan_usage_count", "usage_count"),
|
||
# 保留旧索引以兼容
|
||
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
|
||
)
|
||
|
||
|
||
# 数据库引擎和会话管理
|
||
_engine = None
|
||
_SessionLocal = None
|
||
|
||
|
||
def get_database_url():
|
||
"""获取数据库连接URL"""
|
||
from src.config.config import global_config
|
||
|
||
config = global_config.database
|
||
|
||
if config.database_type == "mysql":
|
||
# 对用户名和密码进行URL编码,处理特殊字符
|
||
from urllib.parse import quote_plus
|
||
|
||
encoded_user = quote_plus(config.mysql_user)
|
||
encoded_password = quote_plus(config.mysql_password)
|
||
|
||
# 检查是否配置了Unix socket连接
|
||
if config.mysql_unix_socket:
|
||
# 使用Unix socket连接
|
||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||
return (
|
||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||
f"@/{config.mysql_database}"
|
||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||
)
|
||
else:
|
||
# 使用标准TCP连接
|
||
return (
|
||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||
f"?charset={config.mysql_charset}"
|
||
)
|
||
else: # SQLite
|
||
# 如果是相对路径,则相对于项目根目录
|
||
if not os.path.isabs(config.sqlite_path):
|
||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||
else:
|
||
db_path = config.sqlite_path
|
||
|
||
# 确保数据库目录存在
|
||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||
|
||
return f"sqlite+aiosqlite:///{db_path}"
|
||
|
||
|
||
async def initialize_database():
|
||
"""初始化异步数据库引擎和会话"""
|
||
global _engine, _SessionLocal
|
||
|
||
if _engine is not None:
|
||
return _engine, _SessionLocal
|
||
|
||
database_url = get_database_url()
|
||
from src.config.config import global_config
|
||
|
||
config = global_config.database
|
||
|
||
# 配置引擎参数
|
||
engine_kwargs: Dict[str, Any] = {
|
||
"echo": False, # 生产环境关闭SQL日志
|
||
"future": True,
|
||
}
|
||
|
||
if config.database_type == "mysql":
|
||
# MySQL连接池配置 - 异步引擎使用默认连接池
|
||
engine_kwargs.update(
|
||
{
|
||
"pool_size": config.connection_pool_size,
|
||
"max_overflow": config.connection_pool_size * 2,
|
||
"pool_timeout": config.connection_timeout,
|
||
"pool_recycle": 3600, # 1小时回收连接
|
||
"pool_pre_ping": True, # 连接前ping检查
|
||
"connect_args": {
|
||
"autocommit": config.mysql_autocommit,
|
||
"charset": config.mysql_charset,
|
||
"connect_timeout": config.connection_timeout,
|
||
"read_timeout": 30,
|
||
"write_timeout": 30,
|
||
},
|
||
}
|
||
)
|
||
else:
|
||
# SQLite配置 - aiosqlite不支持连接池参数
|
||
engine_kwargs.update(
|
||
{
|
||
"connect_args": {
|
||
"check_same_thread": False,
|
||
"timeout": 30,
|
||
},
|
||
}
|
||
)
|
||
|
||
_engine = create_async_engine(database_url, **engine_kwargs)
|
||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
||
# 调用新的迁移函数,它会处理表的创建和列的添加
|
||
from src.common.database.db_migration import check_and_migrate_database
|
||
|
||
await check_and_migrate_database()
|
||
|
||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
||
return _engine, _SessionLocal
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||
"""异步数据库会话上下文管理器"""
|
||
session: Optional[AsyncSession] = None
|
||
try:
|
||
engine, SessionLocal = await initialize_database()
|
||
if not SessionLocal:
|
||
raise RuntimeError("Database session not initialized")
|
||
session = SessionLocal()
|
||
yield session
|
||
except Exception as e:
|
||
logger.error(f"数据库会话错误: {e}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
if session:
|
||
await session.close()
|
||
|
||
|
||
async def get_engine():
|
||
"""获取异步数据库引擎"""
|
||
engine, _ = await initialize_database()
|
||
return engine
|
||
|
||
|
||
class PermissionNodes(Base):
|
||
"""权限节点模型"""
|
||
|
||
__tablename__ = "permission_nodes"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
||
description = Column(Text, nullable=False) # 权限描述
|
||
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
||
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||
|
||
__table_args__ = (
|
||
Index("idx_permission_plugin", "plugin_name"),
|
||
Index("idx_permission_node", "node_name"),
|
||
)
|
||
|
||
|
||
class UserPermissions(Base):
|
||
"""用户权限模型"""
|
||
|
||
__tablename__ = "user_permissions"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
||
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
|
||
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
|
||
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
||
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
||
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
||
|
||
__table_args__ = (
|
||
Index("idx_user_platform_id", "platform", "user_id"),
|
||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||
Index("idx_permission_granted", "permission_node", "granted"),
|
||
)
|
||
|
||
|
||
class UserRelationships(Base):
|
||
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||
|
||
__tablename__ = "user_relationships"
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||
|
||
__table_args__ = (
|
||
Index("idx_user_relationship_id", "user_id"),
|
||
Index("idx_relationship_score", "relationship_score"),
|
||
Index("idx_relationship_updated", "last_updated"),
|
||
)
|