修复代码格式和文件名大小写问题
This commit is contained in:
@@ -352,4 +352,3 @@ class CacheManager:
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
|
||||
@@ -16,30 +16,30 @@ _sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
|
||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
||||
class DatabaseProxy:
|
||||
"""数据库代理类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
|
||||
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
"""SQLAlchemy事务上下文管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
self.session = get_db_session()
|
||||
return self.session
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
@@ -47,9 +47,11 @@ class SQLAlchemyTransaction:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
|
||||
# 创建全局数据库代理实例
|
||||
db = DatabaseProxy()
|
||||
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
@@ -95,10 +97,10 @@ def initialize_sql_database(database_config):
|
||||
database_config: DatabaseConfig对象
|
||||
"""
|
||||
global _sql_engine
|
||||
|
||||
|
||||
try:
|
||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
||||
|
||||
|
||||
# 记录数据库配置信息
|
||||
if database_config.database_type == "mysql":
|
||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
||||
@@ -113,7 +115,7 @@ def initialize_sql_database(database_config):
|
||||
db_path = database_config.sqlite_path
|
||||
logger.info("SQLite数据库连接配置:")
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
if success:
|
||||
@@ -121,13 +123,14 @@ def initialize_sql_database(database_config):
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
else:
|
||||
logger.error("SQLAlchemy数据库初始化失败")
|
||||
|
||||
|
||||
return _sql_engine
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
@@ -140,4 +143,3 @@ class DBWrapper:
|
||||
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@
|
||||
from typing import List
|
||||
from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 需要导入全局配置
|
||||
from src.config.config import global_config # 需要导入全局配置
|
||||
|
||||
logger = get_logger("monthly_plan_db")
|
||||
|
||||
|
||||
def add_new_plans(plans: List[str], month: str):
|
||||
"""
|
||||
批量添加新生成的月度计划到数据库,并确保不超过上限。
|
||||
@@ -17,10 +18,11 @@ def add_new_plans(plans: List[str], month: str):
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
# 1. 获取当前有效计划数量(状态为 'active')
|
||||
current_plan_count = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'active'
|
||||
).count()
|
||||
current_plan_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
)
|
||||
|
||||
# 2. 从配置获取上限
|
||||
max_plans = global_config.monthly_plan_system.max_plans_per_month
|
||||
@@ -36,12 +38,11 @@ def add_new_plans(plans: List[str], month: str):
|
||||
plans_to_add = plans[:remaining_slots]
|
||||
|
||||
new_plan_objects = [
|
||||
MonthlyPlan(plan_text=plan, target_month=month, status='active')
|
||||
for plan in plans_to_add
|
||||
MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add
|
||||
]
|
||||
session.add_all(new_plan_objects)
|
||||
session.commit()
|
||||
|
||||
|
||||
logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。")
|
||||
if len(plans) > len(plans_to_add):
|
||||
logger.info(f"由于达到月度计划上限,有 {len(plans) - len(plans_to_add)} 条计划未被添加。")
|
||||
@@ -51,6 +52,7 @@ def add_new_plans(plans: List[str], month: str):
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'active' 的计划。
|
||||
@@ -60,15 +62,17 @@ def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
plans = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'active'
|
||||
).all()
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.all()
|
||||
)
|
||||
return plans
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def mark_plans_completed(plan_ids: List[int]):
|
||||
"""
|
||||
将指定ID的计划标记为已完成。
|
||||
@@ -85,18 +89,19 @@ def mark_plans_completed(plan_ids: List[int]):
|
||||
logger.info("没有需要标记为完成的月度计划。")
|
||||
return
|
||||
|
||||
plan_details = "\n".join([f" {i+1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)])
|
||||
plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)])
|
||||
logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}")
|
||||
|
||||
session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.id.in_(plan_ids)
|
||||
).update({"status": "completed"}, synchronize_session=False)
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"标记月度计划为完成时发生错误: {e}")
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def delete_plans_by_ids(plan_ids: List[int]):
|
||||
"""
|
||||
根据ID列表从数据库中物理删除月度计划。
|
||||
@@ -114,20 +119,19 @@ def delete_plans_by_ids(plan_ids: List[int]):
|
||||
logger.info("没有找到需要删除的月度计划。")
|
||||
return
|
||||
|
||||
plan_details = "\n".join([f" {i+1}. {plan.plan_text}" for i, plan in enumerate(plans_to_delete)])
|
||||
plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_delete)])
|
||||
logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}")
|
||||
|
||||
# 执行删除
|
||||
session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.id.in_(plan_ids)
|
||||
).delete(synchronize_session=False)
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def soft_delete_plans(plan_ids: List[int]):
|
||||
"""
|
||||
将指定ID的计划标记为软删除(兼容旧接口)。
|
||||
@@ -138,6 +142,7 @@ def soft_delete_plans(plan_ids: List[int]):
|
||||
logger.warning("soft_delete_plans 已弃用,请使用 mark_plans_completed")
|
||||
mark_plans_completed(plan_ids)
|
||||
|
||||
|
||||
def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
"""
|
||||
更新计划的使用统计信息。
|
||||
@@ -151,33 +156,32 @@ def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
# 获取完成阈值配置,如果不存在则使用默认值
|
||||
completion_threshold = getattr(global_config.monthly_plan_system, 'completion_threshold', 3)
|
||||
|
||||
completion_threshold = getattr(global_config.monthly_plan_system, "completion_threshold", 3)
|
||||
|
||||
# 批量更新使用次数和最后使用日期
|
||||
session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.id.in_(plan_ids)
|
||||
).update({
|
||||
"usage_count": MonthlyPlan.usage_count + 1,
|
||||
"last_used_date": used_date
|
||||
}, synchronize_session=False)
|
||||
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False
|
||||
)
|
||||
|
||||
# 检查是否有计划达到完成阈值
|
||||
plans_to_complete = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.id.in_(plan_ids),
|
||||
MonthlyPlan.usage_count >= completion_threshold,
|
||||
MonthlyPlan.status == 'active'
|
||||
).all()
|
||||
|
||||
plans_to_complete = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(
|
||||
MonthlyPlan.id.in_(plan_ids),
|
||||
MonthlyPlan.usage_count >= completion_threshold,
|
||||
MonthlyPlan.status == "active",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if plans_to_complete:
|
||||
completed_ids = [plan.id for plan in plans_to_complete]
|
||||
session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.id.in_(completed_ids)
|
||||
).update({
|
||||
"status": "completed"
|
||||
}, synchronize_session=False)
|
||||
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
)
|
||||
|
||||
logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。")
|
||||
|
||||
|
||||
session.commit()
|
||||
logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。")
|
||||
except Exception as e:
|
||||
@@ -185,10 +189,11 @@ def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]:
|
||||
"""
|
||||
智能抽取月度计划用于每日日程生成。
|
||||
|
||||
|
||||
抽取规则:
|
||||
1. 避免短期内重复(avoid_days 天内不重复抽取同一个计划)
|
||||
2. 优先抽取使用次数较少的计划
|
||||
@@ -200,43 +205,39 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
|
||||
:return: MonthlyPlan 对象列表。
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
# 计算避免重复的日期阈值
|
||||
avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
# 查询符合条件的计划
|
||||
query = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'active'
|
||||
)
|
||||
|
||||
query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
|
||||
# 排除最近使用过的计划
|
||||
query = query.filter(
|
||||
(MonthlyPlan.last_used_date.is_(None)) |
|
||||
(MonthlyPlan.last_used_date < avoid_date)
|
||||
)
|
||||
|
||||
query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date))
|
||||
|
||||
# 按使用次数升序排列,优先选择使用次数少的
|
||||
plans = query.order_by(MonthlyPlan.usage_count.asc()).all()
|
||||
|
||||
|
||||
if not plans:
|
||||
logger.info(f"没有找到符合条件的 {month} 月度计划。")
|
||||
return []
|
||||
|
||||
|
||||
# 如果计划数量超过需要的数量,进行随机抽取
|
||||
if len(plans) > max_count:
|
||||
import random
|
||||
|
||||
plans = random.sample(plans, max_count)
|
||||
|
||||
|
||||
logger.info(f"智能抽取了 {len(plans)} 条 {month} 的月度计划用于每日日程生成。")
|
||||
return plans
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能抽取 {month} 的月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def archive_active_plans_for_month(month: str):
|
||||
"""
|
||||
将指定月份所有状态为 'active' 的计划归档为 'archived'。
|
||||
@@ -246,11 +247,12 @@ def archive_active_plans_for_month(month: str):
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
updated_count = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'active'
|
||||
).update({"status": "archived"}, synchronize_session=False)
|
||||
|
||||
updated_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.update({"status": "archived"}, synchronize_session=False)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
logger.info(f"成功将 {updated_count} 条 {month} 的活跃月度计划归档。")
|
||||
return updated_count
|
||||
@@ -259,6 +261,7 @@ def archive_active_plans_for_month(month: str):
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'archived' 的计划。
|
||||
@@ -269,15 +272,17 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
plans = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'archived'
|
||||
).all()
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived")
|
||||
.all()
|
||||
)
|
||||
return plans
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def has_active_plans(month: str) -> bool:
|
||||
"""
|
||||
检查指定月份是否存在任何状态为 'active' 的计划。
|
||||
@@ -287,11 +292,12 @@ def has_active_plans(month: str) -> bool:
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
count = session.query(MonthlyPlan).filter(
|
||||
MonthlyPlan.target_month == month,
|
||||
MonthlyPlan.status == 'active'
|
||||
).count()
|
||||
count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
)
|
||||
return count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -11,38 +11,51 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy import desc, asc, func, and_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus,
|
||||
CacheEntries
|
||||
Base,
|
||||
get_db_session,
|
||||
Messages,
|
||||
ActionRecords,
|
||||
PersonInfo,
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
Memory,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
'Schedule': Schedule,
|
||||
'MaiZoneScheduleStatus': MaiZoneScheduleStatus,
|
||||
'CacheEntries': CacheEntries,
|
||||
"Messages": Messages,
|
||||
"ActionRecords": ActionRecords,
|
||||
"PersonInfo": PersonInfo,
|
||||
"ChatStreams": ChatStreams,
|
||||
"LLMUsage": LLMUsage,
|
||||
"Emoji": Emoji,
|
||||
"Images": Images,
|
||||
"ImageDescriptions": ImageDescriptions,
|
||||
"OnlineTime": OnlineTime,
|
||||
"Memory": Memory,
|
||||
"Expression": Expression,
|
||||
"ThinkingLog": ThinkingLog,
|
||||
"GraphNodes": GraphNodes,
|
||||
"GraphEdges": GraphEdges,
|
||||
"Schedule": Schedule,
|
||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||
"CacheEntries": CacheEntries,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
@@ -225,10 +238,7 @@ async def db_query(
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
model_class: Type[Base], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
@@ -246,9 +256,9 @@ async def db_save(
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
existing_record = (
|
||||
session.query(model_class).filter(getattr(model_class, key_field) == key_value).first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -312,7 +322,7 @@ async def db_get(
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
single_result=single_result,
|
||||
)
|
||||
|
||||
|
||||
@@ -347,7 +357,7 @@ async def store_action_info(
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": orjson.dumps(action_data or {}).decode('utf-8'),
|
||||
"action_data": orjson.dumps(action_data or {}).decode("utf-8"),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
@@ -355,24 +365,25 @@ async def store_action_info(
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
@@ -386,4 +397,3 @@ async def store_action_info(
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_engine, initialize_database
|
||||
)
|
||||
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
@@ -18,23 +16,23 @@ def initialize_sqlalchemy_database() -> bool:
|
||||
"""
|
||||
初始化SQLAlchemy数据库
|
||||
创建所有表结构
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化SQLAlchemy数据库...")
|
||||
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine, session_local = initialize_database()
|
||||
|
||||
|
||||
if engine is None:
|
||||
logger.error("数据库引擎初始化失败")
|
||||
return False
|
||||
|
||||
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
||||
return False
|
||||
@@ -46,24 +44,24 @@ def initialize_sqlalchemy_database() -> bool:
|
||||
def create_all_tables() -> bool:
|
||||
"""
|
||||
创建所有数据库表
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 创建是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始创建数据库表...")
|
||||
|
||||
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
logger.error("无法获取数据库引擎")
|
||||
return False
|
||||
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
logger.info("数据库表创建成功")
|
||||
return True
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"创建数据库表失败: {e}")
|
||||
return False
|
||||
@@ -72,11 +70,10 @@ def create_all_tables() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_database_info() -> Optional[dict]:
|
||||
"""
|
||||
获取数据库信息
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 数据库信息字典,包含引擎信息等
|
||||
"""
|
||||
@@ -84,17 +81,17 @@ def get_database_info() -> Optional[dict]:
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
return None
|
||||
|
||||
|
||||
info = {
|
||||
'engine_name': engine.name,
|
||||
'driver': engine.driver,
|
||||
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
|
||||
'pool_size': getattr(engine.pool, 'size', None),
|
||||
'max_overflow': getattr(engine.pool, 'max_overflow', None),
|
||||
"engine_name": engine.name,
|
||||
"driver": engine.driver,
|
||||
"url": str(engine.url).replace(engine.url.password or "", "***"), # 隐藏密码
|
||||
"pool_size": getattr(engine.pool, "size", None),
|
||||
"max_overflow": getattr(engine.pool, "max_overflow", None),
|
||||
}
|
||||
|
||||
|
||||
return info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {e}")
|
||||
return None
|
||||
@@ -102,24 +99,25 @@ def get_database_info() -> Optional[dict]:
|
||||
|
||||
_database_initialized = False
|
||||
|
||||
|
||||
def initialize_database_compat() -> bool:
|
||||
"""
|
||||
兼容性数据库初始化函数
|
||||
用于替换原有的Peewee初始化代码
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global _database_initialized
|
||||
|
||||
|
||||
if _database_initialized:
|
||||
return True
|
||||
|
||||
|
||||
success = initialize_sqlalchemy_database()
|
||||
if success:
|
||||
success = create_all_tables()
|
||||
|
||||
|
||||
if success:
|
||||
_database_initialized = True
|
||||
|
||||
return success
|
||||
|
||||
return success
|
||||
|
||||
@@ -19,6 +19,7 @@ logger = get_logger("sqlalchemy_models")
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
@@ -26,6 +27,7 @@ def get_string_field(max_length=255, **kwargs):
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config.database.database_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
@@ -34,7 +36,8 @@ def get_string_field(max_length=255, **kwargs):
|
||||
|
||||
class ChatStreams(Base):
|
||||
"""聊天流模型"""
|
||||
__tablename__ = 'chat_streams'
|
||||
|
||||
__tablename__ = "chat_streams"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||
@@ -50,15 +53,16 @@ class ChatStreams(Base):
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_chatstreams_stream_id', 'stream_id'),
|
||||
Index('idx_chatstreams_user_id', 'user_id'),
|
||||
Index('idx_chatstreams_group_id', 'group_id'),
|
||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||
Index("idx_chatstreams_user_id", "user_id"),
|
||||
Index("idx_chatstreams_group_id", "group_id"),
|
||||
)
|
||||
|
||||
|
||||
class LLMUsage(Base):
|
||||
"""LLM使用记录模型"""
|
||||
__tablename__ = 'llm_usage'
|
||||
|
||||
__tablename__ = "llm_usage"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
@@ -76,19 +80,20 @@ class LLMUsage(Base):
|
||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_llmusage_model_name', 'model_name'),
|
||||
Index('idx_llmusage_model_assign_name', 'model_assign_name'),
|
||||
Index('idx_llmusage_model_api_provider', 'model_api_provider'),
|
||||
Index('idx_llmusage_time_cost', 'time_cost'),
|
||||
Index('idx_llmusage_user_id', 'user_id'),
|
||||
Index('idx_llmusage_request_type', 'request_type'),
|
||||
Index('idx_llmusage_timestamp', 'timestamp'),
|
||||
Index("idx_llmusage_model_name", "model_name"),
|
||||
Index("idx_llmusage_model_assign_name", "model_assign_name"),
|
||||
Index("idx_llmusage_model_api_provider", "model_api_provider"),
|
||||
Index("idx_llmusage_time_cost", "time_cost"),
|
||||
Index("idx_llmusage_user_id", "user_id"),
|
||||
Index("idx_llmusage_request_type", "request_type"),
|
||||
Index("idx_llmusage_timestamp", "timestamp"),
|
||||
)
|
||||
|
||||
|
||||
class Emoji(Base):
|
||||
"""表情包模型"""
|
||||
__tablename__ = 'emoji'
|
||||
|
||||
__tablename__ = "emoji"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
@@ -105,14 +110,15 @@ class Emoji(Base):
|
||||
last_used_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_emoji_full_path', 'full_path'),
|
||||
Index('idx_emoji_hash', 'emoji_hash'),
|
||||
Index("idx_emoji_full_path", "full_path"),
|
||||
Index("idx_emoji_hash", "emoji_hash"),
|
||||
)
|
||||
|
||||
|
||||
class Messages(Base):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
@@ -153,16 +159,17 @@ class Messages(Base):
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_messages_message_id', 'message_id'),
|
||||
Index('idx_messages_chat_id', 'chat_id'),
|
||||
Index('idx_messages_time', 'time'),
|
||||
Index('idx_messages_user_id', 'user_id'),
|
||||
Index("idx_messages_message_id", "message_id"),
|
||||
Index("idx_messages_chat_id", "chat_id"),
|
||||
Index("idx_messages_time", "time"),
|
||||
Index("idx_messages_user_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
class ActionRecords(Base):
|
||||
"""动作记录模型"""
|
||||
__tablename__ = 'action_records'
|
||||
|
||||
__tablename__ = "action_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
@@ -177,15 +184,16 @@ class ActionRecords(Base):
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_actionrecords_action_id', 'action_id'),
|
||||
Index('idx_actionrecords_chat_id', 'chat_id'),
|
||||
Index('idx_actionrecords_time', 'time'),
|
||||
Index("idx_actionrecords_action_id", "action_id"),
|
||||
Index("idx_actionrecords_chat_id", "chat_id"),
|
||||
Index("idx_actionrecords_time", "time"),
|
||||
)
|
||||
|
||||
|
||||
class Images(Base):
|
||||
"""图像信息模型"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
__tablename__ = "images"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
image_id = Column(Text, nullable=False, default="")
|
||||
@@ -198,14 +206,15 @@ class Images(Base):
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_images_emoji_hash', 'emoji_hash'),
|
||||
Index('idx_images_path', 'path'),
|
||||
Index("idx_images_emoji_hash", "emoji_hash"),
|
||||
Index("idx_images_path", "path"),
|
||||
)
|
||||
|
||||
|
||||
class ImageDescriptions(Base):
|
||||
"""图像描述信息模型"""
|
||||
__tablename__ = 'image_descriptions'
|
||||
|
||||
__tablename__ = "image_descriptions"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(Text, nullable=False)
|
||||
@@ -213,14 +222,13 @@ class ImageDescriptions(Base):
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_imagedesc_hash', 'image_description_hash'),
|
||||
)
|
||||
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
||||
|
||||
|
||||
class Videos(Base):
|
||||
"""视频信息模型"""
|
||||
__tablename__ = 'videos'
|
||||
|
||||
__tablename__ = "videos"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
video_id = Column(Text, nullable=False, default="")
|
||||
@@ -229,7 +237,7 @@ class Videos(Base):
|
||||
count = Column(Integer, nullable=False, default=1)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
# 视频特有属性
|
||||
duration = Column(Float, nullable=True) # 视频时长(秒)
|
||||
frame_count = Column(Integer, nullable=True) # 总帧数
|
||||
@@ -238,14 +246,15 @@ class Videos(Base):
|
||||
file_size = Column(Integer, nullable=True) # 文件大小(字节)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_videos_video_hash', 'video_hash'),
|
||||
Index('idx_videos_timestamp', 'timestamp'),
|
||||
Index("idx_videos_video_hash", "video_hash"),
|
||||
Index("idx_videos_timestamp", "timestamp"),
|
||||
)
|
||||
|
||||
|
||||
class OnlineTime(Base):
|
||||
"""在线时长记录模型"""
|
||||
__tablename__ = 'online_time'
|
||||
|
||||
__tablename__ = "online_time"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||
@@ -253,14 +262,13 @@ class OnlineTime(Base):
|
||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
|
||||
)
|
||||
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
||||
|
||||
|
||||
class PersonInfo(Base):
|
||||
"""人物信息模型"""
|
||||
__tablename__ = 'person_info'
|
||||
|
||||
__tablename__ = "person_info"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
@@ -280,14 +288,15 @@ class PersonInfo(Base):
|
||||
attitude = Column(Integer, nullable=True, default=50)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_personinfo_person_id', 'person_id'),
|
||||
Index('idx_personinfo_user_id', 'user_id'),
|
||||
Index("idx_personinfo_person_id", "person_id"),
|
||||
Index("idx_personinfo_user_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
__tablename__ = 'memory'
|
||||
|
||||
__tablename__ = "memory"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
@@ -297,14 +306,13 @@ class Memory(Base):
|
||||
create_time = Column(Float, nullable=True)
|
||||
last_view_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_memory_id', 'memory_id'),
|
||||
)
|
||||
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
||||
|
||||
|
||||
class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
__tablename__ = "expression"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
@@ -315,14 +323,13 @@ class Expression(Base):
|
||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
)
|
||||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
||||
|
||||
|
||||
class ThinkingLog(Base):
|
||||
"""思考日志模型"""
|
||||
__tablename__ = 'thinking_logs'
|
||||
|
||||
__tablename__ = "thinking_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
@@ -338,14 +345,13 @@ class ThinkingLog(Base):
|
||||
reasoning_data_json = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_thinkinglog_chat_id', 'chat_id'),
|
||||
)
|
||||
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
||||
|
||||
|
||||
class GraphNodes(Base):
|
||||
"""记忆图节点模型"""
|
||||
__tablename__ = 'graph_nodes'
|
||||
|
||||
__tablename__ = "graph_nodes"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
@@ -354,14 +360,13 @@ class GraphNodes(Base):
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphnodes_concept', 'concept'),
|
||||
)
|
||||
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
||||
|
||||
|
||||
class GraphEdges(Base):
|
||||
"""记忆图边模型"""
|
||||
__tablename__ = 'graph_edges'
|
||||
|
||||
__tablename__ = "graph_edges"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(get_string_field(255), nullable=False, index=True)
|
||||
@@ -372,14 +377,15 @@ class GraphEdges(Base):
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphedges_source', 'source'),
|
||||
Index('idx_graphedges_target', 'target'),
|
||||
Index("idx_graphedges_source", "source"),
|
||||
Index("idx_graphedges_target", "target"),
|
||||
)
|
||||
|
||||
|
||||
class Schedule(Base):
|
||||
"""日程模型"""
|
||||
__tablename__ = 'schedule'
|
||||
|
||||
__tablename__ = "schedule"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
|
||||
@@ -387,17 +393,18 @@ class Schedule(Base):
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_schedule_date', 'date'),
|
||||
)
|
||||
__table_args__ = (Index("idx_schedule_date", "date"),)
|
||||
|
||||
|
||||
class MaiZoneScheduleStatus(Base):
|
||||
class MaiZoneScheduleStatus(Base):
|
||||
"""麦麦空间日程处理状态模型"""
|
||||
__tablename__ = 'maizone_schedule_status'
|
||||
|
||||
__tablename__ = "maizone_schedule_status"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
datetime_hour = Column(get_string_field(13), nullable=False, unique=True, index=True) # YYYY-MM-DD HH格式,精确到小时
|
||||
datetime_hour = Column(
|
||||
get_string_field(13), nullable=False, unique=True, index=True
|
||||
) # YYYY-MM-DD HH格式,精确到小时
|
||||
activity = Column(Text, nullable=False) # 该小时的活动内容
|
||||
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
|
||||
processed_at = Column(DateTime, nullable=True) # 处理时间
|
||||
@@ -407,14 +414,15 @@ class MaiZoneScheduleStatus(Base):
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_maizone_datetime_hour', 'datetime_hour'),
|
||||
Index('idx_maizone_is_processed', 'is_processed'),
|
||||
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
||||
Index("idx_maizone_is_processed", "is_processed"),
|
||||
)
|
||||
|
||||
|
||||
class BanUser(Base):
|
||||
"""被禁用用户模型"""
|
||||
__tablename__ = 'ban_users'
|
||||
|
||||
__tablename__ = "ban_users"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
platform = Column(Text, nullable=False)
|
||||
@@ -424,113 +432,120 @@ class BanUser(Base):
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_violation_num', 'violation_num'),
|
||||
Index('idx_banuser_user_id', 'user_id'),
|
||||
Index('idx_banuser_platform', 'platform'),
|
||||
Index('idx_banuser_platform_user_id', 'platform', 'user_id'),
|
||||
Index("idx_violation_num", "violation_num"),
|
||||
Index("idx_banuser_user_id", "user_id"),
|
||||
Index("idx_banuser_platform", "platform"),
|
||||
Index("idx_banuser_platform_user_id", "platform", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
class AntiInjectionStats(Base):
|
||||
"""反注入系统统计模型"""
|
||||
__tablename__ = 'anti_injection_stats'
|
||||
|
||||
__tablename__ = "anti_injection_stats"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
total_messages = Column(Integer, nullable=False, default=0)
|
||||
"""总处理消息数"""
|
||||
|
||||
|
||||
detected_injections = Column(Integer, nullable=False, default=0)
|
||||
"""检测到的注入攻击数"""
|
||||
|
||||
|
||||
blocked_messages = Column(Integer, nullable=False, default=0)
|
||||
"""被阻止的消息数"""
|
||||
|
||||
|
||||
shielded_messages = Column(Integer, nullable=False, default=0)
|
||||
"""被加盾的消息数"""
|
||||
|
||||
|
||||
processing_time_total = Column(Float, nullable=False, default=0.0)
|
||||
"""总处理时间"""
|
||||
|
||||
|
||||
total_process_time = Column(Float, nullable=False, default=0.0)
|
||||
"""累计总处理时间"""
|
||||
|
||||
|
||||
last_process_time = Column(Float, nullable=False, default=0.0)
|
||||
"""最近一次处理时间"""
|
||||
|
||||
|
||||
error_count = Column(Integer, nullable=False, default=0)
|
||||
"""错误计数"""
|
||||
|
||||
|
||||
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
"""统计开始时间"""
|
||||
|
||||
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
"""记录创建时间"""
|
||||
|
||||
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
"""记录更新时间"""
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_anti_injection_stats_created_at', 'created_at'),
|
||||
Index('idx_anti_injection_stats_updated_at', 'updated_at'),
|
||||
Index("idx_anti_injection_stats_created_at", "created_at"),
|
||||
Index("idx_anti_injection_stats_updated_at", "updated_at"),
|
||||
)
|
||||
|
||||
|
||||
class CacheEntries(Base):
|
||||
"""工具缓存条目模型"""
|
||||
__tablename__ = 'cache_entries'
|
||||
|
||||
__tablename__ = "cache_entries"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
"""缓存键,包含工具名、参数和代码哈希"""
|
||||
|
||||
|
||||
cache_value = Column(Text, nullable=False)
|
||||
"""缓存的数据,JSON格式"""
|
||||
|
||||
|
||||
expires_at = Column(Float, nullable=False, index=True)
|
||||
"""过期时间戳"""
|
||||
|
||||
|
||||
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
"""工具名称"""
|
||||
|
||||
|
||||
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
||||
"""创建时间戳"""
|
||||
|
||||
|
||||
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
||||
"""最后访问时间戳"""
|
||||
|
||||
|
||||
access_count = Column(Integer, nullable=False, default=0)
|
||||
"""访问次数"""
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_cache_entries_key', 'cache_key'),
|
||||
Index('idx_cache_entries_expires_at', 'expires_at'),
|
||||
Index('idx_cache_entries_tool_name', 'tool_name'),
|
||||
Index('idx_cache_entries_created_at', 'created_at'),
|
||||
Index("idx_cache_entries_key", "cache_key"),
|
||||
Index("idx_cache_entries_expires_at", "expires_at"),
|
||||
Index("idx_cache_entries_tool_name", "tool_name"),
|
||||
Index("idx_cache_entries_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class MonthlyPlan(Base):
|
||||
"""月度计划模型"""
|
||||
__tablename__ = 'monthly_plans'
|
||||
|
||||
__tablename__ = "monthly_plans"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
plan_text = Column(Text, nullable=False)
|
||||
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
|
||||
status = Column(get_string_field(20), nullable=False, default='active', index=True) # 'active', 'completed', 'archived'
|
||||
status = Column(
|
||||
get_string_field(20), nullable=False, default="active", index=True
|
||||
) # 'active', 'completed', 'archived'
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
|
||||
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
||||
is_deleted = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_monthlyplan_target_month_status', 'target_month', 'status'),
|
||||
Index('idx_monthlyplan_last_used_date', 'last_used_date'),
|
||||
Index('idx_monthlyplan_usage_count', 'usage_count'),
|
||||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
||||
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
||||
Index("idx_monthlyplan_usage_count", "usage_count"),
|
||||
# 保留旧索引以兼容
|
||||
Index('idx_monthlyplan_target_month_is_deleted', 'target_month', 'is_deleted'),
|
||||
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
@@ -539,14 +554,16 @@ _SessionLocal = None
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
from src.config.config import global_config
|
||||
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# 对用户名和密码进行URL编码,处理特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
|
||||
# 检查是否配置了Unix socket连接
|
||||
if config.mysql_unix_socket:
|
||||
# 使用Unix socket连接
|
||||
@@ -586,51 +603,57 @@ def initialize_database():
|
||||
|
||||
database_url = get_database_url()
|
||||
from src.config.config import global_config
|
||||
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs: Dict[str, Any] = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
"echo": False, # 生产环境关闭SQL日志
|
||||
"future": True,
|
||||
}
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': config.connection_pool_size,
|
||||
'max_overflow': config.connection_pool_size * 2,
|
||||
'pool_timeout': config.connection_timeout,
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'autocommit': config.mysql_autocommit,
|
||||
'charset': config.mysql_charset,
|
||||
'connect_timeout': config.connection_timeout,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"poolclass": QueuePool,
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600, # 1小时回收连接
|
||||
"pool_pre_ping": True, # 连接前ping检查
|
||||
"connect_args": {
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
"read_timeout": 30,
|
||||
"write_timeout": 30,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
else:
|
||||
# SQLite配置 - 添加连接池设置以避免连接耗尽
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': 20, # 增加池大小
|
||||
'max_overflow': 30, # 增加溢出连接数
|
||||
'pool_timeout': 60, # 增加超时时间
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'check_same_thread': False,
|
||||
'timeout': 30,
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"poolclass": QueuePool,
|
||||
"pool_size": 20, # 增加池大小
|
||||
"max_overflow": 30, # 增加溢出连接数
|
||||
"pool_timeout": 60, # 增加超时时间
|
||||
"pool_recycle": 3600, # 1小时回收连接
|
||||
"pool_pre_ping": True, # 连接前ping检查
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 30,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
# 调用新的迁移函数,它会处理表的创建和列的添加
|
||||
from src.common.database.db_migration import check_and_migrate_database
|
||||
|
||||
check_and_migrate_database()
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
@@ -647,7 +670,7 @@ def get_db_session() -> Iterator[Session]:
|
||||
raise RuntimeError("Database session not initialized")
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
#session.commit()
|
||||
# session.commit()
|
||||
except Exception:
|
||||
if session:
|
||||
session.rollback()
|
||||
@@ -655,7 +678,6 @@ def get_db_session() -> Iterator[Session]:
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
|
||||
def get_engine():
|
||||
@@ -666,7 +688,8 @@ def get_engine():
|
||||
|
||||
class PermissionNodes(Base):
|
||||
"""权限节点模型"""
|
||||
__tablename__ = 'permission_nodes'
|
||||
|
||||
__tablename__ = "permission_nodes"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
||||
@@ -674,16 +697,17 @@ class PermissionNodes(Base):
|
||||
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
||||
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_permission_plugin', 'plugin_name'),
|
||||
Index('idx_permission_node', 'node_name'),
|
||||
Index("idx_permission_plugin", "plugin_name"),
|
||||
Index("idx_permission_node", "node_name"),
|
||||
)
|
||||
|
||||
|
||||
class UserPermissions(Base):
|
||||
"""用户权限模型"""
|
||||
__tablename__ = 'user_permissions'
|
||||
|
||||
__tablename__ = "user_permissions"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
||||
@@ -692,9 +716,9 @@ class UserPermissions(Base):
|
||||
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
||||
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
||||
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
||||
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_user_platform_id', 'platform', 'user_id'),
|
||||
Index('idx_user_permission', 'platform', 'user_id', 'permission_node'),
|
||||
Index('idx_permission_granted', 'permission_node', 'granted'),
|
||||
Index("idx_user_platform_id", "platform", "user_id"),
|
||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||
Index("idx_permission_granted", "permission_node", "granted"),
|
||||
)
|
||||
|
||||
@@ -194,8 +194,20 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"log_level": "INFO", # 全局日志级别(向下兼容)
|
||||
"console_log_level": "INFO", # 控制台日志级别
|
||||
"file_log_level": "DEBUG", # 文件日志级别
|
||||
"suppress_libraries": ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"],
|
||||
"library_log_levels": { "aiohttp": "WARNING"},
|
||||
"suppress_libraries": [
|
||||
"faiss",
|
||||
"httpx",
|
||||
"urllib3",
|
||||
"asyncio",
|
||||
"websockets",
|
||||
"httpcore",
|
||||
"requests",
|
||||
"peewee",
|
||||
"openai",
|
||||
"uvicorn",
|
||||
"jieba",
|
||||
],
|
||||
"library_log_levels": {"aiohttp": "WARNING"},
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -363,7 +375,7 @@ MODULE_COLORS = {
|
||||
"base_command": "\033[38;5;208m", # 橙色
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"plugin_hot_reload": "\033[38;5;226m", #品红色
|
||||
"plugin_hot_reload": "\033[38;5;226m", # 品红色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
@@ -406,14 +418,12 @@ MODULE_COLORS = {
|
||||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
"sqlalchemy_init": "\033[38;5;105m", #
|
||||
"sqlalchemy_init": "\033[38;5;105m", #
|
||||
"sqlalchemy_models": "\033[38;5;105m",
|
||||
"sqlalchemy_database_api": "\033[38;5;105m",
|
||||
|
||||
#s4u
|
||||
# s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 亮绿色
|
||||
|
||||
# API相关扩展
|
||||
"chat_api": "\033[38;5;34m", # 深绿色
|
||||
"emoji_api": "\033[38;5;40m", # 亮绿色
|
||||
@@ -422,20 +432,17 @@ MODULE_COLORS = {
|
||||
"tool_api": "\033[38;5;76m", # 绿色
|
||||
"OpenAI客户端": "\033[38;5;81m",
|
||||
"Gemini客户端": "\033[38;5;81m",
|
||||
|
||||
# 插件系统扩展
|
||||
"plugin_base": "\033[38;5;196m", # 红色
|
||||
"base_event_handler": "\033[38;5;203m", # 粉红色
|
||||
"events_manager": "\033[38;5;209m", # 橙红色
|
||||
"global_announcement_manager": "\033[38;5;215m", # 浅橙色
|
||||
|
||||
# 工具和依赖管理
|
||||
"dependency_config": "\033[38;5;24m", # 深蓝色
|
||||
"dependency_manager": "\033[38;5;30m", # 深青色
|
||||
"manifest_utils": "\033[38;5;39m", # 蓝色
|
||||
"schedule_manager": "\033[38;5;27m", # 深蓝色
|
||||
"monthly_plan_manager": "\033[38;5;171m",
|
||||
|
||||
# 聊天和多媒体扩展
|
||||
"chat_voice": "\033[38;5;87m", # 浅青色
|
||||
"typo_gen": "\033[38;5;123m", # 天蓝色
|
||||
@@ -444,14 +451,12 @@ MODULE_COLORS = {
|
||||
"relationship_builder_manager": "\033[38;5;176m", # 浅紫色
|
||||
"expression_selector": "\033[38;5;176m",
|
||||
"chat_message_builder": "\033[38;5;176m",
|
||||
|
||||
# MaiZone QQ空间相关
|
||||
"MaiZone": "\033[38;5;98m", # 紫色
|
||||
"MaiZone-Monitor": "\033[38;5;104m", # 深紫色
|
||||
"MaiZone.ConfigLoader": "\033[38;5;110m", # 蓝紫色
|
||||
"MaiZone-Scheduler": "\033[38;5;134m", # 紫红色
|
||||
"MaiZone-Utils": "\033[38;5;140m", # 浅紫色
|
||||
|
||||
# MaiZone Refactored
|
||||
"MaiZone.HistoryUtils": "\033[38;5;140m",
|
||||
"MaiZone.SchedulerService": "\033[38;5;134m",
|
||||
@@ -464,13 +469,11 @@ MODULE_COLORS = {
|
||||
"MaiZone.SendFeedCommand": "\033[38;5;134m",
|
||||
"MaiZone.SendFeedAction": "\033[38;5;134m",
|
||||
"MaiZone.ReadFeedAction": "\033[38;5;134m",
|
||||
|
||||
# 网络工具
|
||||
"web_surfing_tool": "\033[38;5;130m", # 棕色
|
||||
"tts": "\033[38;5;136m", # 浅棕色
|
||||
"poke_plugin": "\033[38;5;136m",
|
||||
"set_emoji_like_plugin": "\033[38;5;136m",
|
||||
|
||||
# mais4u系统扩展
|
||||
"s4u_config": "\033[38;5;18m", # 深蓝色
|
||||
"action": "\033[38;5;52m", # 深红色(mais4u的action)
|
||||
@@ -481,7 +484,6 @@ MODULE_COLORS = {
|
||||
"watching": "\033[38;5;131m", # 深橙色
|
||||
"offline_llm": "\033[38;5;236m", # 深灰色
|
||||
"s4u_stream_generator": "\033[38;5;60m", # 深紫色
|
||||
|
||||
# 其他工具
|
||||
"消息压缩工具": "\033[38;5;244m", # 灰色
|
||||
"lpmm_get_knowledge_tool": "\033[38;5;102m", # 绿色
|
||||
@@ -545,42 +547,36 @@ MODULE_ALIASES = {
|
||||
"replyer": "言语",
|
||||
"config": "配置",
|
||||
"main": "主程序",
|
||||
|
||||
# API相关扩展
|
||||
"chat_api": "聊天接口",
|
||||
"emoji_api": "表情接口",
|
||||
"generator_api": "生成接口",
|
||||
"person_api": "人物接口",
|
||||
"tool_api": "工具接口",
|
||||
|
||||
# 插件系统扩展
|
||||
"plugin_base": "插件基类",
|
||||
"base_event_handler": "事件处理",
|
||||
"events_manager": "事件管理",
|
||||
"global_announcement_manager": "全局通知",
|
||||
"event_manager"
|
||||
|
||||
# 工具和依赖管理
|
||||
"dependency_config": "依赖配置",
|
||||
"dependency_manager": "依赖管理",
|
||||
"manifest_utils": "清单工具",
|
||||
"schedule_manager": "计划管理",
|
||||
"monthly_plan_manager": "月度计划",
|
||||
|
||||
# 聊天和多媒体扩展
|
||||
"chat_voice": "语音处理",
|
||||
"typo_gen": "错字生成",
|
||||
"src.chat.utils.utils_video": "视频分析",
|
||||
"ReplyerManager": "回复管理",
|
||||
"relationship_builder_manager": "关系管理",
|
||||
|
||||
# MaiZone QQ空间相关
|
||||
"MaiZone": "Mai空间",
|
||||
"MaiZone-Monitor": "Mai空间监控",
|
||||
"MaiZone.ConfigLoader": "Mai空间配置",
|
||||
"MaiZone-Scheduler": "Mai空间调度",
|
||||
"MaiZone-Utils": "Mai空间工具",
|
||||
|
||||
# MaiZone Refactored
|
||||
"MaiZone.HistoryUtils": "Mai空间历史",
|
||||
"MaiZone.SchedulerService": "Mai空间调度",
|
||||
@@ -593,12 +589,9 @@ MODULE_ALIASES = {
|
||||
"MaiZone.SendFeedCommand": "Mai空间发说说",
|
||||
"MaiZone.SendFeedAction": "Mai空间发说说",
|
||||
"MaiZone.ReadFeedAction": "Mai空间读说说",
|
||||
|
||||
# 网络工具
|
||||
"web_surfing_tool": "网络搜索",
|
||||
"tts": "语音合成",
|
||||
|
||||
|
||||
# mais4u系统扩展
|
||||
"s4u_config": "直播配置",
|
||||
"action": "直播动作",
|
||||
@@ -609,7 +602,6 @@ MODULE_ALIASES = {
|
||||
"watching": "观看状态",
|
||||
"offline_llm": "离线模型",
|
||||
"s4u_stream_generator": "直播生成",
|
||||
|
||||
# 其他工具
|
||||
"消息压缩工具": "消息压缩",
|
||||
"lpmm_get_knowledge_tool": "知识获取",
|
||||
@@ -640,7 +632,7 @@ MODULE_ALIASES = {
|
||||
"db_migration": "数据库迁移",
|
||||
"小彩蛋": "小彩蛋",
|
||||
"AioHTTP-Gemini客户端": "AioHTTP-Gemini客户端",
|
||||
"event_manager" : "事件管理器"
|
||||
"event_manager": "事件管理器",
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
@@ -735,7 +727,7 @@ class ModuleColoredConsoleRenderer:
|
||||
if logger_name:
|
||||
# 获取别名,如果没有别名则使用原名称
|
||||
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
||||
|
||||
|
||||
if self._colors and self._enable_module_colors:
|
||||
if module_color:
|
||||
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
||||
|
||||
@@ -13,9 +13,11 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
@@ -193,9 +195,9 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
from .base import VectorDBBase
|
||||
from .chromadb_impl import ChromaDBImpl
|
||||
|
||||
|
||||
def get_vector_db_service() -> VectorDBBase:
|
||||
"""
|
||||
工厂函数,初始化并返回向量数据库服务实例。
|
||||
|
||||
|
||||
目前硬编码为 ChromaDB,未来可以从配置中读取。
|
||||
"""
|
||||
# TODO: 从全局配置中读取数据库类型和路径
|
||||
db_path = "data/chroma_db"
|
||||
|
||||
|
||||
# ChromaDBImpl 是一个单例,所以这里每次调用都会返回同一个实例
|
||||
return ChromaDBImpl(path=db_path)
|
||||
|
||||
|
||||
# 全局向量数据库服务实例
|
||||
vector_db_service: VectorDBBase = get_vector_db_service()
|
||||
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
向量数据库的抽象基类 (ABC),定义了所有向量数据库实现必须遵循的接口。
|
||||
@@ -133,7 +134,7 @@ class VectorDBBase(ABC):
|
||||
int: 条目总数。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, name: str) -> None:
|
||||
"""
|
||||
@@ -142,4 +143,4 @@ class VectorDBBase(ABC):
|
||||
Args:
|
||||
name (str): 要删除的集合的名称。
|
||||
"""
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -9,11 +9,13 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
|
||||
class ChromaDBImpl(VectorDBBase):
|
||||
"""
|
||||
ChromaDB 的具体实现,遵循 VectorDBBase 接口。
|
||||
采用单例模式,确保全局只有一个 ChromaDB 客户端实例。
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
@@ -29,13 +31,12 @@ class ChromaDBImpl(VectorDBBase):
|
||||
初始化 ChromaDB 客户端。
|
||||
由于是单例,这个初始化只会执行一次。
|
||||
"""
|
||||
if not hasattr(self, '_initialized'):
|
||||
if not hasattr(self, "_initialized"):
|
||||
with self._lock:
|
||||
if not hasattr(self, '_initialized'):
|
||||
if not hasattr(self, "_initialized"):
|
||||
try:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
path=path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
@@ -48,10 +49,10 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def get_or_create_collection(self, name: str, **kwargs: Any) -> Any:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
|
||||
if name in self._collections:
|
||||
return self._collections[name]
|
||||
|
||||
|
||||
try:
|
||||
collection = self.client.get_or_create_collection(name=name, **kwargs)
|
||||
self._collections[name] = collection
|
||||
@@ -151,15 +152,15 @@ class ChromaDBImpl(VectorDBBase):
|
||||
except Exception as e:
|
||||
logger.error(f"获取集合 '{collection_name}' 计数失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def delete_collection(self, name: str) -> None:
|
||||
if not self.client:
|
||||
raise ConnectionError("ChromaDB 客户端未初始化")
|
||||
|
||||
|
||||
try:
|
||||
self.client.delete_collection(name=name)
|
||||
if name in self._collections:
|
||||
del self._collections[name]
|
||||
logger.info(f"集合 '{name}' 已被删除")
|
||||
except Exception as e:
|
||||
logger.error(f"删除集合 '{name}' 失败: {e}")
|
||||
logger.error(f"删除集合 '{name}' 失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user