"""Auto-migration runner. Applies pending SQL migrations on startup.""" from __future__ import annotations import logging from pathlib import Path import asyncpg logger = logging.getLogger(__name__) MIGRATIONS_DIR = Path(__file__).parent async def run_migrations(pool: asyncpg.Pool) -> None: """Check schema_version and apply any pending .sql migration files.""" async with pool.acquire() as conn: # Ensure the version-tracking table exists await conn.execute(""" CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), description TEXT NOT NULL DEFAULT '' ) """) row = await conn.fetchrow( "SELECT COALESCE(MAX(version), 0) AS v FROM schema_version" ) current_version: int = row["v"] logger.info("Current schema version: %d", current_version) # Discover and sort SQL files by their numeric prefix sql_files = sorted( MIGRATIONS_DIR.glob("[0-9]*.sql"), key=lambda p: int(p.stem.split("_")[0]), ) applied = 0 for sql_file in sql_files: version = int(sql_file.stem.split("_")[0]) if version <= current_version: continue logger.info("Applying migration %03d: %s", version, sql_file.name) sql = sql_file.read_text(encoding="utf-8") # Execute the entire migration in a transaction async with conn.transaction(): await conn.execute(sql) logger.info("Migration %03d applied successfully", version) applied += 1 if applied == 0: logger.info("No pending migrations") else: logger.info("Applied %d migration(s)", applied)