From a2c86f36052d1b9fe8447dc67055b93bc8437d53 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 8 Aug 2025 12:34:21 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E9=83=A8=E5=88=86=E5=A4=84?= =?UTF-8?q?=E7=90=86notify=EF=BC=8C=E8=87=AA=E5=8A=A8=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93null=E7=BA=A6=E6=9D=9F=E5=8F=98?= =?UTF-8?q?=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit notify存储至message库 --- scripts/import_openie.py | 1 - src/chat/message_receive/bot.py | 8 +- src/chat/message_receive/chat_stream.py | 3 +- src/chat/message_receive/message.py | 1 + src/chat/message_receive/storage.py | 3 + src/common/database/database_model.py | 278 +++++++++++++++++- .../mais4u_chat/s4u_stream_generator.py | 3 +- 7 files changed, 285 insertions(+), 12 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 1177650d4..eabeb9965 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -15,7 +15,6 @@ from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 from src.manager.local_store_manager import local_storage -from dotenv import load_dotenv # 添加项目根目录到 sys.path diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index a4228b89a..a6a8aeb16 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -146,7 +146,10 @@ class ChatBot: async def hanle_notice_message(self, message: MessageRecv): if message.message_info.message_id == "notice": - logger.info("收到notice消息,暂时不支持处理") + message.is_notify = True + logger.info("notice消息") + print(message) + return True async def do_s4u(self, message_data: Dict[str, Any]): @@ -207,7 +210,8 @@ class ChatBot: message = MessageRecv(message_data) if await self.hanle_notice_message(message): - return + # return + pass group_info = message.message_info.group_info user_info = message.message_info.user_info diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 2ee2be05a..5108643fe 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -217,7 +217,8 @@ class ChatManager: # 更新用户信息和群组信息 stream.update_active_time() stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 - stream.user_info = user_info + if user_info.platform and user_info.user_id: + stream.user_info = user_info if group_info: stream.group_info = group_info from .message import MessageRecv # 延迟导入,避免循环引用 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 58dd6d689..5c7e0940e 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -109,6 +109,7 @@ class MessageRecv(Message): self.has_picid = False self.is_voice = False self.is_mentioned = None + self.is_notify = False self.is_command = False diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 9659bb417..5f54b15fb 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -43,6 +43,7 @@ class MessageStorage: priority_info = {} is_emoji = False is_picid = False + is_notify = False is_command = False else: filtered_display_message = "" @@ -53,6 +54,7 @@ class MessageStorage: priority_info = message.priority_info is_emoji = message.is_emoji is_picid = message.is_picid + is_notify = message.is_notify is_command = message.is_command chat_info_dict = chat_stream.to_dict() @@ -98,6 +100,7 @@ class MessageStorage: priority_info=priority_info, is_emoji=is_emoji, is_picid=is_picid, + is_notify=is_notify, is_command=is_command, ) except Exception: diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index d2b3acce7..e095c1891 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -146,9 +146,9 @@ class Messages(BaseModel): chat_info_last_active_time = DoubleField() # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) - user_platform = TextField() - user_id = TextField() - user_nickname = TextField() + user_platform = TextField(null=True) + user_id = TextField(null=True) + user_nickname = TextField(null=True) user_cardname = TextField(null=True) processed_plain_text = TextField(null=True) # 处理后的纯文本消息 @@ -162,6 +162,7 @@ class Messages(BaseModel): is_emoji = BooleanField(default=False) is_picid = BooleanField(default=False) is_command = BooleanField(default=False) + is_notify = BooleanField(default=False) class Meta: # database = db # 继承自 BaseModel @@ -252,7 +253,7 @@ class PersonInfo(BaseModel): name_reason = TextField(null=True) # 名称设定的原因 platform = TextField() # 平台 user_id = TextField(index=True) # 用户ID - nickname = TextField() # 用户昵称 + nickname = TextField(null=True) # 用户昵称 impression = TextField(null=True) # 个人印象 short_impression = TextField(null=True) # 个人印象的简短描述 points = TextField(null=True) # 个人印象的点 @@ -378,10 +379,14 @@ def create_tables(): ) -def initialize_database(): +def initialize_database(sync_constraints=False): """ 检查所有定义的表是否存在,如果不存在则创建它们。 检查所有表的所有字段是否存在,如果缺失则自动添加。 + + Args: + sync_constraints (bool): 是否同步字段约束。默认为 False。 + 如果为 True,会检查并修复字段的 NULL 约束不一致问题。 """ models = [ @@ -462,6 +467,13 @@ def initialize_database(): logger.info(f"字段 '{field_name}' 删除成功") except Exception as e: logger.error(f"删除字段 '{field_name}' 失败: {e}") + + # 如果启用了约束同步,执行约束检查和修复 + if sync_constraints: + logger.debug("开始同步数据库字段约束...") + sync_field_constraints() + logger.debug("数据库字段约束同步完成") + except Exception as e: logger.exception(f"检查表或字段是否存在时出错: {e}") # 如果检查失败(例如数据库不可用),则退出 @@ -470,5 +482,259 @@ def initialize_database(): logger.info("数据库初始化完成") +def sync_field_constraints(): + """ + 同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。 + 如果发现不一致,会自动修复字段约束。 + """ + + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Expression, + Memory, + ThinkingLog, + GraphNodes, + GraphEdges, + ActionRecords, + ] + + try: + with db: + for model in models: + table_name = model._meta.table_name + if not db.table_exists(model): + logger.warning(f"表 '{table_name}' 不存在,跳过约束检查") + continue + + logger.debug(f"检查表 '{table_name}' 的字段约束...") + + # 获取当前表结构信息 + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} + for row in cursor.fetchall()} + + # 检查每个模型字段的约束 + constraints_to_fix = [] + for field_name, field_obj in model._meta.fields.items(): + if field_name not in current_schema: + continue # 字段不存在,跳过 + + current_notnull = current_schema[field_name]['notnull'] + model_allows_null = field_obj.null + + # 如果模型允许 null 但数据库字段不允许 null,需要修复 + if model_allows_null and current_notnull: + constraints_to_fix.append({ + 'field_name': field_name, + 'field_obj': field_obj, + 'action': 'allow_null', + 'current_constraint': 'NOT NULL', + 'target_constraint': 'NULL' + }) + logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL,但数据库为NOT NULL") + + # 如果模型不允许 null 但数据库字段允许 null,也需要修复(但要小心) + elif not model_allows_null and not current_notnull: + constraints_to_fix.append({ + 'field_name': field_name, + 'field_obj': field_obj, + 'action': 'disallow_null', + 'current_constraint': 'NULL', + 'target_constraint': 'NOT NULL' + }) + logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL,但数据库允许NULL") + + # 修复约束不一致的字段 + if constraints_to_fix: + logger.info(f"表 '{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束") + _fix_table_constraints(table_name, model, constraints_to_fix) + else: + logger.debug(f"表 '{table_name}' 的字段约束已同步") + + except Exception as e: + logger.exception(f"同步字段约束时出错: {e}") + + +def _fix_table_constraints(table_name, model, constraints_to_fix): + """ + 修复表的字段约束。 + 对于 SQLite,由于不支持直接修改列约束,需要重建表。 + """ + try: + # 备份表名 + backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}" + + logger.info(f"开始修复表 '{table_name}' 的字段约束...") + + # 1. 创建备份表 + db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}") + logger.info(f"已创建备份表 '{backup_table}'") + + # 2. 删除原表 + db.execute_sql(f"DROP TABLE {table_name}") + logger.info(f"已删除原表 '{table_name}'") + + # 3. 重新创建表(使用当前模型定义) + db.create_tables([model]) + logger.info(f"已重新创建表 '{table_name}' 使用新的约束") + + # 4. 从备份表恢复数据 + # 获取字段列表 + fields = list(model._meta.fields.keys()) + fields_str = ', '.join(fields) + + # 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据 + # 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值 + insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}" + + # 检查是否有字段需要从 NULL 改为 NOT NULL + null_to_notnull_fields = [ + constraint['field_name'] for constraint in constraints_to_fix + if constraint['action'] == 'disallow_null' + ] + + if null_to_notnull_fields: + # 需要处理 NULL 值,为这些字段设置默认值 + logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL,需要处理现有的NULL值") + + # 构建更复杂的 SELECT 语句来处理 NULL 值 + select_fields = [] + for field_name in fields: + if field_name in null_to_notnull_fields: + field_obj = model._meta.fields[field_name] + # 根据字段类型设置默认值 + if isinstance(field_obj, (TextField,)): + default_value = "''" + elif isinstance(field_obj, (IntegerField, FloatField, DoubleField)): + default_value = "0" + elif isinstance(field_obj, BooleanField): + default_value = "0" + elif isinstance(field_obj, DateTimeField): + default_value = f"'{datetime.datetime.now()}'" + else: + default_value = "''" + + select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}") + else: + select_fields.append(field_name) + + select_str = ', '.join(select_fields) + insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}" + + db.execute_sql(insert_sql) + logger.info(f"已从备份表恢复数据到 '{table_name}'") + + # 5. 验证数据完整性 + original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0] + new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0] + + if original_count == new_count: + logger.info(f"数据完整性验证通过: {original_count} 行数据") + # 删除备份表 + db.execute_sql(f"DROP TABLE {backup_table}") + logger.info(f"已删除备份表 '{backup_table}'") + else: + logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count} 行") + logger.error(f"备份表 '{backup_table}' 已保留,请手动检查") + + # 记录修复的约束 + for constraint in constraints_to_fix: + logger.info(f"已修复字段 '{constraint['field_name']}': " + f"{constraint['current_constraint']} -> {constraint['target_constraint']}") + + except Exception as e: + logger.exception(f"修复表 '{table_name}' 约束时出错: {e}") + # 尝试恢复 + try: + if db.table_exists(backup_table): + logger.info(f"尝试从备份表 '{backup_table}' 恢复...") + db.execute_sql(f"DROP TABLE IF EXISTS {table_name}") + db.execute_sql(f"ALTER TABLE {backup_table} RENAME TO {table_name}") + logger.info(f"已从备份恢复表 '{table_name}'") + except Exception as restore_error: + logger.exception(f"恢复表失败: {restore_error}") + + +def check_field_constraints(): + """ + 检查但不修复字段约束,返回不一致的字段信息。 + 用于在修复前预览需要修复的内容。 + """ + + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Expression, + Memory, + ThinkingLog, + GraphNodes, + GraphEdges, + ActionRecords, + ] + + inconsistencies = {} + + try: + with db: + for model in models: + table_name = model._meta.table_name + if not db.table_exists(model): + continue + + # 获取当前表结构信息 + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} + for row in cursor.fetchall()} + + table_inconsistencies = [] + + # 检查每个模型字段的约束 + for field_name, field_obj in model._meta.fields.items(): + if field_name not in current_schema: + continue + + current_notnull = current_schema[field_name]['notnull'] + model_allows_null = field_obj.null + + if model_allows_null and current_notnull: + table_inconsistencies.append({ + 'field_name': field_name, + 'issue': 'model_allows_null_but_db_not_null', + 'model_constraint': 'NULL', + 'db_constraint': 'NOT NULL', + 'recommended_action': 'allow_null' + }) + elif not model_allows_null and not current_notnull: + table_inconsistencies.append({ + 'field_name': field_name, + 'issue': 'model_not_null_but_db_allows_null', + 'model_constraint': 'NOT NULL', + 'db_constraint': 'NULL', + 'recommended_action': 'disallow_null' + }) + + if table_inconsistencies: + inconsistencies[table_name] = table_inconsistencies + + except Exception as e: + logger.exception(f"检查字段约束时出错: {e}") + + return inconsistencies + + + # 模块加载时调用初始化函数 -initialize_database() +initialize_database(sync_constraints=True) \ No newline at end of file diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index c0ca26581..43bf3599b 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,7 +1,6 @@ -import os from typing import AsyncGenerator from src.mais4u.openai_client import AsyncOpenAIClient -from src.config.config import global_config, model_config +from src.config.config import model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder from src.common.logger import get_logger