Files
util.nix/package/postgres/pg-migration/src/main.rs

211 lines
7.0 KiB
Rust

use clap::{Arg, Command, ArgAction};
use postgres::{Client, NoTls};
use chrono::Utc;
use pg_migration_lib::init_db;
use std::{fs, path::Path, process::{self, Command as ProcessCommand}};
use rand::Rng;
fn main() {
if let Err(code) = run_app() {
process::exit(code);
}
}
fn run_app() -> Result<(), i32> {
check_psql_installed();
let matches = Command::new("Rust PG Migration Tool")
.version("0.1")
.arg(
Arg::new("migration_dir")
.short('d')
.long("migration-dir")
.env("MIGRATION_DIR")
.num_args(1)
.default_value("migration"),
)
.arg(
Arg::new("inherits")
.long("inherits")
.num_args(1..)
.help("List one or more tables the migration table must inherit from"),
)
.subcommand(
Command::new("migrate").arg(
Arg::new("force")
.short('f')
.long("force")
.action(ArgAction::SetTrue),
)
.arg(
Arg::new("db_url")
.short('u')
.long("db-url")
.env("PG_URL")
.required(true)
.num_args(1),
)
.arg(
Arg::new("set")
.short('v')
.long("set")
.num_args(1..)
.help("Pass variable assignments to psql in the format key=value"),
),
)
.subcommand(
Command::new("create").arg(
Arg::new("name")
.short('n')
.long("name")
.num_args(1),
),
)
.subcommand(
Command::new("fetch").arg(
Arg::new("db_url")
.short('u')
.long("db-url")
.env("PG_URL")
.required(true)
.num_args(1),
),
)
.get_matches();
let migration_dir = matches.get_one::<String>("migration_dir").unwrap();
let inherits: Vec<String> = matches
.get_many::<String>("inherits")
.map(|vals| vals.cloned().collect())
.unwrap_or_else(Vec::new);
match matches.subcommand() {
Some(("create", sub_m)) => {
let name = sub_m
.get_one::<String>("name")
.cloned()
.unwrap_or_else(generate_migration_name);
create_migration_file(migration_dir, &name);
Ok(())
}
Some(("migrate", sub_m)) => {
let db_url = sub_m.get_one::<String>("db_url").unwrap();
let set_vars: Vec<String> = sub_m
.get_many::<String>("set")
.map(|vals| vals.cloned().collect())
.unwrap_or_else(Vec::new);
let mut client = Client::connect(db_url, NoTls).expect("DB connection failed");
init_db(&mut client, &inherits);
let force = sub_m.get_flag("force");
apply_migrations(&mut client, migration_dir, db_url, force, &set_vars)
}
Some(("fetch", sub_m)) => {
let db_url = sub_m.get_one::<String>("db_url").unwrap();
let mut client = Client::connect(db_url, NoTls).expect("DB connection failed");
init_db(&mut client, &inherits);
fetch_migrations(&mut client, migration_dir);
Ok(())
}
_ => Ok(()),
}
}
fn check_psql_installed() {
if ProcessCommand::new("psql")
.arg("--version")
.output()
.is_err()
{
eprintln!("Error: psql is not installed or not in PATH.");
std::process::exit(1);
}
}
fn apply_migrations(client: &mut Client, migration_dir: &str, db_url: &str, force: bool, set_vars: &[String]) -> Result<(), i32> {
// Get the list of new migrations from disk
let mut fs_entries: Vec<_> = fs::read_dir(migration_dir)
.expect("Reading migration directory failed")
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().and_then(|s| s.to_str()) == Some("sql"))
.collect();
fs_entries.sort_by_key(|e| e.path());
let fs_migrations: Vec<String> = fs_entries
.iter()
.map(|e| e.path().file_name().unwrap().to_string_lossy().into_owned())
.collect();
// Get the list of already applied migrations from DB
let rows = client
.query("SELECT name FROM hectic.migration ORDER BY name ASC", &[])
.expect("Query failed");
let db_migrations: Vec<String> = rows.iter().map(|row| row.get(0)).collect();
// Check if the DB migrations form a proper prefix of disk migrations
// (meaning all DB-applied migration filenames should appear in the same order at the start).
for (i, db_mig) in db_migrations.iter().enumerate() {
if i >= fs_migrations.len() || fs_migrations[i] != *db_mig {
// The DB has migrations that are not found in the same position on disk -> unrelated tree
if !force {
eprintln!("Unrelated migration tree detected. Use --force to proceed.");
return Err(2);
} else {
eprintln!("Unrelated migration tree forced. Proceeding...");
break;
}
}
}
for fs_mig in fs_migrations {
if db_migrations.contains(&fs_mig) {
continue;
}
let mut cmd = std::process::Command::new("psql");
cmd.arg("-d")
.arg(db_url);
// Add provided set variables as -v key=value
for var in set_vars {
cmd.arg("-v").arg(var);
}
cmd.arg("-f")
.arg(Path::new(migration_dir).join(&fs_mig).to_str().unwrap());
let status = cmd.status().expect("psql execution failed");
if !status.success() {
eprintln!("Migration failed: {}", fs_mig);
return Err(3);
}
client.execute("INSERT INTO hectic.migration (name) VALUES ($1)", &[&fs_mig])
.expect("Recording migration failed");
}
Ok(())
}
fn create_migration_file(migration_dir: &str, name: &str) {
fs::create_dir_all(migration_dir).expect("Creating migration directory failed");
let timestamp = Utc::now().timestamp();
let file_name = format!("{}_{}.sql", timestamp, name);
let file_path = Path::new(migration_dir).join(file_name);
fs::write(&file_path, "-- Write your migration SQL here\n")
.expect("Creating migration file failed");
println!("Created migration: {:?}", file_path);
}
fn fetch_migrations(_client: &mut Client, _migration_dir: &str) {
// (Fetch implementation omitted)
}
fn generate_migration_name() -> String {
let adjectives = ["quick", "lazy", "sleepy", "noisy", "hungry"];
let nouns = ["fox", "dog", "cat", "mouse", "bear"];
let mut rng = rand::rng();
let adj = adjectives[rng.random_range(0..adjectives.len())];
let noun = nouns[rng.random_range(0..nouns.len())];
format!("{}_{}", adj, noun)
}