fix: 修复返回的 embedding 为空时的处理逻辑

This commit is contained in:
Windpicker-owo
2025-11-27 22:08:22 +08:00
parent 26520c123a
commit 25571bf0ec
2 changed files with 71 additions and 56 deletions

View File

@@ -58,6 +58,7 @@ from sqlalchemy import (
Table,
inspect,
text,
types as sqltypes,
)
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.exc import SQLAlchemyError
@@ -191,7 +192,7 @@ def get_database_config_from_toml(db_type: str) -> dict | None:
def create_sqlite_engine(sqlite_path: str) -> Engine:
"""创建 SQLite 引擎"""
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD> SQLite <EFBFBD><EFBFBD><EFBFBD><EFBFBD>"""
if not os.path.isabs(sqlite_path):
sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path)
@@ -200,28 +201,18 @@ def create_sqlite_engine(sqlite_path: str) -> Engine:
url = f"sqlite:///{sqlite_path}"
logger.info("使用 SQLite 数据库: %s", sqlite_path)
return create_engine(url, future=True)
def create_mysql_engine(
host: str,
port: int,
database: str,
user: str,
password: str,
charset: str = "utf8mb4",
) -> Engine:
"""创建 MySQL 引擎"""
# 延迟导入 pymysql以便友好提示
try:
import pymysql # noqa: F401
except ImportError:
logger.error("需要安装 pymysql 才能连接 MySQL: pip install pymysql")
raise
url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset={charset}"
logger.info("使用 MySQL 数据库: %s@%s:%s/%s", user, host, port, database)
return create_engine(url, future=True)
engine = create_engine(
url,
future=True,
connect_args={
"timeout": 30, # wait a bit if the db is locked
"check_same_thread": False,
},
)
# Increase busy timeout to reduce "database is locked" errors on SQLite
with engine.connect() as conn:
conn.execute(text("PRAGMA busy_timeout=30000"))
return engine
def create_postgresql_engine(
@@ -324,22 +315,35 @@ def get_table_row_count(conn: Connection, table: Table) -> int:
def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table:
"""在目标数据库中创建与源表结构相同的表
"""复制表结构到目标数据库,使其结构保持一致"""
target_is_sqlite = target_engine.dialect.name == "sqlite"
target_is_pg = target_engine.dialect.name == "postgresql"
Args:
source_table: 源表对象
target_metadata: 目标元数据对象
target_engine: 目标数据库引擎
columns = []
for c in source_table.columns:
new_col = c.copy()
Returns:
Table: 目标表对象
"""
# 复制表结构
# SQLite 不支持 nextval 等 server_default
if target_is_sqlite:
new_col.server_default = None
# PostgreSQL 需要将部分 SQLite 特有类型转换
if target_is_pg:
col_type = new_col.type
# SQLite DATETIME -> 通用 DateTime
if isinstance(col_type, sqltypes.DateTime) or col_type.__class__.__name__ in {"DATETIME", "DateTime"}:
new_col.type = sqltypes.DateTime()
# TEXT(50) 等长度受限的 TEXT 在 PG 无效,改用 String(length)
elif isinstance(col_type, sqltypes.Text) and getattr(col_type, "length", None):
new_col.type = sqltypes.String(length=col_type.length)
columns.append(new_col)
# 为避免迭代约束集合时出现 “Set changed size during iteration”这里不复制表级约束
target_table = Table(
source_table.name,
target_metadata,
*[c.copy() for c in source_table.columns],
*[c.copy() for c in source_table.constraints],
*columns,
)
target_metadata.create_all(target_engine, tables=[target_table])
return target_table
@@ -383,8 +387,6 @@ def migrate_table_data(
logger.error("查询表 %s 失败: %s", source_table.name, e)
return 0, 1
columns = source_table.columns.keys()
def insert_batch(rows: list[dict]):
nonlocal migrated_rows, error_count
if not rows:
@@ -399,7 +401,8 @@ def migrate_table_data(
batch: list[dict] = []
for row in result:
row_dict = {col: row[col] for col in columns}
# Use column objects to access row mapping to avoid quoted_name keys
row_dict = {col.key: row._mapping[col] for col in source_table.columns}
batch.append(row_dict)
if len(batch) >= batch_size:
insert_batch(batch)
@@ -535,6 +538,14 @@ class DatabaseMigrator:
# 目标数据库配置
target_config = self._load_target_config()
# 防止源/目标 SQLite 指向同一路径导致自我覆盖及锁
if (
self.source_type == "sqlite"
and self.target_type == "sqlite"
and os.path.abspath(source_config.get("path", "")) == os.path.abspath(target_config.get("path", ""))
):
raise ValueError("源数据库与目标数据库不能是同一个 SQLite 文件,请为目标指定不同的路径")
# 创建引擎
self.source_engine = create_engine_by_type(self.source_type, source_config)
self.target_engine = create_engine_by_type(self.target_type, target_config)
@@ -589,32 +600,36 @@ class DatabaseMigrator:
return sorted_tables
def _drop_target_tables(self, conn: Connection):
"""删除目标数据库中已经存在的表(谨慎操作
def _drop_target_tables(self):
"""删除目标数据库中已有的表(如果有
这里为了避免冲突,迁移前会询问用户是否删除目标库中已经存在的同名表。
使用 Engine.begin() 进行连接以支持 autobegin 和 begin 兼容 SQLAlchemy 2.0 的写法
"""
inspector = inspect(conn)
existing_tables = inspector.get_table_names()
if not existing_tables:
logger.info("目标数据库中没有已存在的表,无需删除")
if self.target_engine is None:
logger.warning("目标数据库引擎尚未初始化,无法删除表")
return
logger.info("目标数据库中当前存在的表: %s", ", ".join(existing_tables))
if confirm_action("是否删除目标数据库中已有的所有表?此操作不可恢复!", default=False):
with conn.begin():
with self.target_engine.begin() as conn:
inspector = inspect(conn)
existing_tables = inspector.get_table_names()
if not existing_tables:
logger.info("目标数据库中没有已存在的表,无需删除")
return
logger.info("目标数据库中的当前表: %s", ", ".join(existing_tables))
if confirm_action("是否删除目标数据库中现有的表列表?此操作不可撤销", default=False):
for table_name in existing_tables:
try:
logger.info("删除目标数据库表: %s", table_name)
logger.info("删除目标数据库表: %s", table_name)
conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE"))
except SQLAlchemyError as e:
logger.error("删除 %s 失败: %s", table_name, e)
logger.error("删除 %s 失败: %s", table_name, e)
self.stats["errors"].append(
f"删除 {table_name} 失败: {e}"
f"删除 {table_name} 失败: {e}"
)
else:
logger.info("用户选择保留目标数据库中已有的表,可能会与迁移数据发生冲突。")
else:
logger.info("跳过删除目标数据库中的表,继续迁移过程")
def migrate(self):
"""执行迁移操作"""
@@ -630,8 +645,7 @@ class DatabaseMigrator:
logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables))
# 删除目标库中已有表(可选)
with self.target_engine.connect() as target_conn:
self._drop_target_tables(target_conn)
self._drop_target_tables()
# 开始迁移
with self.source_engine.connect() as source_conn, self.target_engine.connect() as target_conn:
@@ -937,7 +951,7 @@ def interactive_setup() -> dict:
if target_type == "sqlite":
target_path = _ask_str(
"目标 SQLite 文件路径(若不存在会自动创建)",
default="data/MaiBot_target.db",
default="data/MaiBot.db",
)
target_config = {"path": target_path}
else:

View File

@@ -367,6 +367,7 @@ class BotInterestManager:
self.embedding_dimension,
current_dim,
)
return embedding
else:
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")