161 lines
5.7 KiB
Python
161 lines
5.7 KiB
Python
from __future__ import annotations
|
||
|
||
from sqlalchemy import Engine, inspect, text
|
||
|
||
from app.db.models import Base
|
||
|
||
|
||
def _has_column(engine: Engine, table: str, col: str) -> bool:
|
||
insp = inspect(engine)
|
||
cols = insp.get_columns(table)
|
||
return any(c.get("name") == col for c in cols)
|
||
|
||
|
||
def _sqlite_table_sql(conn, table: str) -> str:
|
||
row = conn.execute(
|
||
text("SELECT sql FROM sqlite_master WHERE type='table' AND name=:name"),
|
||
{"name": table},
|
||
).fetchone()
|
||
return str(row[0] or "") if row else ""
|
||
|
||
|
||
def _ensure_job_logs_status_allows_running(engine: Engine) -> None:
|
||
"""
|
||
为 status 新增 RUNNING 时的轻量自升级:
|
||
- SQLite:如存在 CHECK 且不包含 RUNNING,则通过“重建表”方式迁移(移除旧 CHECK,确保允许 RUNNING)
|
||
- PostgreSQL:如存在 status CHECK 且不包含 RUNNING,则 drop & recreate
|
||
"""
|
||
dialect = engine.dialect.name
|
||
if dialect not in ("sqlite", "postgresql"):
|
||
return
|
||
|
||
insp = inspect(engine)
|
||
try:
|
||
cols = insp.get_columns("job_logs")
|
||
except Exception:
|
||
return
|
||
existing_cols = {c.get("name") for c in cols if c.get("name")}
|
||
|
||
with engine.begin() as conn:
|
||
if dialect == "sqlite":
|
||
sql = _sqlite_table_sql(conn, "job_logs")
|
||
# 没有 CHECK 约束则无需迁移;有 CHECK 但已包含 RUNNING 也无需迁移
|
||
if not sql or "CHECK" not in sql or "RUNNING" in sql:
|
||
return
|
||
|
||
# 重建表:去掉旧 CHECK(允许 RUNNING),并确保列存在(缺列用默认值补齐)
|
||
conn.execute(text("ALTER TABLE job_logs RENAME TO job_logs_old"))
|
||
|
||
conn.execute(
|
||
text(
|
||
"""
|
||
CREATE TABLE job_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
job_id VARCHAR NOT NULL,
|
||
status VARCHAR(16) NOT NULL,
|
||
snapshot_params TEXT NOT NULL DEFAULT '{}',
|
||
message TEXT NOT NULL DEFAULT '',
|
||
traceback TEXT NOT NULL DEFAULT '',
|
||
run_log TEXT NOT NULL DEFAULT '',
|
||
celery_task_id VARCHAR NOT NULL DEFAULT '',
|
||
attempt INTEGER NOT NULL DEFAULT 0,
|
||
started_at DATETIME NOT NULL,
|
||
finished_at DATETIME
|
||
)
|
||
"""
|
||
)
|
||
)
|
||
|
||
def _expr(col: str, default_expr: str) -> str:
|
||
return col if col in existing_cols else f"{default_expr} AS {col}"
|
||
|
||
insert_cols = [
|
||
"id",
|
||
"job_id",
|
||
"status",
|
||
"snapshot_params",
|
||
"message",
|
||
"traceback",
|
||
"run_log",
|
||
"celery_task_id",
|
||
"attempt",
|
||
"started_at",
|
||
"finished_at",
|
||
]
|
||
select_exprs = [
|
||
_expr("id", "NULL"),
|
||
_expr("job_id", "''"),
|
||
_expr("status", "''"),
|
||
_expr("snapshot_params", "'{}'"),
|
||
_expr("message", "''"),
|
||
_expr("traceback", "''"),
|
||
_expr("run_log", "''"),
|
||
_expr("celery_task_id", "''"),
|
||
_expr("attempt", "0"),
|
||
_expr("started_at", "CURRENT_TIMESTAMP"),
|
||
_expr("finished_at", "NULL"),
|
||
]
|
||
|
||
conn.execute(
|
||
text(
|
||
f"INSERT INTO job_logs ({', '.join(insert_cols)}) "
|
||
f"SELECT {', '.join(select_exprs)} FROM job_logs_old"
|
||
)
|
||
)
|
||
|
||
conn.execute(text("DROP TABLE job_logs_old"))
|
||
# 还原 job_id 索引(SQLAlchemy 默认命名 ix_job_logs_job_id)
|
||
conn.execute(text("CREATE INDEX IF NOT EXISTS ix_job_logs_job_id ON job_logs (job_id)"))
|
||
return
|
||
|
||
if dialect == "postgresql":
|
||
try:
|
||
checks = insp.get_check_constraints("job_logs") or []
|
||
except Exception:
|
||
checks = []
|
||
|
||
need = False
|
||
drop_names: list[str] = []
|
||
for ck in checks:
|
||
name = str(ck.get("name") or "")
|
||
sqltext = str(ck.get("sqltext") or "")
|
||
if "status" in sqltext and "RUNNING" not in sqltext:
|
||
need = True
|
||
if name:
|
||
drop_names.append(name)
|
||
|
||
if not need:
|
||
return
|
||
|
||
# 先尽力 drop 旧约束(名称不确定),再创建统一的新约束
|
||
for n in drop_names:
|
||
conn.execute(text(f'ALTER TABLE job_logs DROP CONSTRAINT IF EXISTS "{n}"'))
|
||
conn.execute(text("ALTER TABLE job_logs DROP CONSTRAINT IF EXISTS ck_job_logs_status"))
|
||
conn.execute(
|
||
text(
|
||
"ALTER TABLE job_logs "
|
||
"ADD CONSTRAINT ck_job_logs_status "
|
||
"CHECK (status IN ('RUNNING','SUCCESS','FAILURE','RETRY'))"
|
||
)
|
||
)
|
||
return
|
||
|
||
|
||
def ensure_schema(engine: Engine) -> None:
|
||
"""
|
||
轻量自升级(跨 SQLite/PostgreSQL):
|
||
- create_all 不会更新既有表结构,因此用 inspector + ALTER TABLE 补列
|
||
- 必须保证任何失败都不影响主流程(上层可选择忽略异常)
|
||
"""
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
with engine.begin() as conn:
|
||
# job_logs.run_log
|
||
if not _has_column(engine, "job_logs", "run_log"):
|
||
conn.execute(text("ALTER TABLE job_logs ADD COLUMN run_log TEXT NOT NULL DEFAULT ''"))
|
||
|
||
# job_logs.status: ensure new enum value RUNNING is accepted by DB constraints
|
||
_ensure_job_logs_status_allows_running(engine)
|
||
|
||
|