数据库重构

This commit is contained in:
雅诺狐
2025-08-16 23:43:45 +08:00
parent 0f0619762b
commit d46d689c43
21 changed files with 834 additions and 1007 deletions

View File

@@ -11,10 +11,9 @@ from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.sqlalchemy_models import PersonInfo
from src.common.database.sqlalchemy_database_api import get_session
from src.common.database.sqlalchemy_database_api import get_db_session
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
session = get_session()
"""
PersonInfoManager 类方法功能摘要:
@@ -52,56 +51,37 @@ person_info_default = {
"attitude": 50,
}
# 统一的会话管理函数
def with_session(func):
"""装饰器为函数自动注入session参数"""
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
return await func(session, *args, **kwargs)
return async_wrapper
else:
def sync_wrapper(*args, **kwargs):
return func(session, *args, **kwargs)
return sync_wrapper
# 全局会话获取函数用于替换所有裸露的session使用
def _get_session():
"""获取数据库会话的统一函数"""
return get_session()
class PersonInfoManager:
def __init__(self):
"""初始化PersonInfoManager"""
from src.common.database.sqlalchemy_models import PersonInfo
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数仅对SQLite有效
if hasattr(db, "execute_sql"):
# 检查数据库类型只对SQLite执行PRAGMA语句
if global_config.database.database_type == "sqlite":
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# try:
# with get_db_session() as session:
# db.connect(reuse_if_open=True)
# # 设置连接池参数仅对SQLite有效
# if hasattr(db, "execute_sql"):
# # 检查数据库类型只对SQLite执行PRAGMA语句
# if global_config.database.database_type == "sqlite":
# # 设置SQLite优化参数
# db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
# db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
# db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
# db.create_tables([PersonInfo], safe=True)
# except Exception as e:
# logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
# # 初始化时读取所有person_name
try:
from src.common.database.sqlalchemy_models import PersonInfo
# 在这里获取会话
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_not(None)
)).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
# 在这里获取会话
with get_db_session() as session:
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_not(None)
)).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e:
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
@@ -121,7 +101,8 @@ class PersonInfoManager:
def _db_check_known_sync(p_id: str):
# 在需要时获取会话
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
@@ -133,7 +114,8 @@ class PersonInfoManager:
"""根据用户名获取用户ID"""
try:
# 在需要时获取会话
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -176,15 +158,16 @@ class PersonInfoManager:
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
with get_db_session() as session:
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
@@ -223,25 +206,26 @@ class PersonInfoManager:
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(p_data: dict):
try:
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
with get_db_session() as session:
try:
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
@@ -263,32 +247,33 @@ class PersonInfoManager:
def _db_update_sync(p_id: str, f_name: str, val_to_set):
start_time = time.time()
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
query_time = time.time()
with get_db_session() as session:
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
session.commit()
save_time = time.time()
if record:
setattr(record, f_name, val_to_set)
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
session.commit()
return True, False # Found and updated, no creation needed
else:
return True, False # Found and updated, no creation needed
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
session.rollback()
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
@@ -320,7 +305,8 @@ class PersonInfoManager:
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
return bool(record)
try:
@@ -430,7 +416,8 @@ class PersonInfoManager:
else:
def _db_check_name_exists_sync(name_to_check):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
@@ -471,14 +458,14 @@ class PersonInfoManager:
def _db_delete_sync(p_id: str):
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
session.delete(record)
session.commit()
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
session.delete(record)
session.commit()
return 1
return 0
except Exception as e:
session.rollback()
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0
@@ -497,7 +484,8 @@ class PersonInfoManager:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
@@ -531,27 +519,28 @@ class PersonInfoManager:
def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
with get_db_session() as session:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return []
elif val is None:
return []
return val
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
@@ -563,7 +552,8 @@ class PersonInfoManager:
result = {}
def _db_get_record_sync(p_id: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id)
@@ -608,10 +598,11 @@ class PersonInfoManager:
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
with get_db_session() as session:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
except Exception as e_query:
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
@@ -634,19 +625,20 @@ class PersonInfoManager:
def _db_get_or_create_sync(p_id: str, init_data: dict):
"""原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 记录存在,未创建
with get_db_session() as session:
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
session.commit()
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
except Exception as e:
session.rollback()
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
@@ -709,7 +701,8 @@ class PersonInfoManager:
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record: