数据库重构

This commit is contained in:
雅诺狐
2025-08-16 23:43:45 +08:00
parent 0f0619762b
commit d46d689c43
21 changed files with 834 additions and 1007 deletions

View File

@@ -14,7 +14,7 @@ from PIL import Image
from rich.traceback import install from rich.traceback import install
from sqlalchemy import select from sqlalchemy import select
from src.common.database.database import db from src.common.database.database import db
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Emoji, Images from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@@ -30,8 +30,6 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
session = get_session()
""" """
还没经过测试,有些地方数据库和内存数据同步可能不完全 还没经过测试,有些地方数据库和内存数据同步可能不完全
@@ -152,28 +150,29 @@ class MaiEmoji:
# --- 数据库操作 --- # --- 数据库操作 ---
try: try:
# 准备数据库记录 for emoji collection # 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else "" with get_db_session() as session:
emotion_str = ",".join(self.emotion) if self.emotion else ""
emoji = Emoji( emoji = Emoji(
emoji_hash=self.hash, emoji_hash=self.hash,
full_path=self.full_path, full_path=self.full_path,
format=self.format, format=self.format,
description=self.description, description=self.description,
emotion=emotion_str, # Store as comma-separated string emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value query_count=0, # Default value
is_registered=True, is_registered=True,
is_banned=False, # Default value is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time, register_time=self.register_time,
usage_count=self.usage_count, usage_count=self.usage_count,
last_used_time=self.last_used_time, last_used_time=self.last_used_time,
) )
session.add(emoji) session.add(emoji)
session.commit() session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") return True
return True
except Exception as db_error: except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
@@ -205,14 +204,15 @@ class MaiEmoji:
# 2. 删除数据库记录 # 2. 删除数据库记录
try: try:
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none() with get_db_session() as session:
if will_delete_emoji is None: will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") if will_delete_emoji is None:
result = 0 # Indicate no DB record was deleted logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
else: result = 0 # Indicate no DB record was deleted
session.delete(will_delete_emoji) else:
session.commit() session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record result = 1 # Successfully deleted one record
session.commit()
except Exception as e: except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
result = 0 result = 0
@@ -403,35 +403,36 @@ class EmojiManager:
def initialize(self) -> None: def initialize(self) -> None:
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
try: # try:
db.connect(reuse_if_open=True) # db.connect(reuse_if_open=True)
if db.is_closed(): # if db.is_closed():
raise RuntimeError("数据库连接失败") # raise RuntimeError("数据库连接失败")
_ensure_emoji_dir() # _ensure_emoji_dir()
self._initialized = True # 标记为已初始化 # self._initialized = True # 标记为已初始化
logger.info("EmojiManager初始化成功") # logger.info("EmojiManager初始化成功")
except Exception as e: # except Exception as e:
logger.error(f"EmojiManager初始化失败: {e}") # logger.error(f"EmojiManager初始化失败: {e}")
self._initialized = False # self._initialized = False
raise # raise
def _ensure_db(self) -> None: # def _ensure_db(self) -> None:
"""确保数据库已初始化""" # """确保数据库已初始化"""
if not self._initialized: # if not self._initialized:
self.initialize() # self.initialize()
if not self._initialized: # if not self._initialized:
raise RuntimeError("EmojiManager not initialized") # raise RuntimeError("EmojiManager not initialized")
def record_usage(self, emoji_hash: str) -> None: def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() with get_db_session() as session:
if emoji_update is None: emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") if emoji_update is None:
else: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
emoji_update.usage_count += 1 else:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time emoji_update.last_used_time = time.time() # Update last used time
session.commit() # Persist changes to DB session.commit()
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -659,11 +660,12 @@ class EmojiManager:
async def get_all_emoji_from_db(self) -> None: async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects""" """获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try: try:
self._ensure_db() with get_db_session() as session:
logger.debug("[数据库] 开始加载所有表情包记录 ...") self._ensure_db()
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = session.execute(select(Emoji)).scalars().all() emoji_instances = session.execute(select(Emoji)).scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances) emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量 # 更新内存中的列表和数量
self.emoji_objects = emoji_objects self.emoji_objects = emoji_objects
@@ -672,6 +674,7 @@ class EmojiManager:
logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。") logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。")
if load_errors > 0: if load_errors > 0:
logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。") logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。")
except Exception as e: except Exception as e:
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}") logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}")
@@ -688,7 +691,8 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表 list[MaiEmoji]: 表情包对象列表
""" """
try: try:
self._ensure_db() with get_db_session() as session:
self._ensure_db()
if emoji_hash: if emoji_hash:
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
@@ -703,7 +707,6 @@ class EmojiManager:
if load_errors > 0: if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
return emoji_objects return emoji_objects
except Exception as e: except Exception as e:
@@ -744,13 +747,15 @@ class EmojiManager:
# 如果内存中没有,从数据库查找 # 如果内存中没有,从数据库查找
self._ensure_db() self._ensure_db()
try: try:
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() with get_db_session() as session:
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
if emoji_record and emoji_record.description: if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description return emoji_record.description
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}") logger.error(f"从数据库查询表情包描述时出错: {e}")
return None return None
except Exception as e: except Exception as e:
@@ -905,13 +910,14 @@ class EmojiManager:
# 尝试从Images表获取已有的详细描述可能在收到表情包时已生成 # 尝试从Images表获取已有的详细描述可能在收到表情包时已生成
existing_description = None existing_description = None
try: try:
with get_db_session() as session:
# from src.common.database.database_model_compat import Images # from src.common.database.database_model_compat import Images
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji")) stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
existing_image = session.execute(stmt).scalar_one_or_none() existing_image = session.execute(stmt).scalar_one_or_none()
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
except Exception as e: except Exception as e:
logger.debug(f"查询已有描述时出错: {e}") logger.debug(f"查询已有描述时出错: {e}")

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -22,7 +22,6 @@ DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值 DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor") logger = get_logger("expressor")
session = get_session()
def format_create_date(timestamp: float) -> str: def format_create_date(timestamp: float) -> str:
""" """
@@ -204,7 +203,8 @@ class ExpressionLearner:
learnt_grammar_expressions = [] learnt_grammar_expressions = []
# 直接从数据库查询 # 直接从数据库查询
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))) with get_db_session() as session:
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
for expr in style_query.scalars(): for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time # 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
@@ -247,8 +247,9 @@ class ExpressionLearner:
对数据库中的所有表达方式应用全局衰减 对数据库中的所有表达方式应用全局衰减
""" """
try: try:
# 获取所有表达方式 with get_db_session() as session:
all_expressions = session.execute(select(Expression)).scalars() # 获取所有表达方式
all_expressions = session.execute(select(Expression)).scalars()
updated_count = 0 updated_count = 0
deleted_count = 0 deleted_count = 0
@@ -265,19 +266,19 @@ class ExpressionLearner:
if new_count <= 0.01: if new_count <= 0.01:
# 如果count太小删除这个表达方式 # 如果count太小删除这个表达方式
session.delete(expr) session.delete(expr)
session.commit()
deleted_count += 1 deleted_count += 1
else: else:
# 更新count # 更新count
expr.count = new_count expr.count = new_count
updated_count += 1 updated_count += 1
session.commit()
if updated_count > 0 or deleted_count > 0: if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e: except Exception as e:
session.rollback()
logger.error(f"数据库全局衰减失败: {e}") logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float: def calculate_decay_factor(self, time_diff_days: float) -> float:
@@ -355,43 +356,46 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items(): for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list: for new_expr in expr_list:
# 查找是否已存在相似表达方式 # 查找是否已存在相似表达方式
query = session.execute(select(Expression).where( with get_db_session() as session:
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id)
& (Expression.type == type) & (Expression.type == type)
& (Expression.situation == new_expr["situation"]) & (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"]) & (Expression.style == new_expr["style"])
)).scalar() )).scalar()
if query: if query:
expr_obj = query expr_obj = query
# 50%概率替换内容 # 50%概率替换内容
if random.random() < 0.5: if random.random() < 0.5:
expr_obj.situation = new_expr["situation"] expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"] expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1 expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time expr_obj.last_active_time = current_time
else: else:
new_expression = Expression( new_expression = Expression(
situation=new_expr["situation"], situation=new_expr["situation"],
style=new_expr["style"], style=new_expr["style"],
count=1, count=1,
last_active_time=current_time, last_active_time=current_time,
chat_id=chat_id, chat_id=chat_id,
type=type, type=type,
create_date=current_time, # 手动设置创建日期 create_date=current_time, # 手动设置创建日期
) )
session.add(new_expression) session.add(new_expression)
# 限制最大数量 session.commit()
exprs = list( # 限制最大数量
session.execute(select(Expression) exprs = list(
.where((Expression.chat_id == chat_id) & (Expression.type == type)) session.execute(select(Expression)
.order_by(Expression.count.asc())).scalars() .where((Expression.chat_id == chat_id) & (Expression.type == type))
) .order_by(Expression.count.asc())).scalars()
if len(exprs) > MAX_EXPRESSION_COUNT: )
# 删除count最小的多余表达方式 if len(exprs) > MAX_EXPRESSION_COUNT:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: # 删除count最小的多余表达方式
session.delete(expr) for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
session.commit() session.delete(expr)
return learnt_expressions session.commit()
return learnt_expressions
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
"""从指定聊天流学习表达方式 """从指定聊天流学习表达方式
@@ -574,8 +578,8 @@ class ExpressionLearnerManager:
continue continue
# 查重同chat_id+type+situation+style # 查重同chat_id+type+situation+style
with get_db_session() as session:
query = session.execute(select(Expression).where( query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id)
& (Expression.type == type_str) & (Expression.type == type_str)
& (Expression.situation == situation) & (Expression.situation == situation)
@@ -596,6 +600,8 @@ class ExpressionLearnerManager:
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
) )
session.add(new_expression) session.add(new_expression)
session.commit()
migrated_count += 1 migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@@ -627,21 +633,22 @@ class ExpressionLearnerManager:
使用last_active_time作为create_date的默认值 使用last_active_time作为create_date的默认值
""" """
try: try:
# 查找所有create_date为空的表达方式 with get_db_session() as session:
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars() # 查找所有create_date为空的表达方式
updated_count = 0 old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
updated_count = 0
for expr in old_expressions: for expr in old_expressions:
# 使用last_active_time作为create_date # 使用last_active_time作为create_date
expr.create_date = expr.last_active_time expr.create_date = expr.last_active_time
updated_count += 1 updated_count += 1
session.commit()
if updated_count > 0: if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
session.commit()
except Exception as e: except Exception as e:
session.rollback()
logger.error(f"迁移老数据创建日期失败: {e}") logger.error(f"迁移老数据创建日期失败: {e}")

View File

@@ -12,8 +12,7 @@ from src.common.logger import get_logger
from sqlalchemy import select from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
session = get_session()
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@@ -132,14 +131,14 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign # sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选 # 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式 # 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(select(Expression).where( style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)) ))
grammar_query = session.execute(select(Expression).where( grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)) ))
style_exprs = [ style_exprs = [
{ {
@@ -180,6 +179,7 @@ class ExpressionSelector:
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num) selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else: else:
selected_grammar = [] selected_grammar = []
return selected_style, selected_grammar return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
@@ -199,7 +199,8 @@ class ExpressionSelector:
if key not in updates_by_key: if key not in updates_by_key:
updates_by_key[key] = expr updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key: for chat_id, expr_type, situation, style in updates_by_key:
query = session.execute(select(Expression).where( with get_db_session() as session:
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id) (Expression.chat_id == chat_id)
& (Expression.type == expr_type) & (Expression.type == expr_type)
& (Expression.situation == situation) & (Expression.situation == situation)
@@ -211,10 +212,11 @@ class ExpressionSelector:
new_count = min(current_count + increment, 5.0) new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count expr_obj.count = new_count
expr_obj.last_active_time = time.time() expr_obj.last_active_time = time.time()
session.commit()
logger.debug( logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
) )
session.commit()
async def select_suitable_expressions_llm( async def select_suitable_expressions_llm(
self, self,

View File

@@ -19,7 +19,7 @@ from src.config.config import global_config, model_config
from sqlalchemy import select,insert,update,delete from sqlalchemy import select,insert,update,delete
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入 from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp, get_raw_msg_by_timestamp,
@@ -30,7 +30,6 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable
install(extra_lines=3) install(extra_lines=3)
session = get_session()
def calculate_information_content(text): def calculate_information_content(text):
"""计算文本的信息量(熵)""" """计算文本的信息量(熵)"""
@@ -862,13 +861,13 @@ class EntorhinalCortex:
for message in messages: for message in messages:
# 确保在更新前获取最新的 memorized_times # 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0) current_memorized_times = message.get("memorized_times", 0)
# 使用 SQLAlchemy 2.0 更新记录 with get_db_session() as session:
session.execute( session.execute(
update(Messages) update(Messages)
.where(Messages.message_id == message["message_id"]) .where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1) .values(memorized_times=current_memorized_times + 1)
) )
session.commit() session.commit()
return messages # 直接返回原始的消息列表 return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -882,253 +881,260 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()} with get_db_session() as session:
memory_nodes = list(self.memory_graph.G.nodes(data=True)) db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据 # 批量准备节点数据
nodes_to_create = [] nodes_to_create = []
nodes_to_update = [] nodes_to_update = []
nodes_to_delete = set() nodes_to_delete = set()
# 处理节点 # 处理节点
for concept, data in memory_nodes: for concept, data in memory_nodes:
if not concept or not isinstance(concept, str): if not concept or not isinstance(concept, str):
self.memory_graph.G.remove_node(concept) self.memory_graph.G.remove_node(concept)
continue
memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
if not memory_items:
self.memory_graph.G.remove_node(concept)
continue
# 计算内存中节点的特征值
memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
created_time = data.get("created_time", current_time)
last_modified = data.get("last_modified", current_time)
# 将memory_items转换为JSON字符串
try:
memory_items = [str(item) for item in memory_items]
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if not memory_items_json:
continue continue
except Exception:
self.memory_graph.G.remove_node(concept)
continue
if concept not in db_nodes: memory_items = data.get("memory_items", [])
nodes_to_create.append( if not isinstance(memory_items, list):
{ memory_items = [memory_items] if memory_items else []
"concept": concept,
"memory_items": memory_items_json, if not memory_items:
"hash": memory_hash, self.memory_graph.G.remove_node(concept)
"created_time": created_time, continue
"last_modified": last_modified,
} # 计算内存中节点的特征值
) memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items)
else: created_time = data.get("created_time", current_time)
db_node = db_nodes[concept] last_modified = data.get("last_modified", current_time)
if db_node.hash != memory_hash:
nodes_to_update.append( # 将memory_items转换为JSON字符串
try:
memory_items = [str(item) for item in memory_items]
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if not memory_items_json:
continue
except Exception:
self.memory_graph.G.remove_node(concept)
continue
if concept not in db_nodes:
nodes_to_create.append(
{ {
"concept": concept, "concept": concept,
"memory_items": memory_items_json, "memory_items": memory_items_json,
"hash": memory_hash, "hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
else:
db_node = db_nodes[concept]
if db_node.hash != memory_hash:
nodes_to_update.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"last_modified": last_modified,
}
)
# 计算需要删除的节点
memory_concepts = {concept for concept, _ in memory_nodes}
nodes_to_delete = set(db_nodes.keys()) - memory_concepts
# 批量处理节点
if nodes_to_create:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
if nodes_to_delete:
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 批量准备边数据
edges_to_create = []
edges_to_update = []
# 处理边
for source, target, data in memory_edges:
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = data.get("strength", 1)
created_time = data.get("created_time", current_time)
last_modified = data.get("last_modified", current_time)
if edge_key not in db_edge_dict:
edges_to_create.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
elif db_edge_dict[edge_key]["hash"] != edge_hash:
edges_to_update.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"last_modified": last_modified, "last_modified": last_modified,
} }
) )
# 计算需要删除的节点 # 计算需要删除的
memory_concepts = {concept for concept, _ in memory_nodes} memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
nodes_to_delete = set(db_nodes.keys()) - memory_concepts edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys
# 批量处理节点 # 批量处理
if nodes_to_create: if edges_to_create:
batch_size = 100 batch_size = 100
for i in range(0, len(nodes_to_create), batch_size): for i in range(0, len(edges_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size] batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch) session.execute(insert(GraphEdges), batch)
session.commit()
if nodes_to_update: if edges_to_update:
batch_size = 100 batch_size = 100
for i in range(0, len(nodes_to_update), batch_size): for i in range(0, len(edges_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size] batch = edges_to_update[i : i + batch_size]
for node_data in batch: for edge_data in batch:
session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
if edges_to_delete:
for source, target in edges_to_delete:
session.execute( session.execute(
update(GraphNodes) delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
) )
session.commit()
if nodes_to_delete: # 提交事务
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
session.commit() session.commit()
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
db_edge_dict = {}
for edge in db_edges:
edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target)
db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength}
# 批量准备边数据
edges_to_create = []
edges_to_update = []
# 处理边
for source, target, data in memory_edges:
edge_hash = self.hippocampus.calculate_edge_hash(source, target)
edge_key = (source, target)
strength = data.get("strength", 1)
created_time = data.get("created_time", current_time)
last_modified = data.get("last_modified", current_time)
if edge_key not in db_edge_dict:
edges_to_create.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
elif db_edge_dict[edge_key]["hash"] != edge_hash:
edges_to_update.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"last_modified": last_modified,
}
)
# 计算需要删除的边
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys
# 批量处理边
if edges_to_create:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
session.commit()
if edges_to_delete:
for source, target in edges_to_delete:
session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
session.commit()
end_time = time.time() end_time = time.time()
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}") logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}")
logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
async def resync_memory_to_db(self): async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据""" """清空数据库并重新同步所有记忆数据"""
start_time = time.time() start_time = time.time()
logger.info("[数据库] 开始重新同步所有记忆数据...") logger.info("[数据库] 开始重新同步所有记忆数据...")
# 清空数据库 # 清空数据库
clear_start = time.time() with get_db_session() as session:
session.execute(delete(GraphNodes)) clear_start = time.time()
session.execute(delete(GraphEdges)) session.execute(delete(GraphNodes))
session.commit() session.execute(delete(GraphEdges))
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}") clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
# 获取所有节点和边 # 获取所有节点和边
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
memory_edges = list(self.memory_graph.G.edges(data=True)) memory_edges = list(self.memory_graph.G.edges(data=True))
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 批量准备节点数据 # 批量准备节点数据
nodes_data = [] nodes_data = []
for concept, data in memory_nodes: for concept, data in memory_nodes:
memory_items = data.get("memory_items", []) memory_items = data.get("memory_items", [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
try: try:
memory_items = [str(item) for item in memory_items] memory_items = [str(item) for item in memory_items]
if memory_items_json := json.dumps(memory_items, ensure_ascii=False): if memory_items_json := json.dumps(memory_items, ensure_ascii=False):
nodes_data.append( nodes_data.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
except Exception as e:
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
continue
# 批量准备边数据
edges_data = []
for source, target, data in memory_edges:
try:
edges_data.append(
{ {
"concept": concept, "source": source,
"memory_items": memory_items_json, "target": target,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items), "strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", current_time), "created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time), "last_modified": data.get("last_modified", current_time),
} }
) )
except Exception as e:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
except Exception as e: # 批量写入节点
logger.error(f"准备节点 {concept} 数据时发生错误: {e}") node_start = time.time()
continue if nodes_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
# 批量准备边数据 # 批量写入边
edges_data = [] edge_start = time.time()
for source, target, data in memory_edges: if edges_data:
try: batch_size = 500 # 增加批量大小
edges_data.append( for i in range(0, len(edges_data), batch_size):
{ batch = edges_data[i : i + batch_size]
"source": source, session.execute(insert(GraphEdges), batch)
"target": target, session.commit()
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
except Exception as e:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
# 批量写入节点
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
session.commit()
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
# 批量写入边
edge_start = time.time()
if edges_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
edge_end = time.time() edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}") logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1145,77 +1151,79 @@ class EntorhinalCortex:
self.memory_graph.G.clear() self.memory_graph.G.clear()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = list(session.execute(select(GraphNodes)).scalars()) with get_db_session() as session:
for node in nodes: nodes = list(session.execute(select(GraphNodes)).scalars())
concept = node.concept for node in nodes:
try: concept = node.concept
memory_items = json.loads(node.memory_items) try:
if not isinstance(memory_items, list): memory_items = json.loads(node.memory_items)
memory_items = [memory_items] if memory_items else [] if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在
if not node.created_time or not node.last_modified:
need_update = True
# 更新数据库中的节点
update_data = {}
if not node.created_time:
update_data["created_time"] = current_time
if not node.last_modified:
update_data["last_modified"] = current_time
session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
last_modified = node.last_modified or current_time
# 添加节点到图中
self.memory_graph.G.add_node(
concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified
)
except Exception as e:
logger.error(f"加载节点 {concept} 时发生错误: {e}")
continue
# 从数据库加载所有边
edges = list(session.execute(select(GraphEdges)).scalars())
for edge in edges:
source = edge.source
target = edge.target
strength = edge.strength
# 检查时间字段是否存在 # 检查时间字段是否存在
if not node.created_time or not node.last_modified: if not edge.created_time or not edge.last_modified:
need_update = True need_update = True
# 更新数据库中的节点 # 更新数据库中的
update_data = {} update_data = {}
if not node.created_time: if not edge.created_time:
update_data["created_time"] = current_time update_data["created_time"] = current_time
if not node.last_modified: if not edge.last_modified:
update_data["last_modified"] = current_time update_data["last_modified"] = current_time
session.execute( session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
) )
session.commit()
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time created_time = edge.created_time or current_time
last_modified = node.last_modified or current_time last_modified = edge.last_modified or current_time
# 添加节点到图中 # 只有当源节点和目标节点都存在时才添加边
self.memory_graph.G.add_node( if source in self.memory_graph.G and target in self.memory_graph.G:
concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified self.memory_graph.G.add_edge(
) source, target, strength=strength, created_time=created_time, last_modified=last_modified
except Exception as e: )
logger.error(f"加载节点 {concept} 时发生错误: {e}") session.commit()
continue
# 从数据库加载所有边 if need_update:
edges = list(session.execute(select(GraphEdges)).scalars()) logger.info("[数据库] 已为缺失的时间字段进行补充")
for edge in edges:
source = edge.source
target = edge.target
strength = edge.strength
# 检查时间字段是否存在
if not edge.created_time or not edge.last_modified:
need_update = True
# 更新数据库中的边
update_data = {}
if not edge.created_time:
update_data["created_time"] = current_time
if not edge.last_modified:
update_data["last_modified"] = current_time
session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.created_time or current_time
last_modified = edge.last_modified or current_time
# 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(
source, target, strength=strength, created_time=created_time, last_modified=last_modified
)
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
# 负责整合,遗忘,合并记忆 # 负责整合,遗忘,合并记忆

View File

@@ -11,12 +11,11 @@ from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入 from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.config.config import model_config from src.config.config import model_config
from sqlalchemy import select from sqlalchemy import select
logger = get_logger(__name__) logger = get_logger(__name__)
session = get_session()
class MemoryItem: class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]): def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
@@ -113,7 +112,8 @@ class InstantMemory:
logger.info(f"不需要记忆:{text}") logger.info(f"不需要记忆:{text}")
async def store_memory(self, memory_item: MemoryItem): async def store_memory(self, memory_item: MemoryItem):
memory = Memory( with get_db_session() as session:
memory = Memory(
memory_id=memory_item.memory_id, memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id, chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text, memory_text=memory_item.memory_text,
@@ -121,8 +121,8 @@ class InstantMemory:
create_time=memory_item.create_time, create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time, last_view_time=memory_item.last_view_time,
) )
session.add(memory) session.add(memory)
session.commit() session.commit()
async def get_memory(self, target: str): async def get_memory(self, target: str):
from json_repair import repair_json from json_repair import repair_json
@@ -165,17 +165,18 @@ class InstantMemory:
logger.info(f"start_time: {start_time}, end_time: {end_time}") logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆 # 检索包含关键词的记忆
memories_set = set() memories_set = set()
if start_time and end_time: with get_db_session() as session:
start_ts = start_time.timestamp() if start_time and end_time:
end_ts = end_time.timestamp() start_ts = start_time.timestamp()
query = session.execute(select(Memory).where( end_ts = end_time.timestamp()
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) query = session.execute(select(Memory).where(
& (Memory.create_time < end_ts) (Memory.chat_id == self.chat_id)
)).scalars() & (Memory.create_time >= start_ts)
else: & (Memory.create_time < end_ts)
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars() )).scalars()
else:
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
for mem in query: for mem in query:
# 对每条记忆 # 对每条记忆
mem_keywords = mem.keywords or "" mem_keywords = mem.keywords or ""

View File

@@ -274,6 +274,7 @@ class ChatBot:
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}") # logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
# return # return
# 过滤检查 # 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore message.raw_message, # type: ignore

View File

@@ -12,7 +12,7 @@ from sqlalchemy import select, text
from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.config.config import global_config # 新增导入 from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示 # 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -23,7 +23,6 @@ install(extra_lines=3)
logger = get_logger("chat_stream") logger = get_logger("chat_stream")
session = get_session()
class ChatMessageContext: class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息""" """聊天消息上下文,存储消息的上下文信息"""
@@ -132,13 +131,14 @@ class ChatManager:
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try: # try:
db.connect(reuse_if_open=True) # with get_db_session() as session:
# 确保 ChatStreams 表存在 # db.connect(reuse_if_open=True)
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) # # 确保 ChatStreams 表存在
session.commit() # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
except Exception as e: # session.commit()
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") # except Exception as e:
# logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
@@ -236,7 +236,8 @@ class ChatManager:
# 检查数据库中是否存在 # 检查数据库中是否存在
def _db_find_stream_sync(s_id: str): def _db_find_stream_sync(s_id: str):
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar() with get_db_session() as session:
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
@@ -331,44 +332,45 @@ class ChatManager:
stream_data_dict = stream.to_dict() stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict): def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info") with get_db_session() as session:
group_info_d = s_data_dict.get("group_info") user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = { fields_to_save = {
"platform": s_data_dict["platform"], "platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"], "create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"], "last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "", "user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "", "user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "", "user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "", "group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "", "group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "", "group_name": group_info_d["group_name"] if group_info_d else "",
} }
# 根据数据库类型选择插入语句 # 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite": if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update( stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'], index_elements=['stream_id'],
set_=fields_to_save set_=fields_to_save
) )
elif global_config.database.database_type == "mysql": elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update( stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"} **{key: value for key, value in fields_to_save.items() if key != "stream_id"}
) )
else: else:
# 默认使用通用插入尝试SQLite语法 # 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update( stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'], index_elements=['stream_id'],
set_=fields_to_save set_=fields_to_save
) )
session.execute(stmt) session.execute(stmt)
session.commit() session.commit()
try: try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -387,30 +389,32 @@ class ChatManager:
def _db_load_all_streams_sync(): def _db_load_all_streams_sync():
loaded_streams_data = [] loaded_streams_data = []
for model_instance in session.execute(select(ChatStreams)).scalars(): with get_db_session() as session:
user_info_data = { for model_instance in session.execute(select(ChatStreams)).scalars():
"platform": model_instance.user_platform, user_info_data = {
"user_id": model_instance.user_id, "platform": model_instance.user_platform,
"user_nickname": model_instance.user_nickname, "user_id": model_instance.user_id,
"user_cardname": model_instance.user_cardname or "", "user_nickname": model_instance.user_nickname,
} "user_cardname": model_instance.user_cardname or "",
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
} }
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = { data_for_from_dict = {
"stream_id": model_instance.stream_id, "stream_id": model_instance.stream_id,
"platform": model_instance.platform, "platform": model_instance.platform,
"user_info": user_info_data, "user_info": user_info_data,
"group_info": group_info_data, "group_info": group_info_data,
"create_time": model_instance.create_time, "create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time, "last_active_time": model_instance.last_active_time,
} }
loaded_streams_data.append(data_for_from_dict) loaded_streams_data.append(data_for_from_dict)
session.commit()
return loaded_streams_data return loaded_streams_data
try: try:

View File

@@ -7,7 +7,7 @@ from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger from src.common.logger import get_logger
from .chat_stream import ChatStream from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, update, desc from sqlalchemy import select, update, desc
logger = get_logger("message_storage") logger = get_logger("message_storage")
@@ -70,7 +70,6 @@ class MessageStorage:
priority_info_json = json.dumps(priority_info) if priority_info else None priority_info_json = json.dumps(priority_info) if priority_info else None
# 获取数据库会话 # 获取数据库会话
session = get_session()
new_message = Messages( new_message = Messages(
message_id=msg_id, message_id=msg_id,
@@ -104,8 +103,10 @@ class MessageStorage:
is_notify=is_notify, is_notify=is_notify,
is_command=is_command, is_command=is_command,
) )
session.add(new_message) with get_db_session() as session:
session.commit() session.add(new_message)
session.commit()
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")
logger.error(f"消息:{message}") logger.error(f"消息:{message}")
@@ -155,7 +156,8 @@ class MessageStorage:
session.execute( session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
) )
# session.commit() 会在上下文管理器中自动调用 session.commit()
# 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else: else:
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
@@ -184,6 +186,7 @@ class MessageStorage:
image_record = session.execute( image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar() ).scalar()
session.commit()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0) return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception: except Exception:
return match.group(0) return match.group(0)

View File

@@ -10,7 +10,7 @@ from src.common.message_repository import find_messages, count_messages
from src.common.database.sqlalchemy_models import ActionRecords, Images from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_ from sqlalchemy import select, and_
install(extra_lines=3) install(extra_lines=3)

View File

@@ -420,7 +420,7 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][COST_BY_MODULE][module_name] += cost stats[period_key][COST_BY_MODULE][module_name] += cost
# 收集time_cost数据 # 收集time_cost数据
time_cost = record.time_cost or 0.0 time_cost = record.get('time_cost') or 0.0
if time_cost > 0: # 只记录有效的time_cost if time_cost > 0: # 只记录有效的time_cost
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost) stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost) stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)

View File

@@ -41,12 +41,12 @@ class ImageManager:
self._initialized = True self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try: # try:
db.connect(reuse_if_open=True) # db.connect(reuse_if_open=True)
# 使用SQLAlchemy创建表已在初始化时完成 # # 使用SQLAlchemy创建表已在初始化时完成
logger.debug("使用SQLAlchemy进行表管理") # logger.debug("使用SQLAlchemy进行表管理")
except Exception as e: # except Exception as e:
logger.error(f"数据库连接失败: {e}") # logger.error(f"数据库连接失败: {e}")
self._initialized = True self._initialized = True
@@ -105,7 +105,8 @@ class ImageManager:
timestamp=current_timestamp timestamp=current_timestamp
) )
session.add(new_desc) session.add(new_desc)
# session.commit() 会在上下文管理器中自动调用 session.commit()
# 会在上下文管理器中自动调用
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
@@ -246,7 +247,8 @@ class ImageManager:
timestamp=current_timestamp, timestamp=current_timestamp,
) )
session.add(new_img) session.add(new_img)
# session.commit() 会在上下文管理器中自动调用 session.commit()
# 会在上下文管理器中自动调用
except Exception as e: except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}") logger.error(f"保存到Images表失败: {str(e)}")
@@ -323,7 +325,7 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4()) existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True existing_image.vlm_processed = True
session.commit()
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else: else:
new_img = Images( new_img = Images(
@@ -337,7 +339,7 @@ class ImageManager:
count=1, count=1,
) )
session.add(new_img) session.add(new_img)
session.commit()
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e: except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}") logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -511,35 +513,35 @@ class ImageManager:
existing_image.vlm_processed = False existing_image.vlm_processed = False
existing_image.count += 1 existing_image.count += 1
session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]" return existing_image.image_id, f"[picid:{existing_image.image_id}]"
# print(f"图片不存在: {image_hash}") # print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())
# 保存新图片 # 保存新图片
current_timestamp = time.time() current_timestamp = time.time()
image_dir = os.path.join(self.IMAGE_DIR, "images") image_dir = os.path.join(self.IMAGE_DIR, "images")
os.makedirs(image_dir, exist_ok=True) os.makedirs(image_dir, exist_ok=True)
filename = f"{image_id}.png" filename = f"{image_id}.png"
file_path = os.path.join(image_dir, filename) file_path = os.path.join(image_dir, filename)
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库
new_img = Images( new_img = Images(
image_id=image_id, image_id=image_id,
emoji_hash=image_hash, emoji_hash=image_hash,
path=file_path, path=file_path,
type="image", type="image",
timestamp=current_timestamp, timestamp=current_timestamp,
vlm_processed=False, vlm_processed=False,
count=1, count=1,
) )
session.add(new_img) session.add(new_img)
session.commit() session.commit()
# 启动异步VLM处理 # 启动异步VLM处理
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64)) asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
@@ -581,7 +583,7 @@ class ImageManager:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description image.description = existing_with_description.description
image.vlm_processed = True image.vlm_processed = True
session.commit()
# 同时保存到ImageDescriptions表作为备用缓存 # 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image") self._save_description_to_db(image_hash, existing_with_description.description, "image")
return return
@@ -591,7 +593,7 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description image.description = cached_description
image.vlm_processed = True image.vlm_processed = True
session.commit()
return return
# 获取图片格式 # 获取图片格式

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
# SQLAlchemy相关导入 # SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_session from src.common.database.sqlalchemy_models import get_engine, get_db_session
install(extra_lines=3) install(extra_lines=3)
@@ -18,7 +18,7 @@ logger = get_logger("database")
# 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy # 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy
class DatabaseProxy: class DatabaseProxy:
"""数据库代理类提供Peewee到SQLAlchemy的兼容性接口""" """数据库代理类"""
def __init__(self): def __init__(self):
self._engine = None self._engine = None
@@ -28,56 +28,7 @@ class DatabaseProxy:
"""初始化数据库连接""" """初始化数据库连接"""
return initialize_database_compat() return initialize_database_compat()
def connect(self, reuse_if_open=True):
"""连接数据库(兼容性方法)"""
try:
self._engine = get_engine()
return True
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return False
def is_closed(self):
"""检查数据库是否关闭(兼容性方法)"""
return self._engine is None
def create_tables(self, models, safe=True):
"""创建表(兼容性方法)"""
try:
from src.common.database.sqlalchemy_models import Base
engine = get_engine()
Base.metadata.create_all(bind=engine)
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def table_exists(self, model):
"""检查表是否存在(兼容性方法)"""
try:
from sqlalchemy import inspect
engine = get_engine()
inspector = inspect(engine)
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
return table_name in inspector.get_table_names()
except Exception:
return False
def execute_sql(self, sql):
"""执行SQL兼容性方法"""
try:
from sqlalchemy import text
session = get_session()
result = session.execute(text(sql))
session.close()
return result
except Exception as e:
logger.error(f"执行SQL失败: {e}")
raise
def atomic(self):
"""事务上下文管理器(兼容性方法)"""
return SQLAlchemyTransaction()
class SQLAlchemyTransaction: class SQLAlchemyTransaction:
"""SQLAlchemy事务上下文管理器""" """SQLAlchemy事务上下文管理器"""
@@ -86,7 +37,7 @@ class SQLAlchemyTransaction:
self.session = None self.session = None
def __enter__(self): def __enter__(self):
self.session = get_session() self.session = get_db_session()
return self.session return self.session
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):

View File

@@ -15,7 +15,7 @@ from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import ( from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams, Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory, LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus, get_session Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus
) )
logger = get_logger("sqlalchemy_database_api") logger = get_logger("sqlalchemy_database_api")
@@ -41,38 +41,9 @@ MODEL_MAPPING = {
} }
@contextmanager
def get_db_session():
"""数据库会话上下文管理器,自动处理事务和连接错误"""
session = None
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
session = get_session()
yield session
session.commit()
break
except (DisconnectionError, OperationalError) as e:
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if session:
session.rollback()
session.close()
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise
except Exception:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]): def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]):
"""构建查询过滤条件""" """构建查询过滤条件"""
conditions = [] conditions = []
@@ -296,6 +267,7 @@ async def db_save(
# 创建新记录 # 创建新记录
new_record = model_class(**data) new_record = model_class(**data)
session.add(new_record) session.add(new_record)
session.commit()
session.flush() session.flush()
# 转换为字典格式返回 # 转换为字典格式返回
@@ -415,8 +387,3 @@ async def store_action_info(
traceback.print_exc() traceback.print_exc()
return None return None
# 兼容性函数方便从Peewee迁移
def get_model_class(model_name: str) -> Optional[Type[Base]]:
"""根据模型名称获取模型类"""
return MODEL_MAPPING.get(model_name)

View File

@@ -8,7 +8,7 @@ from typing import Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import ( from src.common.database.sqlalchemy_models import (
Base, get_engine, get_session, initialize_database Base, get_engine, initialize_database
) )
logger = get_logger("sqlalchemy_init") logger = get_logger("sqlalchemy_init")
@@ -72,36 +72,6 @@ def create_all_tables() -> bool:
return False return False
def check_database_connection() -> bool:
"""
检查数据库连接是否正常
Returns:
bool: 连接是否正常
"""
try:
session = get_session()
if session is None:
logger.error("无法获取数据库会话")
return False
# 检查会话是否可用(如果能获取到会话说明连接正常)
if session is None:
logger.error("数据库会话无效")
return False
session.close()
logger.info("数据库连接检查通过")
return True
except SQLAlchemyError as e:
logger.error(f"数据库连接检查失败: {e}")
return False
except Exception as e:
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
return False
def get_database_info() -> Optional[dict]: def get_database_info() -> Optional[dict]:
""" """
@@ -149,9 +119,6 @@ def initialize_database_compat() -> bool:
if success: if success:
success = create_all_tables() success = create_all_tables()
if success:
success = check_database_connection()
if success: if success:
_database_initialized = True _database_initialized = True

View File

@@ -29,102 +29,6 @@ def get_string_field(max_length=255, **kwargs):
return String(max_length, **kwargs) return String(max_length, **kwargs)
else: else:
return Text(**kwargs) return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base): class ChatStreams(Base):
@@ -482,6 +386,22 @@ class MaiZoneScheduleStatus(Base):
) )
class BanUser(Base):
"""被禁用用户模型"""
__tablename__ = 'ban_users'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
violation_num = Column(Integer, nullable=False, default=0)
reason = Column(Text, nullable=False)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_violation_num', 'violation_num'),
Index('idx_banuser_user_id', 'user_id'),
)
# 数据库引擎和会话管理 # 数据库引擎和会话管理
_engine = None _engine = None
_SessionLocal = None _SessionLocal = None
@@ -593,7 +513,7 @@ def get_db_session():
_, SessionLocal = initialize_database() _, SessionLocal = initialize_database()
session = SessionLocal() session = SessionLocal()
yield session yield session
session.commit() # session.commit()
except Exception: except Exception:
if session: if session:
session.rollback() session.rollback()
@@ -601,6 +521,7 @@ def get_db_session():
finally: finally:
if session: if session:
session.close() session.close()
def get_engine(): def get_engine():

View File

@@ -8,7 +8,7 @@ from src.config.config import global_config
# from src.common.database.database_model import Messages # from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -44,92 +44,92 @@ def find_messages(
消息字典列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
session = get_session() with get_db_session() as session:
query = select(Messages) query = select(Messages)
# 应用过滤器 # 应用过滤器
if message_filter: if message_filter:
conditions = [] conditions = []
for key, value in message_filter.items(): for key, value in message_filter.items():
if hasattr(Messages, key): if hasattr(Messages, key):
field = getattr(Messages, key) field = getattr(Messages, key)
if isinstance(value, dict): if isinstance(value, dict):
# 处理 MongoDB 风格的操作符 # 处理 MongoDB 风格的操作符
for op, op_value in value.items(): for op, op_value in value.items():
if op == "$gt": if op == "$gt":
conditions.append(field > op_value) conditions.append(field > op_value)
elif op == "$lt": elif op == "$lt":
conditions.append(field < op_value) conditions.append(field < op_value)
elif op == "$gte": elif op == "$gte":
conditions.append(field >= op_value) conditions.append(field >= op_value)
elif op == "$lte": elif op == "$lte":
conditions.append(field <= op_value) conditions.append(field <= op_value)
elif op == "$ne": elif op == "$ne":
conditions.append(field != op_value) conditions.append(field != op_value)
elif op == "$in": elif op == "$in":
conditions.append(field.in_(op_value)) conditions.append(field.in_(op_value))
elif op == "$nin": elif op == "$nin":
conditions.append(field.not_in(op_value)) conditions.append(field.not_in(op_value))
else: else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else: else:
# 直接相等比较 logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
conditions.append(field == value) if conditions:
else: query = query.where(*conditions)
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if filter_bot: if filter_bot:
query = query.where(Messages.user_id != global_config.bot.qq_account) query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command: if filter_command:
query = query.where(not_(Messages.is_command)) query = query.where(not_(Messages.is_command))
if limit > 0: if limit > 0:
# 确保limit是正整数 # 确保limit是正整数
limit = max(1, int(limit)) limit = max(1, int(limit))
if limit_mode == "earliest": if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
sort_terms.append(field.asc())
elif direction == -1: # DESC
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if sort_terms:
query = query.order_by(*sort_terms)
try: try:
results = session.execute(query).scalars().all() results = session.execute(query).scalars().all()
except Exception as e: except Exception as e:
logger.error(f"执行earliest查询失败: {e}") logger.error(f"执行无限制查询失败: {e}")
results = [] results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
sort_terms.append(field.asc())
elif direction == -1: # DESC
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in results] return [_model_to_dict(msg) for msg in results]
except Exception as e: except Exception as e:
@@ -152,50 +152,50 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
session = get_session() with get_db_session() as session:
query = select(func.count(Messages.id)) query = select(func.count(Messages.id))
# 应用过滤器 # 应用过滤器
if message_filter: if message_filter:
conditions = [] conditions = []
for key, value in message_filter.items(): for key, value in message_filter.items():
if hasattr(Messages, key): if hasattr(Messages, key):
field = getattr(Messages, key) field = getattr(Messages, key)
if isinstance(value, dict): if isinstance(value, dict):
# 处理 MongoDB 风格的操作符 # 处理 MongoDB 风格的操作符
for op, op_value in value.items(): for op, op_value in value.items():
if op == "$gt": if op == "$gt":
conditions.append(field > op_value) conditions.append(field > op_value)
elif op == "$lt": elif op == "$lt":
conditions.append(field < op_value) conditions.append(field < op_value)
elif op == "$gte": elif op == "$gte":
conditions.append(field >= op_value) conditions.append(field >= op_value)
elif op == "$lte": elif op == "$lte":
conditions.append(field <= op_value) conditions.append(field <= op_value)
elif op == "$ne": elif op == "$ne":
conditions.append(field != op_value) conditions.append(field != op_value)
elif op == "$in": elif op == "$in":
conditions.append(field.in_(op_value)) conditions.append(field.in_(op_value))
elif op == "$nin": elif op == "$nin":
conditions.append(field.not_in(op_value)) conditions.append(field.not_in(op_value))
else: else:
logger.warning( logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
) )
else:
# 直接相等比较
conditions.append(field == value)
else: else:
# 直接相等比较 logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
conditions.append(field == value) if conditions:
else: query = query.where(*conditions)
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = session.execute(query).scalar() count = session.execute(query).scalar()
return count or 0 return count or 0
except Exception as e: except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message) logger.error(log_message)
return 0 return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。

View File

@@ -5,7 +5,7 @@ from PIL import Image
from datetime import datetime from datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import LLMUsage, get_session from src.common.database.sqlalchemy_models import LLMUsage, get_db_session
from src.config.api_ada_configs import ModelInfo from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord from .model_client.base_client import UsageRecord
@@ -156,9 +156,8 @@ class LLMUsageRecorder:
session = None session = None
try: try:
# 使用 SQLAlchemy 会话创建记录 # 使用 SQLAlchemy 会话创建记录
session = get_session() with get_db_session() as session:
usage_record = LLMUsage(
usage_record = LLMUsage(
model_name=model_info.model_identifier, model_name=model_info.model_identifier,
model_assign_name=model_info.name, model_assign_name=model_info.name,
model_api_provider=model_info.api_provider, model_api_provider=model_info.api_provider,
@@ -174,8 +173,8 @@ class LLMUsageRecorder:
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
) )
session.add(usage_record) session.add(usage_record)
session.commit() session.commit()
logger.debug( logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, " f"Token使用情况 - 模型: {model_usage.model_name}, "
@@ -184,11 +183,7 @@ class LLMUsageRecorder:
f"总计: {model_usage.total_tokens}" f"总计: {model_usage.total_tokens}"
) )
except Exception as e: except Exception as e:
if session:
session.rollback()
logger.error(f"记录token使用情况失败: {str(e)}") logger.error(f"记录token使用情况失败: {str(e)}")
finally:
if session:
session.close()
llm_usage_recorder = LLMUsageRecorder() llm_usage_recorder = LLMUsageRecorder()

View File

@@ -212,6 +212,7 @@ class ScheduleManager:
setattr(new_schedule, 'date', today_str) setattr(new_schedule, 'date', today_str)
setattr(new_schedule, 'schedule_data', json.dumps(schedule_data)) setattr(new_schedule, 'schedule_data', json.dumps(schedule_data))
session.add(new_schedule) session.add(new_schedule)
session.commit()
# 美化输出 # 美化输出
schedule_str = f"已成功生成并保存今天的日程 ({today_str})\n" schedule_str = f"已成功生成并保存今天的日程 ({today_str})\n"

View File

@@ -11,10 +11,9 @@ from sqlalchemy import select
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.sqlalchemy_models import PersonInfo from src.common.database.sqlalchemy_models import PersonInfo
from src.common.database.sqlalchemy_database_api import get_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
session = get_session()
""" """
PersonInfoManager 类方法功能摘要: PersonInfoManager 类方法功能摘要:
@@ -52,56 +51,37 @@ person_info_default = {
"attitude": 50, "attitude": 50,
} }
# 统一的会话管理函数
def with_session(func):
"""装饰器为函数自动注入session参数"""
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
return await func(session, *args, **kwargs)
return async_wrapper
else:
def sync_wrapper(*args, **kwargs):
return func(session, *args, **kwargs)
return sync_wrapper
# 全局会话获取函数用于替换所有裸露的session使用
def _get_session():
"""获取数据库会话的统一函数"""
return get_session()
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
"""初始化PersonInfoManager""" """初始化PersonInfoManager"""
from src.common.database.sqlalchemy_models import PersonInfo
self.person_name_list = {} self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try: # try:
db.connect(reuse_if_open=True) # with get_db_session() as session:
# 设置连接池参数仅对SQLite有效 # db.connect(reuse_if_open=True)
if hasattr(db, "execute_sql"): # # 设置连接池参数仅对SQLite有效
# 检查数据库类型只对SQLite执行PRAGMA语句 # if hasattr(db, "execute_sql"):
if global_config.database.database_type == "sqlite": # # 检查数据库类型只对SQLite执行PRAGMA语句
# 设置SQLite优化参数 # if global_config.database.database_type == "sqlite":
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存 # # 设置SQLite优化参数
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中 # db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射 # db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.create_tables([PersonInfo], safe=True) # db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
except Exception as e: # db.create_tables([PersonInfo], safe=True)
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # except Exception as e:
# logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name # # 初始化时读取所有person_name
try: try:
from src.common.database.sqlalchemy_models import PersonInfo # 在这里获取会话
# 在这里获取会话 with get_db_session() as session:
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where( for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_not(None) PersonInfo.person_name.is_not(None)
)).fetchall(): )).fetchall():
if record.person_name: if record.person_name:
self.person_name_list[record.person_id] = record.person_name self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e: except Exception as e:
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
@@ -121,7 +101,8 @@ class PersonInfoManager:
def _db_check_known_sync(p_id: str): def _db_check_known_sync(p_id: str):
# 在需要时获取会话 # 在需要时获取会话
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
try: try:
return await asyncio.to_thread(_db_check_known_sync, person_id) return await asyncio.to_thread(_db_check_known_sync, person_id)
@@ -133,7 +114,8 @@ class PersonInfoManager:
"""根据用户名获取用户ID""" """根据用户名获取用户ID"""
try: try:
# 在需要时获取会话 # 在需要时获取会话
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar() with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
return record.person_id if record else "" return record.person_id if record else ""
except Exception as e: except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -176,15 +158,16 @@ class PersonInfoManager:
# If it's already a string, assume it's valid JSON or a non-JSON string field # If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict): def _db_create_sync(p_data: dict):
try: with get_db_session() as session:
new_person = PersonInfo(**p_data) try:
session.add(new_person) new_person = PersonInfo(**p_data)
session.commit() session.add(new_person)
return True session.commit()
except Exception as e:
session.rollback() return True
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") except Exception as e:
return False logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data) await asyncio.to_thread(_db_create_sync, final_data)
@@ -223,25 +206,26 @@ class PersonInfoManager:
final_data[key] = json.dumps([], ensure_ascii=False) final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(p_data: dict): def _db_safe_create_sync(p_data: dict):
try: with get_db_session() as session:
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar() try:
if existing: existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") if existing:
return True logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建 # 尝试创建
new_person = PersonInfo(**p_data) new_person = PersonInfo(**p_data)
session.add(new_person) session.add(new_person)
session.commit() session.commit()
return True
except Exception as e: return True
session.rollback() except Exception as e:
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误") logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功 return True # 其他协程已创建,视为成功
else: else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False return False
await asyncio.to_thread(_db_safe_create_sync, final_data) await asyncio.to_thread(_db_safe_create_sync, final_data)
@@ -263,32 +247,33 @@ class PersonInfoManager:
def _db_update_sync(p_id: str, f_name: str, val_to_set): def _db_update_sync(p_id: str, f_name: str, val_to_set):
start_time = time.time() start_time = time.time()
try: with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() try:
query_time = time.time() record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
query_time = time.time()
if record: if record:
setattr(record, f_name, val_to_set) setattr(record, f_name, val_to_set)
session.commit()
save_time = time.time() save_time = time.time()
total_time = save_time - start_time total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志 if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning( logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
) )
session.commit()
return True, False # Found and updated, no creation needed return True, False # Found and updated, no creation needed
else: else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
total_time = time.time() - start_time total_time = time.time() - start_time
if total_time > 0.5: logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}") raise
return False, True # Not found, needs creation
except Exception as e:
session.rollback()
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value) found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
@@ -320,7 +305,8 @@ class PersonInfoManager:
return False return False
def _db_has_field_sync(p_id: str, f_name: str): def _db_has_field_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
return bool(record) return bool(record)
try: try:
@@ -430,7 +416,8 @@ class PersonInfoManager:
else: else:
def _db_check_name_exists_sync(name_to_check): def _db_check_name_exists_sync(name_to_check):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True is_duplicate = True
@@ -471,14 +458,14 @@ class PersonInfoManager:
def _db_delete_sync(p_id: str): def _db_delete_sync(p_id: str):
try: try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() with get_db_session() as session:
if record: record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
session.delete(record) if record:
session.commit() session.delete(record)
session.commit()
return 1 return 1
return 0 return 0
except Exception as e: except Exception as e:
session.rollback()
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}") logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0 return 0
@@ -497,7 +484,8 @@ class PersonInfoManager:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str): def _db_get_value_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record: if record:
val = getattr(record, f_name, None) val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS: if f_name in JSON_SERIALIZED_FIELDS:
@@ -531,27 +519,28 @@ class PersonInfoManager:
def get_value_sync(person_id: str, field_name: str): def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值""" """同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name) default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: with get_db_session() as session:
default_value_for_field = [] if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar(): if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None) val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS: if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str): if isinstance(val, str):
try: try:
return json.loads(val) return json.loads(val)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return [] return []
elif val is None: return val
return []
return val return val
return val
if field_name in person_info_default: if field_name in person_info_default:
return default_value_for_field return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None return None
@staticmethod @staticmethod
async def get_values(person_id: str, field_names: list) -> dict: async def get_values(person_id: str, field_names: list) -> dict:
@@ -563,7 +552,8 @@ class PersonInfoManager:
result = {} result = {}
def _db_get_record_sync(p_id: str): def _db_get_record_sync(p_id: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id) record = await asyncio.to_thread(_db_get_record_sync, person_id)
@@ -608,10 +598,11 @@ class PersonInfoManager:
def _db_get_specific_sync(f_name: str): def _db_get_specific_sync(f_name: str):
found_results = {} found_results = {}
try: try:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall(): with get_db_session() as session:
value = getattr(record, f_name) for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
if way(value): value = getattr(record, f_name)
found_results[record.person_id] = value if way(value):
found_results[record.person_id] = value
except Exception as e_query: except Exception as e_query:
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True) logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results return found_results
@@ -634,19 +625,20 @@ class PersonInfoManager:
def _db_get_or_create_sync(p_id: str, init_data: dict): def _db_get_or_create_sync(p_id: str, init_data: dict):
"""原子性的获取或创建操作""" """原子性的获取或创建操作"""
# 首先尝试获取现有记录 with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() # 首先尝试获取现有记录
if record: record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
return record, False # 记录存在,未创建 if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建 # 记录不存在,尝试创建
try: try:
new_person = PersonInfo(**init_data) new_person = PersonInfo(**init_data)
session.add(new_person) session.add(new_person)
session.commit() session.commit()
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功 return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
except Exception as e: except Exception as e:
session.rollback()
# 如果创建失败(可能是因为竞态条件),再次尝试获取 # 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
@@ -709,7 +701,8 @@ class PersonInfoManager:
if not found_person_id: if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str): def _db_find_by_name_sync(p_name_to_find: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar() with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name) record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record: if record:

View File

@@ -14,7 +14,6 @@ from src.common.database.sqlalchemy_database_api import (
db_save, db_save,
db_get, db_get,
store_action_info, store_action_info,
get_model_class,
MODEL_MAPPING MODEL_MAPPING
) )
@@ -24,6 +23,5 @@ __all__ = [
'db_save', 'db_save',
'db_get', 'db_get',
'store_action_info', 'store_action_info',
'get_model_class',
'MODEL_MAPPING' 'MODEL_MAPPING'
] ]

View File

@@ -151,7 +151,7 @@ class ScheduleManager:
) )
session.add(new_record) session.add(new_record)
session.commit()
logger.info(f"已更新日程处理状态: {datetime_hour} - {activity} - 成功: {success}") logger.info(f"已更新日程处理状态: {datetime_hour} - {activity} - 成功: {success}")
except Exception as e: except Exception as e: