From d46d689c43047b1386fe93756fc3a2cd491bb28f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sat, 16 Aug 2025 23:43:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 138 ++--- src/chat/express/expression_learner.py | 109 ++-- src/chat/express/expression_selector.py | 24 +- src/chat/memory_system/Hippocampus.py | 550 +++++++++--------- src/chat/memory_system/instant_memory.py | 33 +- src/chat/message_receive/bot.py | 1 + src/chat/message_receive/chat_stream.py | 138 ++--- src/chat/message_receive/storage.py | 13 +- src/chat/utils/chat_message_builder.py | 2 +- src/chat/utils/statistic.py | 2 +- src/chat/utils/utils_image.py | 74 +-- src/common/database/database.py | 57 +- .../database/sqlalchemy_database_api.py | 39 +- src/common/database/sqlalchemy_init.py | 35 +- src/common/database/sqlalchemy_models.py | 115 +--- src/common/message_repository.py | 234 ++++---- src/llm_models/utils.py | 17 +- src/manager/schedule_manager.py | 1 + src/person_info/person_info.py | 255 ++++---- src/plugin_system/apis/database_api.py | 2 - src/plugins/built_in/maizone/scheduler.py | 2 +- 21 files changed, 834 insertions(+), 1007 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index c2a2e7f1e..1afa6ad86 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -14,7 +14,7 @@ from PIL import Image from rich.traceback import install from sqlalchemy import select from src.common.database.database import db -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import Emoji, Images from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -30,8 +30,6 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 -session = get_session() - """ 还没经过测试,有些地方数据库和内存数据同步可能不完全 @@ -152,28 +150,29 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - emotion_str = ",".join(self.emotion) if self.emotion else "" + with get_db_session() as session: + emotion_str = ",".join(self.emotion) if self.emotion else "" - emoji = Emoji( - emoji_hash=self.hash, - full_path=self.full_path, - format=self.format, - description=self.description, - emotion=emotion_str, # Store as comma-separated string - query_count=0, # Default value - is_registered=True, - is_banned=False, # Default value - record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time - register_time=self.register_time, - usage_count=self.usage_count, - last_used_time=self.last_used_time, - ) - session.add(emoji) - session.commit() + emoji = Emoji( + emoji_hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) + session.add(emoji) + session.commit() + + logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") - logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") - - return True + return True except Exception as db_error: logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") @@ -205,14 +204,15 @@ class MaiEmoji: # 2. 删除数据库记录 try: - will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none() - if will_delete_emoji is None: - logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted - else: - session.delete(will_delete_emoji) - session.commit() - result = 1 # Successfully deleted one record + with get_db_session() as session: + will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none() + if will_delete_emoji is None: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 # Indicate no DB record was deleted + else: + session.delete(will_delete_emoji) + result = 1 # Successfully deleted one record + session.commit() except Exception as e: logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") result = 0 @@ -403,35 +403,36 @@ class EmojiManager: def initialize(self) -> None: """初始化数据库连接和表情目录""" - try: - db.connect(reuse_if_open=True) - if db.is_closed(): - raise RuntimeError("数据库连接失败") - _ensure_emoji_dir() - self._initialized = True # 标记为已初始化 - logger.info("EmojiManager初始化成功") - except Exception as e: - logger.error(f"EmojiManager初始化失败: {e}") - self._initialized = False - raise + # try: + # db.connect(reuse_if_open=True) + # if db.is_closed(): + # raise RuntimeError("数据库连接失败") + # _ensure_emoji_dir() + # self._initialized = True # 标记为已初始化 + # logger.info("EmojiManager初始化成功") + # except Exception as e: + # logger.error(f"EmojiManager初始化失败: {e}") + # self._initialized = False + # raise - def _ensure_db(self) -> None: - """确保数据库已初始化""" - if not self._initialized: - self.initialize() - if not self._initialized: - raise RuntimeError("EmojiManager not initialized") + # def _ensure_db(self) -> None: + # """确保数据库已初始化""" + # if not self._initialized: + # self.initialize() + # if not self._initialized: + # raise RuntimeError("EmojiManager not initialized") def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() - if emoji_update is None: - logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") - else: - emoji_update.usage_count += 1 + with get_db_session() as session: + emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() + if emoji_update is None: + logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") + else: + emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() # Update last used time - session.commit() # Persist changes to DB + session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -659,11 +660,12 @@ class EmojiManager: async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: - self._ensure_db() - logger.debug("[数据库] 开始加载所有表情包记录 ...") + with get_db_session() as session: + self._ensure_db() + logger.debug("[数据库] 开始加载所有表情包记录 ...") - emoji_instances = session.execute(select(Emoji)).scalars().all() - emoji_objects, load_errors = _to_emoji_objects(emoji_instances) + emoji_instances = session.execute(select(Emoji)).scalars().all() + emoji_objects, load_errors = _to_emoji_objects(emoji_instances) # 更新内存中的列表和数量 self.emoji_objects = emoji_objects @@ -672,6 +674,7 @@ class EmojiManager: logger.info(f"[数据库] 加载完成: 共加载 {self.emoji_num} 个表情包记录。") if load_errors > 0: logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。") + except Exception as e: logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}") @@ -688,7 +691,8 @@ class EmojiManager: list[MaiEmoji]: 表情包对象列表 """ try: - self._ensure_db() + with get_db_session() as session: + self._ensure_db() if emoji_hash: query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() @@ -703,7 +707,6 @@ class EmojiManager: if load_errors > 0: logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") - return emoji_objects except Exception as e: @@ -744,13 +747,15 @@ class EmojiManager: # 如果内存中没有,从数据库查找 self._ensure_db() try: - emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() + with get_db_session() as session: + emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() if emoji_record and emoji_record.description: logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") return emoji_record.description except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") + return None except Exception as e: @@ -905,13 +910,14 @@ class EmojiManager: # 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成) existing_description = None try: + with get_db_session() as session: # from src.common.database.database_model_compat import Images - stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji")) - existing_image = session.execute(stmt).scalar_one_or_none() - if existing_image and existing_image.description: - existing_description = existing_image.description - logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") + stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + existing_image = session.execute(stmt).scalar_one_or_none() + if existing_image and existing_image.description: + existing_description = existing_image.description + logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") except Exception as e: logger.debug(f"查询已有描述时出错: {e}") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 3f6533547..d040e5ba6 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from sqlalchemy import select from src.common.database.sqlalchemy_models import Expression from src.llm_models.utils_model import LLMRequest @@ -22,7 +22,6 @@ DECAY_DAYS = 30 # 30天衰减到0.01 DECAY_MIN = 0.01 # 最小衰减值 logger = get_logger("expressor") -session = get_session() def format_create_date(timestamp: float) -> str: """ @@ -204,7 +203,8 @@ class ExpressionLearner: learnt_grammar_expressions = [] # 直接从数据库查询 - style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))) + with get_db_session() as session: + style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))) for expr in style_query.scalars(): # 确保create_date存在,如果不存在则使用last_active_time create_date = expr.create_date if expr.create_date is not None else expr.last_active_time @@ -247,8 +247,9 @@ class ExpressionLearner: 对数据库中的所有表达方式应用全局衰减 """ try: - # 获取所有表达方式 - all_expressions = session.execute(select(Expression)).scalars() + with get_db_session() as session: + # 获取所有表达方式 + all_expressions = session.execute(select(Expression)).scalars() updated_count = 0 deleted_count = 0 @@ -265,19 +266,19 @@ class ExpressionLearner: if new_count <= 0.01: # 如果count太小,删除这个表达方式 session.delete(expr) + session.commit() deleted_count += 1 else: # 更新count expr.count = new_count updated_count += 1 - session.commit() + if updated_count > 0 or deleted_count > 0: logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") except Exception as e: - session.rollback() logger.error(f"数据库全局衰减失败: {e}") def calculate_decay_factor(self, time_diff_days: float) -> float: @@ -355,43 +356,46 @@ class ExpressionLearner: for chat_id, expr_list in chat_dict.items(): for new_expr in expr_list: # 查找是否已存在相似表达方式 - query = session.execute(select(Expression).where( + with get_db_session() as session: + query = session.execute(select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == type) & (Expression.situation == new_expr["situation"]) & (Expression.style == new_expr["style"]) )).scalar() - if query: - expr_obj = query + if query: + expr_obj = query # 50%概率替换内容 - if random.random() < 0.5: - expr_obj.situation = new_expr["situation"] - expr_obj.style = new_expr["style"] - expr_obj.count = expr_obj.count + 1 - expr_obj.last_active_time = current_time - else: - new_expression = Expression( - situation=new_expr["situation"], - style=new_expr["style"], - count=1, - last_active_time=current_time, - chat_id=chat_id, - type=type, - create_date=current_time, # 手动设置创建日期 - ) - session.add(new_expression) - # 限制最大数量 - exprs = list( - session.execute(select(Expression) - .where((Expression.chat_id == chat_id) & (Expression.type == type)) - .order_by(Expression.count.asc())).scalars() - ) - if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除count最小的多余表达方式 - for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: - session.delete(expr) - session.commit() - return learnt_expressions + if random.random() < 0.5: + expr_obj.situation = new_expr["situation"] + expr_obj.style = new_expr["style"] + expr_obj.count = expr_obj.count + 1 + expr_obj.last_active_time = current_time + else: + new_expression = Expression( + situation=new_expr["situation"], + style=new_expr["style"], + count=1, + last_active_time=current_time, + chat_id=chat_id, + type=type, + create_date=current_time, # 手动设置创建日期 + ) + session.add(new_expression) + session.commit() + # 限制最大数量 + exprs = list( + session.execute(select(Expression) + .where((Expression.chat_id == chat_id) & (Expression.type == type)) + .order_by(Expression.count.asc())).scalars() + ) + if len(exprs) > MAX_EXPRESSION_COUNT: + # 删除count最小的多余表达方式 + for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: + session.delete(expr) + session.commit() + + return learnt_expressions async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: """从指定聊天流学习表达方式 @@ -574,8 +578,8 @@ class ExpressionLearnerManager: continue # 查重:同chat_id+type+situation+style - - query = session.execute(select(Expression).where( + with get_db_session() as session: + query = session.execute(select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == type_str) & (Expression.situation == situation) @@ -596,6 +600,8 @@ class ExpressionLearnerManager: create_date=last_active_time, # 迁移时使用last_active_time作为创建时间 ) session.add(new_expression) + session.commit() + migrated_count += 1 logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式") except json.JSONDecodeError as e: @@ -627,21 +633,22 @@ class ExpressionLearnerManager: 使用last_active_time作为create_date的默认值 """ try: - # 查找所有create_date为空的表达方式 - old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars() - updated_count = 0 + with get_db_session() as session: + # 查找所有create_date为空的表达方式 + old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars() + updated_count = 0 - for expr in old_expressions: - # 使用last_active_time作为create_date - expr.create_date = expr.last_active_time - updated_count += 1 + for expr in old_expressions: + # 使用last_active_time作为create_date + expr.create_date = expr.last_active_time + updated_count += 1 - session.commit() + - if updated_count > 0: - logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") + if updated_count > 0: + logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") + session.commit() except Exception as e: - session.rollback() logger.error(f"迁移老数据创建日期失败: {e}") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 2338aa426..61b5d48b9 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -12,8 +12,7 @@ from src.common.logger import get_logger from sqlalchemy import select from src.common.database.sqlalchemy_models import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.common.database.sqlalchemy_database_api import get_session -session = get_session() +from src.common.database.sqlalchemy_database_api import get_db_session logger = get_logger("expression_selector") @@ -132,14 +131,14 @@ class ExpressionSelector: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - + with get_db_session() as session: # 优化:一次性查询所有相关chat_id的表达方式 - style_query = session.execute(select(Expression).where( - (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") - )) - grammar_query = session.execute(select(Expression).where( - (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") - )) + style_query = session.execute(select(Expression).where( + (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") + )) + grammar_query = session.execute(select(Expression).where( + (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") + )) style_exprs = [ { @@ -180,6 +179,7 @@ class ExpressionSelector: selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num) else: selected_grammar = [] + return selected_style, selected_grammar def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): @@ -199,7 +199,8 @@ class ExpressionSelector: if key not in updates_by_key: updates_by_key[key] = expr for chat_id, expr_type, situation, style in updates_by_key: - query = session.execute(select(Expression).where( + with get_db_session() as session: + query = session.execute(select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == expr_type) & (Expression.situation == situation) @@ -211,10 +212,11 @@ class ExpressionSelector: new_count = min(current_count + increment, 5.0) expr_obj.count = new_count expr_obj.last_active_time = time.time() - session.commit() + logger.debug( f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" ) + session.commit() async def select_suitable_expressions_llm( self, diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 719afd005..e084cbe57 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -19,7 +19,7 @@ from src.config.config import global_config, model_config from sqlalchemy import select,insert,update,delete from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入 from src.common.logger import get_logger -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp, @@ -30,7 +30,6 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable install(extra_lines=3) -session = get_session() def calculate_information_content(text): """计算文本的信息量(熵)""" @@ -862,13 +861,13 @@ class EntorhinalCortex: for message in messages: # 确保在更新前获取最新的 memorized_times current_memorized_times = message.get("memorized_times", 0) - # 使用 SQLAlchemy 2.0 更新记录 - session.execute( + with get_db_session() as session: + session.execute( update(Messages) .where(Messages.message_id == message["message_id"]) .values(memorized_times=current_memorized_times + 1) ) - session.commit() + session.commit() return messages # 直接返回原始的消息列表 target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 @@ -882,253 +881,260 @@ class EntorhinalCortex: current_time = datetime.datetime.now().timestamp() # 获取数据库中所有节点和内存中所有节点 - db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()} - memory_nodes = list(self.memory_graph.G.nodes(data=True)) + with get_db_session() as session: + db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()} + memory_nodes = list(self.memory_graph.G.nodes(data=True)) - # 批量准备节点数据 - nodes_to_create = [] - nodes_to_update = [] - nodes_to_delete = set() + # 批量准备节点数据 + nodes_to_create = [] + nodes_to_update = [] + nodes_to_delete = set() - # 处理节点 - for concept, data in memory_nodes: - if not concept or not isinstance(concept, str): - self.memory_graph.G.remove_node(concept) - continue - - memory_items = data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if not memory_items: - self.memory_graph.G.remove_node(concept) - continue - - # 计算内存中节点的特征值 - memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) - created_time = data.get("created_time", current_time) - last_modified = data.get("last_modified", current_time) - - # 将memory_items转换为JSON字符串 - try: - memory_items = [str(item) for item in memory_items] - memory_items_json = json.dumps(memory_items, ensure_ascii=False) - if not memory_items_json: + # 处理节点 + for concept, data in memory_nodes: + if not concept or not isinstance(concept, str): + self.memory_graph.G.remove_node(concept) continue - except Exception: - self.memory_graph.G.remove_node(concept) - continue - if concept not in db_nodes: - nodes_to_create.append( - { - "concept": concept, - "memory_items": memory_items_json, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - } - ) - else: - db_node = db_nodes[concept] - if db_node.hash != memory_hash: - nodes_to_update.append( + memory_items = data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + if not memory_items: + self.memory_graph.G.remove_node(concept) + continue + + # 计算内存中节点的特征值 + memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) + + # 将memory_items转换为JSON字符串 + try: + memory_items = [str(item) for item in memory_items] + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + if not memory_items_json: + continue + except Exception: + self.memory_graph.G.remove_node(concept) + continue + + if concept not in db_nodes: + nodes_to_create.append( { "concept": concept, "memory_items": memory_items_json, "hash": memory_hash, + "created_time": created_time, + "last_modified": last_modified, + } + ) + else: + db_node = db_nodes[concept] + if db_node.hash != memory_hash: + nodes_to_update.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": memory_hash, + "last_modified": last_modified, + } + ) + + # 计算需要删除的节点 + memory_concepts = {concept for concept, _ in memory_nodes} + nodes_to_delete = set(db_nodes.keys()) - memory_concepts + + # 批量处理节点 + if nodes_to_create: + batch_size = 100 + for i in range(0, len(nodes_to_create), batch_size): + batch = nodes_to_create[i : i + batch_size] + session.execute(insert(GraphNodes), batch) + + + if nodes_to_update: + batch_size = 100 + for i in range(0, len(nodes_to_update), batch_size): + batch = nodes_to_update[i : i + batch_size] + for node_data in batch: + session.execute( + update(GraphNodes) + .where(GraphNodes.concept == node_data["concept"]) + .values(**{k: v for k, v in node_data.items() if k != "concept"}) + ) + + + if nodes_to_delete: + session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) + + + # 处理边的信息 + db_edges = list(session.execute(select(GraphEdges)).scalars()) + memory_edges = list(self.memory_graph.G.edges(data=True)) + + # 创建边的哈希值字典 + db_edge_dict = {} + for edge in db_edges: + edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) + db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} + + # 批量准备边数据 + edges_to_create = [] + edges_to_update = [] + + # 处理边 + for source, target, data in memory_edges: + edge_hash = self.hippocampus.calculate_edge_hash(source, target) + edge_key = (source, target) + strength = data.get("strength", 1) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) + + if edge_key not in db_edge_dict: + edges_to_create.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "created_time": created_time, + "last_modified": last_modified, + } + ) + elif db_edge_dict[edge_key]["hash"] != edge_hash: + edges_to_update.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, "last_modified": last_modified, } ) - # 计算需要删除的节点 - memory_concepts = {concept for concept, _ in memory_nodes} - nodes_to_delete = set(db_nodes.keys()) - memory_concepts + # 计算需要删除的边 + memory_edge_keys = {(source, target) for source, target, _ in memory_edges} + edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys - # 批量处理节点 - if nodes_to_create: - batch_size = 100 - for i in range(0, len(nodes_to_create), batch_size): - batch = nodes_to_create[i : i + batch_size] - session.execute(insert(GraphNodes), batch) - session.commit() + # 批量处理边 + if edges_to_create: + batch_size = 100 + for i in range(0, len(edges_to_create), batch_size): + batch = edges_to_create[i : i + batch_size] + session.execute(insert(GraphEdges), batch) + - if nodes_to_update: - batch_size = 100 - for i in range(0, len(nodes_to_update), batch_size): - batch = nodes_to_update[i : i + batch_size] - for node_data in batch: + if edges_to_update: + batch_size = 100 + for i in range(0, len(edges_to_update), batch_size): + batch = edges_to_update[i : i + batch_size] + for edge_data in batch: + session.execute( + update(GraphEdges) + .where( + (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) + ) + .values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}) + ) + + + if edges_to_delete: + for source, target in edges_to_delete: session.execute( - update(GraphNodes) - .where(GraphNodes.concept == node_data["concept"]) - .values(**{k: v for k, v in node_data.items() if k != "concept"}) + delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target)) ) - session.commit() - if nodes_to_delete: - session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) + # 提交事务 session.commit() - # 处理边的信息 - db_edges = list(session.execute(select(GraphEdges)).scalars()) - memory_edges = list(self.memory_graph.G.edges(data=True)) - - # 创建边的哈希值字典 - db_edge_dict = {} - for edge in db_edges: - edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) - db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} - - # 批量准备边数据 - edges_to_create = [] - edges_to_update = [] - - # 处理边 - for source, target, data in memory_edges: - edge_hash = self.hippocampus.calculate_edge_hash(source, target) - edge_key = (source, target) - strength = data.get("strength", 1) - created_time = data.get("created_time", current_time) - last_modified = data.get("last_modified", current_time) - - if edge_key not in db_edge_dict: - edges_to_create.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - } - ) - elif db_edge_dict[edge_key]["hash"] != edge_hash: - edges_to_update.append( - { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "last_modified": last_modified, - } - ) - - # 计算需要删除的边 - memory_edge_keys = {(source, target) for source, target, _ in memory_edges} - edges_to_delete = set(db_edge_dict.keys()) - memory_edge_keys - - # 批量处理边 - if edges_to_create: - batch_size = 100 - for i in range(0, len(edges_to_create), batch_size): - batch = edges_to_create[i : i + batch_size] - session.execute(insert(GraphEdges), batch) - session.commit() - - if edges_to_update: - batch_size = 100 - for i in range(0, len(edges_to_update), batch_size): - batch = edges_to_update[i : i + batch_size] - for edge_data in batch: - session.execute( - update(GraphEdges) - .where( - (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) - ) - .values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}) - ) - session.commit() - - if edges_to_delete: - for source, target in edges_to_delete: - session.execute( - delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target)) - ) - session.commit() + end_time = time.time() logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒") logger.info(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") - + async def resync_memory_to_db(self): """清空数据库并重新同步所有记忆数据""" start_time = time.time() logger.info("[数据库] 开始重新同步所有记忆数据...") # 清空数据库 - clear_start = time.time() - session.execute(delete(GraphNodes)) - session.execute(delete(GraphEdges)) - session.commit() - clear_end = time.time() - logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") + with get_db_session() as session: + clear_start = time.time() + session.execute(delete(GraphNodes)) + session.execute(delete(GraphEdges)) + + clear_end = time.time() + logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") - # 获取所有节点和边 - memory_nodes = list(self.memory_graph.G.nodes(data=True)) - memory_edges = list(self.memory_graph.G.edges(data=True)) - current_time = datetime.datetime.now().timestamp() + # 获取所有节点和边 + memory_nodes = list(self.memory_graph.G.nodes(data=True)) + memory_edges = list(self.memory_graph.G.edges(data=True)) + current_time = datetime.datetime.now().timestamp() - # 批量准备节点数据 - nodes_data = [] - for concept, data in memory_nodes: - memory_items = data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] + # 批量准备节点数据 + nodes_data = [] + for concept, data in memory_nodes: + memory_items = data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] - try: - memory_items = [str(item) for item in memory_items] - if memory_items_json := json.dumps(memory_items, ensure_ascii=False): - nodes_data.append( + try: + memory_items = [str(item) for item in memory_items] + if memory_items_json := json.dumps(memory_items, ensure_ascii=False): + nodes_data.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": self.hippocampus.calculate_node_hash(concept, memory_items), + "created_time": data.get("created_time", current_time), + "last_modified": data.get("last_modified", current_time), + } + ) + + except Exception as e: + logger.error(f"准备节点 {concept} 数据时发生错误: {e}") + continue + + # 批量准备边数据 + edges_data = [] + for source, target, data in memory_edges: + try: + edges_data.append( { - "concept": concept, - "memory_items": memory_items_json, - "hash": self.hippocampus.calculate_node_hash(concept, memory_items), + "source": source, + "target": target, + "strength": data.get("strength", 1), + "hash": self.hippocampus.calculate_edge_hash(source, target), "created_time": data.get("created_time", current_time), "last_modified": data.get("last_modified", current_time), } ) + except Exception as e: + logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}") + continue - except Exception as e: - logger.error(f"准备节点 {concept} 数据时发生错误: {e}") - continue + # 批量写入节点 + node_start = time.time() + if nodes_data: + batch_size = 500 # 增加批量大小 + for i in range(0, len(nodes_data), batch_size): + batch = nodes_data[i : i + batch_size] + session.execute(insert(GraphNodes), batch) + + node_end = time.time() + logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") - # 批量准备边数据 - edges_data = [] - for source, target, data in memory_edges: - try: - edges_data.append( - { - "source": source, - "target": target, - "strength": data.get("strength", 1), - "hash": self.hippocampus.calculate_edge_hash(source, target), - "created_time": data.get("created_time", current_time), - "last_modified": data.get("last_modified", current_time), - } - ) - except Exception as e: - logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}") - continue - - # 批量写入节点 - node_start = time.time() - if nodes_data: - batch_size = 500 # 增加批量大小 - for i in range(0, len(nodes_data), batch_size): - batch = nodes_data[i : i + batch_size] - session.execute(insert(GraphNodes), batch) - session.commit() - node_end = time.time() - logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") - - # 批量写入边 - edge_start = time.time() - if edges_data: - batch_size = 500 # 增加批量大小 - for i in range(0, len(edges_data), batch_size): - batch = edges_data[i : i + batch_size] - session.execute(insert(GraphEdges), batch) - session.commit() + # 批量写入边 + edge_start = time.time() + if edges_data: + batch_size = 500 # 增加批量大小 + for i in range(0, len(edges_data), batch_size): + batch = edges_data[i : i + batch_size] + session.execute(insert(GraphEdges), batch) + session.commit() + edge_end = time.time() logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") @@ -1145,77 +1151,79 @@ class EntorhinalCortex: self.memory_graph.G.clear() # 从数据库加载所有节点 - nodes = list(session.execute(select(GraphNodes)).scalars()) - for node in nodes: - concept = node.concept - try: - memory_items = json.loads(node.memory_items) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] + with get_db_session() as session: + nodes = list(session.execute(select(GraphNodes)).scalars()) + for node in nodes: + concept = node.concept + try: + memory_items = json.loads(node.memory_items) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 检查时间字段是否存在 + if not node.created_time or not node.last_modified: + need_update = True + # 更新数据库中的节点 + update_data = {} + if not node.created_time: + update_data["created_time"] = current_time + if not node.last_modified: + update_data["last_modified"] = current_time + + session.execute( + update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) + ) + + + # 获取时间信息(如果不存在则使用当前时间) + created_time = node.created_time or current_time + last_modified = node.last_modified or current_time + + # 添加节点到图中 + self.memory_graph.G.add_node( + concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified + ) + except Exception as e: + logger.error(f"加载节点 {concept} 时发生错误: {e}") + continue + + # 从数据库加载所有边 + edges = list(session.execute(select(GraphEdges)).scalars()) + for edge in edges: + source = edge.source + target = edge.target + strength = edge.strength # 检查时间字段是否存在 - if not node.created_time or not node.last_modified: + if not edge.created_time or not edge.last_modified: need_update = True - # 更新数据库中的节点 + # 更新数据库中的边 update_data = {} - if not node.created_time: + if not edge.created_time: update_data["created_time"] = current_time - if not node.last_modified: + if not edge.last_modified: update_data["last_modified"] = current_time session.execute( - update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) + update(GraphEdges) + .where((GraphEdges.source == source) & (GraphEdges.target == target)) + .values(**update_data) ) - session.commit() + # 获取时间信息(如果不存在则使用当前时间) - created_time = node.created_time or current_time - last_modified = node.last_modified or current_time + created_time = edge.created_time or current_time + last_modified = edge.last_modified or current_time - # 添加节点到图中 - self.memory_graph.G.add_node( - concept, memory_items=memory_items, created_time=created_time, last_modified=last_modified - ) - except Exception as e: - logger.error(f"加载节点 {concept} 时发生错误: {e}") - continue + # 只有当源节点和目标节点都存在时才添加边 + if source in self.memory_graph.G and target in self.memory_graph.G: + self.memory_graph.G.add_edge( + source, target, strength=strength, created_time=created_time, last_modified=last_modified + ) + session.commit() - # 从数据库加载所有边 - edges = list(session.execute(select(GraphEdges)).scalars()) - for edge in edges: - source = edge.source - target = edge.target - strength = edge.strength - - # 检查时间字段是否存在 - if not edge.created_time or not edge.last_modified: - need_update = True - # 更新数据库中的边 - update_data = {} - if not edge.created_time: - update_data["created_time"] = current_time - if not edge.last_modified: - update_data["last_modified"] = current_time - - session.execute( - update(GraphEdges) - .where((GraphEdges.source == source) & (GraphEdges.target == target)) - .values(**update_data) - ) - session.commit() - - # 获取时间信息(如果不存在则使用当前时间) - created_time = edge.created_time or current_time - last_modified = edge.last_modified or current_time - - # 只有当源节点和目标节点都存在时才添加边 - if source in self.memory_graph.G and target in self.memory_graph.G: - self.memory_graph.G.add_edge( - source, target, strength=strength, created_time=created_time, last_modified=last_modified - ) - - if need_update: - logger.info("[数据库] 已为缺失的时间字段进行补充") + if need_update: + logger.info("[数据库] 已为缺失的时间字段进行补充") # 负责整合,遗忘,合并记忆 diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index d7fdd32e3..b55784c2e 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -11,12 +11,11 @@ from datetime import datetime, timedelta from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入 -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.config.config import model_config from sqlalchemy import select logger = get_logger(__name__) -session = get_session() class MemoryItem: def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]): @@ -113,7 +112,8 @@ class InstantMemory: logger.info(f"不需要记忆:{text}") async def store_memory(self, memory_item: MemoryItem): - memory = Memory( + with get_db_session() as session: + memory = Memory( memory_id=memory_item.memory_id, chat_id=memory_item.chat_id, memory_text=memory_item.memory_text, @@ -121,8 +121,8 @@ class InstantMemory: create_time=memory_item.create_time, last_view_time=memory_item.last_view_time, ) - session.add(memory) - session.commit() + session.add(memory) + session.commit() async def get_memory(self, target: str): from json_repair import repair_json @@ -165,17 +165,18 @@ class InstantMemory: logger.info(f"start_time: {start_time}, end_time: {end_time}") # 检索包含关键词的记忆 memories_set = set() - if start_time and end_time: - start_ts = start_time.timestamp() - end_ts = end_time.timestamp() - query = session.execute(select(Memory).where( - (Memory.chat_id == self.chat_id) - & (Memory.create_time >= start_ts) - & (Memory.create_time < end_ts) - )).scalars() - else: - query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars() - + with get_db_session() as session: + if start_time and end_time: + start_ts = start_time.timestamp() + end_ts = end_time.timestamp() + + query = session.execute(select(Memory).where( + (Memory.chat_id == self.chat_id) + & (Memory.create_time >= start_ts) + & (Memory.create_time < end_ts) + )).scalars() + else: + query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars() for mem in query: # 对每条记忆 mem_keywords = mem.keywords or "" diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index dcc616bfb..47655dd09 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -274,6 +274,7 @@ class ChatBot: # logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}") # return + # 过滤检查 if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore message.raw_message, # type: ignore diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 04e0299e2..5ad7f2654 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -12,7 +12,7 @@ from sqlalchemy import select, text from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.dialects.mysql import insert as mysql_insert from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.config.config import global_config # 新增导入 # 避免循环导入,使用TYPE_CHECKING进行类型提示 if TYPE_CHECKING: @@ -23,7 +23,6 @@ install(extra_lines=3) logger = get_logger("chat_stream") -session = get_session() class ChatMessageContext: """聊天消息上下文,存储消息的上下文信息""" @@ -132,13 +131,14 @@ class ChatManager: if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message - try: - db.connect(reuse_if_open=True) - # 确保 ChatStreams 表存在 - session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) - session.commit() - except Exception as e: - logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + # try: + # with get_db_session() as session: + # db.connect(reuse_if_open=True) + # # 确保 ChatStreams 表存在 + # session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)")) + # session.commit() + # except Exception as e: + # logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") self._initialized = True # 在事件循环中启动初始化 @@ -236,7 +236,8 @@ class ChatManager: # 检查数据库中是否存在 def _db_find_stream_sync(s_id: str): - return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar() + with get_db_session() as session: + return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar() model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) @@ -331,44 +332,45 @@ class ChatManager: stream_data_dict = stream.to_dict() def _db_save_stream_sync(s_data_dict: dict): - user_info_d = s_data_dict.get("user_info") - group_info_d = s_data_dict.get("group_info") + with get_db_session() as session: + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") - fields_to_save = { - "platform": s_data_dict["platform"], - "create_time": s_data_dict["create_time"], - "last_active_time": s_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d["platform"] if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", - } + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } - # 根据数据库类型选择插入语句 - if global_config.database.database_type == "sqlite": - stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - stmt = stmt.on_conflict_do_update( - index_elements=['stream_id'], - set_=fields_to_save - ) - elif global_config.database.database_type == "mysql": - stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - stmt = stmt.on_duplicate_key_update( - **{key: value for key, value in fields_to_save.items() if key != "stream_id"} - ) - else: - # 默认使用通用插入,尝试SQLite语法 - stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - stmt = stmt.on_conflict_do_update( - index_elements=['stream_id'], - set_=fields_to_save - ) + # 根据数据库类型选择插入语句 + if global_config.database.database_type == "sqlite": + stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) + stmt = stmt.on_conflict_do_update( + index_elements=['stream_id'], + set_=fields_to_save + ) + elif global_config.database.database_type == "mysql": + stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) + stmt = stmt.on_duplicate_key_update( + **{key: value for key, value in fields_to_save.items() if key != "stream_id"} + ) + else: + # 默认使用通用插入,尝试SQLite语法 + stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) + stmt = stmt.on_conflict_do_update( + index_elements=['stream_id'], + set_=fields_to_save + ) - session.execute(stmt) - session.commit() + session.execute(stmt) + session.commit() try: await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) @@ -387,30 +389,32 @@ class ChatManager: def _db_load_all_streams_sync(): loaded_streams_data = [] - for model_instance in session.execute(select(ChatStreams)).scalars(): - user_info_data = { - "platform": model_instance.user_platform, - "user_id": model_instance.user_id, - "user_nickname": model_instance.user_nickname, - "user_cardname": model_instance.user_cardname or "", - } - group_info_data = None - if model_instance.group_id: - group_info_data = { - "platform": model_instance.group_platform, - "group_id": model_instance.group_id, - "group_name": model_instance.group_name, + with get_db_session() as session: + for model_instance in session.execute(select(ChatStreams)).scalars(): + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } - data_for_from_dict = { - "stream_id": model_instance.stream_id, - "platform": model_instance.platform, - "user_info": user_info_data, - "group_info": group_info_data, - "create_time": model_instance.create_time, - "last_active_time": model_instance.last_active_time, - } - loaded_streams_data.append(data_for_from_dict) + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + loaded_streams_data.append(data_for_from_dict) + session.commit() return loaded_streams_data try: diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 9df3997ba..e2d7808ea 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -7,7 +7,7 @@ from src.common.database.sqlalchemy_models import Messages, Images from src.common.logger import get_logger from .chat_stream import ChatStream from .message import MessageSending, MessageRecv -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from sqlalchemy import select, update, desc logger = get_logger("message_storage") @@ -70,7 +70,6 @@ class MessageStorage: priority_info_json = json.dumps(priority_info) if priority_info else None # 获取数据库会话 - session = get_session() new_message = Messages( message_id=msg_id, @@ -104,8 +103,10 @@ class MessageStorage: is_notify=is_notify, is_command=is_command, ) - session.add(new_message) - session.commit() + with get_db_session() as session: + session.add(new_message) + session.commit() + except Exception: logger.exception("存储消息失败") logger.error(f"消息:{message}") @@ -155,7 +156,8 @@ class MessageStorage: session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) - # session.commit() 会在上下文管理器中自动调用 + session.commit() + # 会在上下文管理器中自动调用 logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: logger.warning(f"未找到匹配的消息记录: {mmc_message_id}") @@ -184,6 +186,7 @@ class MessageStorage: image_record = session.execute( select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) ).scalar() + session.commit() return f"[picid:{image_record.image_id}]" if image_record else match.group(0) except Exception: return match.group(0) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 5d57cc793..4dde7de2b 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -10,7 +10,7 @@ from src.common.message_repository import find_messages, count_messages from src.common.database.sqlalchemy_models import ActionRecords, Images from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from sqlalchemy import select, and_ install(extra_lines=3) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 667582a7b..b2ddcd039 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -420,7 +420,7 @@ class StatisticOutputTask(AsyncTask): stats[period_key][COST_BY_MODULE][module_name] += cost # 收集time_cost数据 - time_cost = record.time_cost or 0.0 + time_cost = record.get('time_cost') or 0.0 if time_cost > 0: # 只记录有效的time_cost stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost) stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost) diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 3ba4084da..92d8d3bd6 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -41,12 +41,12 @@ class ImageManager: self._initialized = True self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") - try: - db.connect(reuse_if_open=True) - # 使用SQLAlchemy创建表已在初始化时完成 - logger.debug("使用SQLAlchemy进行表管理") - except Exception as e: - logger.error(f"数据库连接失败: {e}") + # try: + # db.connect(reuse_if_open=True) + # # 使用SQLAlchemy创建表已在初始化时完成 + # logger.debug("使用SQLAlchemy进行表管理") + # except Exception as e: + # logger.error(f"数据库连接失败: {e}") self._initialized = True @@ -105,7 +105,8 @@ class ImageManager: timestamp=current_timestamp ) session.add(new_desc) - # session.commit() 会在上下文管理器中自动调用 + session.commit() + # 会在上下文管理器中自动调用 except Exception as e: logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") @@ -246,7 +247,8 @@ class ImageManager: timestamp=current_timestamp, ) session.add(new_img) - # session.commit() 会在上下文管理器中自动调用 + session.commit() + # 会在上下文管理器中自动调用 except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -323,7 +325,7 @@ class ImageManager: existing_image.image_id = str(uuid.uuid4()) if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: existing_image.vlm_processed = True - session.commit() + logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") else: new_img = Images( @@ -337,7 +339,7 @@ class ImageManager: count=1, ) session.add(new_img) - session.commit() + logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") except Exception as e: logger.error(f"保存图片文件或元数据失败: {str(e)}") @@ -511,35 +513,35 @@ class ImageManager: existing_image.vlm_processed = False existing_image.count += 1 - session.commit() + return existing_image.image_id, f"[picid:{existing_image.image_id}]" - # print(f"图片不存在: {image_hash}") - image_id = str(uuid.uuid4()) + # print(f"图片不存在: {image_hash}") + image_id = str(uuid.uuid4()) - # 保存新图片 - current_timestamp = time.time() - image_dir = os.path.join(self.IMAGE_DIR, "images") - os.makedirs(image_dir, exist_ok=True) - filename = f"{image_id}.png" - file_path = os.path.join(image_dir, filename) + # 保存新图片 + current_timestamp = time.time() + image_dir = os.path.join(self.IMAGE_DIR, "images") + os.makedirs(image_dir, exist_ok=True) + filename = f"{image_id}.png" + file_path = os.path.join(image_dir, filename) - # 保存文件 - with open(file_path, "wb") as f: - f.write(image_bytes) + # 保存文件 + with open(file_path, "wb") as f: + f.write(image_bytes) - # 保存到数据库 - new_img = Images( - image_id=image_id, - emoji_hash=image_hash, - path=file_path, - type="image", - timestamp=current_timestamp, - vlm_processed=False, - count=1, - ) - session.add(new_img) - session.commit() + # 保存到数据库 + new_img = Images( + image_id=image_id, + emoji_hash=image_hash, + path=file_path, + type="image", + timestamp=current_timestamp, + vlm_processed=False, + count=1, + ) + session.add(new_img) + session.commit() # 启动异步VLM处理 asyncio.create_task(self._process_image_with_vlm(image_id, image_base64)) @@ -581,7 +583,7 @@ class ImageManager: logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") image.description = existing_with_description.description image.vlm_processed = True - session.commit() + # 同时保存到ImageDescriptions表作为备用缓存 self._save_description_to_db(image_hash, existing_with_description.description, "image") return @@ -591,7 +593,7 @@ class ImageManager: logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") image.description = cached_description image.vlm_processed = True - session.commit() + return # 获取图片格式 diff --git a/src/common/database/database.py b/src/common/database/database.py index 817bc084a..876c6846d 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -6,7 +6,7 @@ from src.common.logger import get_logger # SQLAlchemy相关导入 from src.common.database.sqlalchemy_init import initialize_database_compat -from src.common.database.sqlalchemy_models import get_engine, get_session +from src.common.database.sqlalchemy_models import get_engine, get_db_session install(extra_lines=3) @@ -18,7 +18,7 @@ logger = get_logger("database") # 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy class DatabaseProxy: - """数据库代理类,提供Peewee到SQLAlchemy的兼容性接口""" + """数据库代理类""" def __init__(self): self._engine = None @@ -28,56 +28,7 @@ class DatabaseProxy: """初始化数据库连接""" return initialize_database_compat() - def connect(self, reuse_if_open=True): - """连接数据库(兼容性方法)""" - try: - self._engine = get_engine() - return True - except Exception as e: - logger.error(f"数据库连接失败: {e}") - return False - - def is_closed(self): - """检查数据库是否关闭(兼容性方法)""" - return self._engine is None - - def create_tables(self, models, safe=True): - """创建表(兼容性方法)""" - try: - from src.common.database.sqlalchemy_models import Base - engine = get_engine() - Base.metadata.create_all(bind=engine) - return True - except Exception as e: - logger.error(f"创建表失败: {e}") - return False - - def table_exists(self, model): - """检查表是否存在(兼容性方法)""" - try: - from sqlalchemy import inspect - engine = get_engine() - inspector = inspect(engine) - table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower()) - return table_name in inspector.get_table_names() - except Exception: - return False - - def execute_sql(self, sql): - """执行SQL(兼容性方法)""" - try: - from sqlalchemy import text - session = get_session() - result = session.execute(text(sql)) - session.close() - return result - except Exception as e: - logger.error(f"执行SQL失败: {e}") - raise - - def atomic(self): - """事务上下文管理器(兼容性方法)""" - return SQLAlchemyTransaction() + class SQLAlchemyTransaction: """SQLAlchemy事务上下文管理器""" @@ -86,7 +37,7 @@ class SQLAlchemyTransaction: self.session = None def __enter__(self): - self.session = get_session() + self.session = get_db_session() return self.session def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/src/common/database/sqlalchemy_database_api.py b/src/common/database/sqlalchemy_database_api.py index 0ed797795..4c773f74e 100644 --- a/src/common/database/sqlalchemy_database_api.py +++ b/src/common/database/sqlalchemy_database_api.py @@ -15,7 +15,7 @@ from src.common.logger import get_logger from src.common.database.sqlalchemy_models import ( Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams, LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory, - Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus, get_session + Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus ) logger = get_logger("sqlalchemy_database_api") @@ -41,38 +41,9 @@ MODEL_MAPPING = { } -@contextmanager -def get_db_session(): - """数据库会话上下文管理器,自动处理事务和连接错误""" - session = None - max_retries = 3 - retry_delay = 1.0 - - for attempt in range(max_retries): - try: - session = get_session() - yield session - session.commit() - break - except (DisconnectionError, OperationalError) as e: - logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}") - if session: - session.rollback() - session.close() - if attempt < max_retries - 1: - time.sleep(retry_delay * (attempt + 1)) - else: - raise - except Exception: - if session: - session.rollback() - raise - finally: - if session: - session.close() -def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]): +def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]): """构建查询过滤条件""" conditions = [] @@ -296,6 +267,7 @@ async def db_save( # 创建新记录 new_record = model_class(**data) session.add(new_record) + session.commit() session.flush() # 转换为字典格式返回 @@ -415,8 +387,3 @@ async def store_action_info( traceback.print_exc() return None - -# 兼容性函数,方便从Peewee迁移 -def get_model_class(model_name: str) -> Optional[Type[Base]]: - """根据模型名称获取模型类""" - return MODEL_MAPPING.get(model_name) diff --git a/src/common/database/sqlalchemy_init.py b/src/common/database/sqlalchemy_init.py index a1fdb7763..3b4fb4f88 100644 --- a/src/common/database/sqlalchemy_init.py +++ b/src/common/database/sqlalchemy_init.py @@ -8,7 +8,7 @@ from typing import Optional from sqlalchemy.exc import SQLAlchemyError from src.common.logger import get_logger from src.common.database.sqlalchemy_models import ( - Base, get_engine, get_session, initialize_database + Base, get_engine, initialize_database ) logger = get_logger("sqlalchemy_init") @@ -72,36 +72,6 @@ def create_all_tables() -> bool: return False -def check_database_connection() -> bool: - """ - 检查数据库连接是否正常 - - Returns: - bool: 连接是否正常 - """ - try: - session = get_session() - if session is None: - logger.error("无法获取数据库会话") - return False - - # 检查会话是否可用(如果能获取到会话说明连接正常) - if session is None: - logger.error("数据库会话无效") - return False - - session.close() - - logger.info("数据库连接检查通过") - return True - - except SQLAlchemyError as e: - logger.error(f"数据库连接检查失败: {e}") - return False - except Exception as e: - logger.error(f"数据库连接检查过程中发生未知错误: {e}") - return False - def get_database_info() -> Optional[dict]: """ @@ -149,9 +119,6 @@ def initialize_database_compat() -> bool: if success: success = create_all_tables() - if success: - success = check_database_connection() - if success: _database_initialized = True diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 983bf7578..5bd342a88 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -29,102 +29,6 @@ def get_string_field(max_length=255, **kwargs): return String(max_length, **kwargs) else: return Text(**kwargs) - - - -class SessionProxy: - """线程安全的Session代理类,自动管理session生命周期""" - - def __init__(self): - self._local = threading.local() - - def _get_current_session(self): - """获取当前线程的session,如果没有则创建新的""" - if not hasattr(self._local, 'session') or self._local.session is None: - _, SessionLocal = initialize_database() - self._local.session = SessionLocal() - return self._local.session - - def _close_current_session(self): - """关闭当前线程的session""" - if hasattr(self._local, 'session') and self._local.session is not None: - try: - self._local.session.close() - except: - pass - finally: - self._local.session = None - - def __getattr__(self, name): - """代理所有session方法""" - session = self._get_current_session() - attr = getattr(session, name) - - # 如果是方法,需要特殊处理一些关键方法 - if callable(attr): - if name in ['commit', 'rollback']: - def wrapper(*args, **kwargs): - try: - result = attr(*args, **kwargs) - if name == 'commit': - # commit后不要清除session,只是刷新状态 - pass # 保持session活跃 - return result - except Exception: - try: - if session and hasattr(session, 'rollback'): - session.rollback() - except: - pass - # 发生错误时重新创建session - self._close_current_session() - raise - return wrapper - elif name == 'close': - def wrapper(*args, **kwargs): - result = attr(*args, **kwargs) - self._close_current_session() - return result - return wrapper - elif name in ['execute', 'query', 'add', 'delete', 'merge']: - def wrapper(*args, **kwargs): - try: - return attr(*args, **kwargs) - except Exception as e: - # 如果是连接相关错误,重新创建session再试一次 - if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e): - logger.warning(f"Session问题,重新创建session: {e}") - self._close_current_session() - new_session = self._get_current_session() - new_attr = getattr(new_session, name) - return new_attr(*args, **kwargs) - raise - return wrapper - - return attr - - def new_session(self): - """强制创建新的session(关闭当前的,创建新的)""" - self._close_current_session() - return self._get_current_session() - - def ensure_fresh_session(self): - """确保使用新鲜的session(如果当前session有问题则重新创建)""" - if hasattr(self._local, 'session') and self._local.session is not None: - try: - # 测试session是否还可用 - self._local.session.execute("SELECT 1") - except Exception: - # session有问题,重新创建 - self._close_current_session() - return self._get_current_session() - -# 创建全局session代理实例 -_global_session_proxy = SessionProxy() - -def get_session(): - """返回线程安全的session代理,自动管理生命周期""" - return _global_session_proxy class ChatStreams(Base): @@ -482,6 +386,22 @@ class MaiZoneScheduleStatus(Base): ) +class BanUser(Base): + """被禁用用户模型""" + __tablename__ = 'ban_users' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(get_string_field(50), nullable=False, index=True) + violation_num = Column(Integer, nullable=False, default=0) + reason = Column(Text, nullable=False) + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + + __table_args__ = ( + Index('idx_violation_num', 'violation_num'), + Index('idx_banuser_user_id', 'user_id'), + ) + + # 数据库引擎和会话管理 _engine = None _SessionLocal = None @@ -593,7 +513,7 @@ def get_db_session(): _, SessionLocal = initialize_database() session = SessionLocal() yield session - session.commit() + # session.commit() except Exception: if session: session.rollback() @@ -601,6 +521,7 @@ def get_db_session(): finally: if session: session.close() + def get_engine(): diff --git a/src/common/message_repository.py b/src/common/message_repository.py index bba1e2e05..0b59dbfc7 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -8,7 +8,7 @@ from src.config.config import global_config # from src.common.database.database_model import Messages from src.common.database.sqlalchemy_models import Messages -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.common.logger import get_logger logger = get_logger(__name__) @@ -44,92 +44,92 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - session = get_session() - query = select(Messages) + with get_db_session() as session: + query = select(Messages) - # 应用过滤器 - if message_filter: - conditions = [] - for key, value in message_filter.items(): - if hasattr(Messages, key): - field = getattr(Messages, key) - if isinstance(value, dict): - # 处理 MongoDB 风格的操作符 - for op, op_value in value.items(): - if op == "$gt": - conditions.append(field > op_value) - elif op == "$lt": - conditions.append(field < op_value) - elif op == "$gte": - conditions.append(field >= op_value) - elif op == "$lte": - conditions.append(field <= op_value) - elif op == "$ne": - conditions.append(field != op_value) - elif op == "$in": - conditions.append(field.in_(op_value)) - elif op == "$nin": - conditions.append(field.not_in(op_value)) - else: - logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + else: + # 直接相等比较 + conditions.append(field == value) else: - # 直接相等比较 - conditions.append(field == value) - else: - logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - if conditions: - query = query.where(*conditions) + logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) - if filter_bot: - query = query.where(Messages.user_id != global_config.bot.qq_account) + if filter_bot: + query = query.where(Messages.user_id != global_config.bot.qq_account) - if filter_command: - query = query.where(not_(Messages.is_command)) + if filter_command: + query = query.where(not_(Messages.is_command)) - if limit > 0: - # 确保limit是正整数 - limit = max(1, int(limit)) + if limit > 0: + # 确保limit是正整数 + limit = max(1, int(limit)) - if limit_mode == "earliest": - # 获取时间最早的 limit 条记录,已经是正序 - query = query.order_by(Messages.time.asc()).limit(limit) + if limit_mode == "earliest": + # 获取时间最早的 limit 条记录,已经是正序 + query = query.order_by(Messages.time.asc()).limit(limit) + try: + results = session.execute(query).scalars().all() + except Exception as e: + logger.error(f"执行earliest查询失败: {e}") + results = [] + else: # 默认为 'latest' + # 获取时间最晚的 limit 条记录 + query = query.order_by(Messages.time.desc()).limit(limit) + try: + latest_results = session.execute(query).scalars().all() + # 将结果按时间正序排列 + results = sorted(latest_results, key=lambda msg: msg.time) + except Exception as e: + logger.error(f"执行latest查询失败: {e}") + results = [] + else: + # limit 为 0 时,应用传入的 sort 参数 + if sort: + sort_terms = [] + for field_name, direction in sort: + if hasattr(Messages, field_name): + field = getattr(Messages, field_name) + if direction == 1: # ASC + sort_terms.append(field.asc()) + elif direction == -1: # DESC + sort_terms.append(field.desc()) + else: + logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") + else: + logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") + if sort_terms: + query = query.order_by(*sort_terms) try: results = session.execute(query).scalars().all() except Exception as e: - logger.error(f"执行earliest查询失败: {e}") + logger.error(f"执行无限制查询失败: {e}") results = [] - else: # 默认为 'latest' - # 获取时间最晚的 limit 条记录 - query = query.order_by(Messages.time.desc()).limit(limit) - try: - latest_results = session.execute(query).scalars().all() - # 将结果按时间正序排列 - results = sorted(latest_results, key=lambda msg: msg.time) - except Exception as e: - logger.error(f"执行latest查询失败: {e}") - results = [] - else: - # limit 为 0 时,应用传入的 sort 参数 - if sort: - sort_terms = [] - for field_name, direction in sort: - if hasattr(Messages, field_name): - field = getattr(Messages, field_name) - if direction == 1: # ASC - sort_terms.append(field.asc()) - elif direction == -1: # DESC - sort_terms.append(field.desc()) - else: - logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") - else: - logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") - if sort_terms: - query = query.order_by(*sort_terms) - try: - results = session.execute(query).scalars().all() - except Exception as e: - logger.error(f"执行无限制查询失败: {e}") - results = [] return [_model_to_dict(msg) for msg in results] except Exception as e: @@ -152,50 +152,50 @@ def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - session = get_session() - query = select(func.count(Messages.id)) + with get_db_session() as session: + query = select(func.count(Messages.id)) - # 应用过滤器 - if message_filter: - conditions = [] - for key, value in message_filter.items(): - if hasattr(Messages, key): - field = getattr(Messages, key) - if isinstance(value, dict): - # 处理 MongoDB 风格的操作符 - for op, op_value in value.items(): - if op == "$gt": - conditions.append(field > op_value) - elif op == "$lt": - conditions.append(field < op_value) - elif op == "$gte": - conditions.append(field >= op_value) - elif op == "$lte": - conditions.append(field <= op_value) - elif op == "$ne": - conditions.append(field != op_value) - elif op == "$in": - conditions.append(field.in_(op_value)) - elif op == "$nin": - conditions.append(field.not_in(op_value)) - else: - logger.warning( - f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" - ) + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + # 直接相等比较 + conditions.append(field == value) else: - # 直接相等比较 - conditions.append(field == value) - else: - logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - if conditions: - query = query.where(*conditions) + logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) - count = session.execute(query).scalar() - return count or 0 + count = session.execute(query).scalar() + return count or 0 except Exception as e: - log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" - logger.error(log_message) - return 0 + log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" + logger.error(log_message) + return 0 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 630e3e3fb..83dce2f4e 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -5,7 +5,7 @@ from PIL import Image from datetime import datetime from src.common.logger import get_logger -from src.common.database.sqlalchemy_models import LLMUsage, get_session +from src.common.database.sqlalchemy_models import LLMUsage, get_db_session from src.config.api_ada_configs import ModelInfo from .payload_content.message import Message, MessageBuilder from .model_client.base_client import UsageRecord @@ -156,9 +156,8 @@ class LLMUsageRecorder: session = None try: # 使用 SQLAlchemy 会话创建记录 - session = get_session() - - usage_record = LLMUsage( + with get_db_session() as session: + usage_record = LLMUsage( model_name=model_info.model_identifier, model_assign_name=model_info.name, model_api_provider=model_info.api_provider, @@ -174,8 +173,8 @@ class LLMUsageRecorder: timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段 ) - session.add(usage_record) - session.commit() + session.add(usage_record) + session.commit() logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " @@ -184,11 +183,7 @@ class LLMUsageRecorder: f"总计: {model_usage.total_tokens}" ) except Exception as e: - if session: - session.rollback() logger.error(f"记录token使用情况失败: {str(e)}") - finally: - if session: - session.close() + llm_usage_recorder = LLMUsageRecorder() \ No newline at end of file diff --git a/src/manager/schedule_manager.py b/src/manager/schedule_manager.py index 2419e3bdd..91ee7de01 100644 --- a/src/manager/schedule_manager.py +++ b/src/manager/schedule_manager.py @@ -212,6 +212,7 @@ class ScheduleManager: setattr(new_schedule, 'date', today_str) setattr(new_schedule, 'schedule_data', json.dumps(schedule_data)) session.add(new_schedule) + session.commit() # 美化输出 schedule_str = f"已成功生成并保存今天的日程 ({today_str}):\n" diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 1c7ffbd3e..4889b4b99 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -11,10 +11,9 @@ from sqlalchemy import select from src.common.logger import get_logger from src.common.database.database import db from src.common.database.sqlalchemy_models import PersonInfo -from src.common.database.sqlalchemy_database_api import get_session +from src.common.database.sqlalchemy_database_api import get_db_session from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config -session = get_session() """ PersonInfoManager 类方法功能摘要: @@ -52,56 +51,37 @@ person_info_default = { "attitude": 50, } -# 统一的会话管理函数 -def with_session(func): - """装饰器:为函数自动注入session参数""" - if asyncio.iscoroutinefunction(func): - async def async_wrapper(*args, **kwargs): - - return await func(session, *args, **kwargs) - return async_wrapper - else: - def sync_wrapper(*args, **kwargs): - - return func(session, *args, **kwargs) - return sync_wrapper - -# 全局会话获取函数,用于替换所有裸露的session使用 -def _get_session(): - """获取数据库会话的统一函数""" - return get_session() - class PersonInfoManager: def __init__(self): """初始化PersonInfoManager""" - from src.common.database.sqlalchemy_models import PersonInfo self.person_name_list = {} self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") - try: - db.connect(reuse_if_open=True) - # 设置连接池参数(仅对SQLite有效) - if hasattr(db, "execute_sql"): - # 检查数据库类型,只对SQLite执行PRAGMA语句 - if global_config.database.database_type == "sqlite": - # 设置SQLite优化参数 - db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存 - db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中 - db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射 - db.create_tables([PersonInfo], safe=True) - except Exception as e: - logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") + # try: + # with get_db_session() as session: + # db.connect(reuse_if_open=True) + # # 设置连接池参数(仅对SQLite有效) + # if hasattr(db, "execute_sql"): + # # 检查数据库类型,只对SQLite执行PRAGMA语句 + # if global_config.database.database_type == "sqlite": + # # 设置SQLite优化参数 + # db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存 + # db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中 + # db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射 + # db.create_tables([PersonInfo], safe=True) + # except Exception as e: + # logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") - # 初始化时读取所有person_name + # # 初始化时读取所有person_name try: - from src.common.database.sqlalchemy_models import PersonInfo - # 在这里获取会话 - for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where( - PersonInfo.person_name.is_not(None) - )).fetchall(): - if record.person_name: - self.person_name_list[record.person_id] = record.person_name - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") + # 在这里获取会话 + with get_db_session() as session: + for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where( + PersonInfo.person_name.is_not(None) + )).fetchall(): + if record.person_name: + self.person_name_list[record.person_id] = record.person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") except Exception as e: logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") @@ -121,7 +101,8 @@ class PersonInfoManager: def _db_check_known_sync(p_id: str): # 在需要时获取会话 - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None + with get_db_session() as session: + return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None try: return await asyncio.to_thread(_db_check_known_sync, person_id) @@ -133,7 +114,8 @@ class PersonInfoManager: """根据用户名获取用户ID""" try: # 在需要时获取会话 - record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar() + with get_db_session() as session: + record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar() return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") @@ -176,15 +158,16 @@ class PersonInfoManager: # If it's already a string, assume it's valid JSON or a non-JSON string field def _db_create_sync(p_data: dict): - try: - new_person = PersonInfo(**p_data) - session.add(new_person) - session.commit() - return True - except Exception as e: - session.rollback() - logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") - return False + with get_db_session() as session: + try: + new_person = PersonInfo(**p_data) + session.add(new_person) + session.commit() + + return True + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") + return False await asyncio.to_thread(_db_create_sync, final_data) @@ -223,25 +206,26 @@ class PersonInfoManager: final_data[key] = json.dumps([], ensure_ascii=False) def _db_safe_create_sync(p_data: dict): - try: - existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar() - if existing: - logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") - return True + with get_db_session() as session: + try: + existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar() + if existing: + logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") + return True - # 尝试创建 - new_person = PersonInfo(**p_data) - session.add(new_person) - session.commit() - return True - except Exception as e: - session.rollback() - if "UNIQUE constraint failed" in str(e): - logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误") - return True # 其他协程已创建,视为成功 - else: - logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") - return False + # 尝试创建 + new_person = PersonInfo(**p_data) + session.add(new_person) + session.commit() + + return True + except Exception as e: + if "UNIQUE constraint failed" in str(e): + logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误") + return True # 其他协程已创建,视为成功 + else: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") + return False await asyncio.to_thread(_db_safe_create_sync, final_data) @@ -263,32 +247,33 @@ class PersonInfoManager: def _db_update_sync(p_id: str, f_name: str, val_to_set): start_time = time.time() - try: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - query_time = time.time() + with get_db_session() as session: + try: + record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + query_time = time.time() - if record: - setattr(record, f_name, val_to_set) - session.commit() - save_time = time.time() + if record: + setattr(record, f_name, val_to_set) + + save_time = time.time() - total_time = save_time - start_time - if total_time > 0.5: # 如果超过500ms就记录日志 - logger.warning( - f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" - ) + total_time = save_time - start_time + if total_time > 0.5: # 如果超过500ms就记录日志 + logger.warning( + f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" + ) + session.commit() - return True, False # Found and updated, no creation needed - else: + return True, False # Found and updated, no creation needed + else: + total_time = time.time() - start_time + if total_time > 0.5: + logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}") + return False, True # Not found, needs creation + except Exception as e: total_time = time.time() - start_time - if total_time > 0.5: - logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}") - return False, True # Not found, needs creation - except Exception as e: - session.rollback() - total_time = time.time() - start_time - logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") - raise + logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") + raise found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value) @@ -320,7 +305,8 @@ class PersonInfoManager: return False def _db_has_field_sync(p_id: str, f_name: str): - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + with get_db_session() as session: + record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() return bool(record) try: @@ -430,7 +416,8 @@ class PersonInfoManager: else: def _db_check_name_exists_sync(name_to_check): - return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None + with get_db_session() as session: + return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): is_duplicate = True @@ -471,14 +458,14 @@ class PersonInfoManager: def _db_delete_sync(p_id: str): try: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - session.delete(record) - session.commit() + with get_db_session() as session: + record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + if record: + session.delete(record) + session.commit() return 1 return 0 except Exception as e: - session.rollback() logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}") return 0 @@ -497,7 +484,8 @@ class PersonInfoManager: default_value_for_field = [] # Ensure JSON fields default to [] if not in DB def _db_get_value_sync(p_id: str, f_name: str): - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + with get_db_session() as session: + record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() if record: val = getattr(record, f_name, None) if f_name in JSON_SERIALIZED_FIELDS: @@ -531,27 +519,28 @@ class PersonInfoManager: def get_value_sync(person_id: str, field_name: str): """同步获取指定用户指定字段的值""" default_value_for_field = person_info_default.get(field_name) - if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] + with get_db_session() as session: + if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: + default_value_for_field = [] - if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar(): - val = getattr(record, field_name, None) - if field_name in JSON_SERIALIZED_FIELDS: - if isinstance(val, str): - try: - return json.loads(val) - except json.JSONDecodeError: - logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") + if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar(): + val = getattr(record, field_name, None) + if field_name in JSON_SERIALIZED_FIELDS: + if isinstance(val, str): + try: + return json.loads(val) + except json.JSONDecodeError: + logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") + return [] + elif val is None: return [] - elif val is None: - return [] + return val return val - return val - if field_name in person_info_default: - return default_value_for_field - logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None + if field_name in person_info_default: + return default_value_for_field + logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") + return None @staticmethod async def get_values(person_id: str, field_names: list) -> dict: @@ -563,7 +552,8 @@ class PersonInfoManager: result = {} def _db_get_record_sync(p_id: str): - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + with get_db_session() as session: + return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() record = await asyncio.to_thread(_db_get_record_sync, person_id) @@ -608,10 +598,11 @@ class PersonInfoManager: def _db_get_specific_sync(f_name: str): found_results = {} try: - for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall(): - value = getattr(record, f_name) - if way(value): - found_results[record.person_id] = value + with get_db_session() as session: + for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall(): + value = getattr(record, f_name) + if way(value): + found_results[record.person_id] = value except Exception as e_query: logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True) return found_results @@ -634,19 +625,20 @@ class PersonInfoManager: def _db_get_or_create_sync(p_id: str, init_data: dict): """原子性的获取或创建操作""" - # 首先尝试获取现有记录 - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - return record, False # 记录存在,未创建 + with get_db_session() as session: + # 首先尝试获取现有记录 + record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + if record: + return record, False # 记录存在,未创建 # 记录不存在,尝试创建 try: new_person = PersonInfo(**init_data) session.add(new_person) session.commit() + return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功 except Exception as e: - session.rollback() # 如果创建失败(可能是因为竞态条件),再次尝试获取 if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") @@ -709,7 +701,8 @@ class PersonInfoManager: if not found_person_id: def _db_find_by_name_sync(p_name_to_find: str): - return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar() + with get_db_session() as session: + return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar() record = await asyncio.to_thread(_db_find_by_name_sync, person_name) if record: diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index bd9f19448..64db72ac1 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -14,7 +14,6 @@ from src.common.database.sqlalchemy_database_api import ( db_save, db_get, store_action_info, - get_model_class, MODEL_MAPPING ) @@ -24,6 +23,5 @@ __all__ = [ 'db_save', 'db_get', 'store_action_info', - 'get_model_class', 'MODEL_MAPPING' ] diff --git a/src/plugins/built_in/maizone/scheduler.py b/src/plugins/built_in/maizone/scheduler.py index 5761258a7..395eadf9d 100644 --- a/src/plugins/built_in/maizone/scheduler.py +++ b/src/plugins/built_in/maizone/scheduler.py @@ -151,7 +151,7 @@ class ScheduleManager: ) session.add(new_record) - session.commit() + logger.info(f"已更新日程处理状态: {datetime_hour} - {activity} - 成功: {success}") except Exception as e: