feat(pg-schema): init
This commit is contained in:
379
package/postgres/pg-schema/src/main.rs
Normal file
379
package/postgres/pg-schema/src/main.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use dotenv::dotenv;
|
||||
use postgres::{Client, NoTls};
|
||||
use regex::Regex;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::env;
|
||||
use std::process::Command;
|
||||
use log::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Column {
|
||||
name: String,
|
||||
data_type: String,
|
||||
is_nullable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ForeignKeyGroup {
|
||||
constraint_name: String,
|
||||
source_schema: String,
|
||||
source_table: String,
|
||||
target_schema: String,
|
||||
target_table: String,
|
||||
source_columns: Vec<String>,
|
||||
target_columns: Vec<String>,
|
||||
is_unique: bool,
|
||||
}
|
||||
|
||||
// --- Sanitization functions ---
|
||||
|
||||
fn sanitize_type(data_type: &str) -> String {
|
||||
let re = Regex::new(r"\s+").unwrap();
|
||||
re.replace_all(&data_type.to_lowercase(), "_")
|
||||
.to_uppercase()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn sanitize_name(name: &str) -> String {
|
||||
// Replace spaces and hyphens with underscores and remove invalid characters
|
||||
let re = Regex::new(r"[\s\-]+").unwrap();
|
||||
let name = re.replace_all(name, "_");
|
||||
let re = Regex::new(r"[^\w_]").unwrap();
|
||||
re.replace_all(&name, "").to_string()
|
||||
}
|
||||
|
||||
// --- Database Schema queries ---
|
||||
|
||||
fn get_tables(client: &mut Client) -> Result<Vec<(String, String)>, postgres::Error> {
|
||||
let rows = client.query(
|
||||
"
|
||||
SELECT table_schema, table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'cron')
|
||||
AND table_schema NOT LIKE 'pg_toast%'
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_schema, table_name;
|
||||
",
|
||||
&[],
|
||||
)?;
|
||||
info!("{rows:#?}");
|
||||
Ok(rows
|
||||
.iter()
|
||||
.map(|row| (row.get::<_, String>(0), row.get::<_, String>(1)))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn get_columns(
|
||||
client: &mut Client,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
) -> Result<Vec<Column>, postgres::Error> {
|
||||
let rows = client.query(
|
||||
"
|
||||
SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'cron')
|
||||
AND table_schema NOT LIKE 'pg_toast%'
|
||||
AND table_schema = $1
|
||||
AND table_name = $2
|
||||
ORDER BY ordinal_position;
|
||||
",
|
||||
&[&schema, &table],
|
||||
)?;
|
||||
let mut columns = Vec::new();
|
||||
for row in rows {
|
||||
let col_name: String = row.get(0);
|
||||
let data_type: String = row.get(1);
|
||||
let is_nullable_str: String = row.get(2);
|
||||
let is_nullable = is_nullable_str.to_lowercase() == "yes";
|
||||
let sanitized_col_name = sanitize_name(&col_name.to_lowercase());
|
||||
let sanitized_type = sanitize_type(&data_type);
|
||||
columns.push(Column {
|
||||
name: sanitized_col_name,
|
||||
data_type: sanitized_type,
|
||||
is_nullable,
|
||||
});
|
||||
}
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
// Improved foreign key query that gathers constraint_name and uniqueness info,
|
||||
// so that composite foreign keys are grouped together.
|
||||
fn get_foreign_keys(client: &mut Client) -> Result<Vec<ForeignKeyGroup>, postgres::Error> {
|
||||
let query = "
|
||||
SELECT
|
||||
tc.constraint_name,
|
||||
tc.table_schema AS source_schema,
|
||||
tc.table_name AS source_table,
|
||||
kcu.column_name AS source_column,
|
||||
ccu.table_schema AS target_schema,
|
||||
ccu.table_name AS target_table,
|
||||
ccu.column_name AS target_column,
|
||||
(
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.table_constraints as utc
|
||||
JOIN information_schema.key_column_usage as ukcu
|
||||
ON utc.constraint_name = ukcu.constraint_name
|
||||
WHERE utc.table_schema = tc.table_schema
|
||||
AND utc.table_name = tc.table_name
|
||||
AND utc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
|
||||
AND ukcu.column_name = kcu.column_name
|
||||
)
|
||||
) as is_unique
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON tc.constraint_name = ccu.constraint_name
|
||||
WHERE
|
||||
tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema NOT IN ('pg_catalog', 'information_schema', 'cron')
|
||||
AND tc.table_schema NOT LIKE 'pg_toast%'
|
||||
ORDER BY tc.constraint_name, kcu.ordinal_position;
|
||||
";
|
||||
let rows = client.query(query, &[])?;
|
||||
let mut map: HashMap<String, ForeignKeyGroup> = HashMap::new();
|
||||
for row in rows {
|
||||
let constraint_name: String = row.get("constraint_name");
|
||||
let source_schema: String = row.get("source_schema");
|
||||
let source_table: String = row.get("source_table");
|
||||
let target_schema: String = row.get("target_schema");
|
||||
let target_table: String = row.get("target_table");
|
||||
let source_column: String = row.get("source_column");
|
||||
let target_column: String = row.get("target_column");
|
||||
let is_unique: bool = row.get("is_unique");
|
||||
|
||||
let entry = map.entry(constraint_name.clone()).or_insert(ForeignKeyGroup {
|
||||
constraint_name: constraint_name.clone(),
|
||||
source_schema: source_schema.clone(),
|
||||
source_table: source_table.clone(),
|
||||
target_schema: target_schema.clone(),
|
||||
target_table: target_table.clone(),
|
||||
source_columns: Vec::new(),
|
||||
target_columns: Vec::new(),
|
||||
is_unique: true,
|
||||
});
|
||||
entry.source_columns.push(source_column);
|
||||
entry.target_columns.push(target_column);
|
||||
entry.is_unique = entry.is_unique && is_unique;
|
||||
}
|
||||
Ok(map.into_iter().map(|(_, v)| v).collect())
|
||||
}
|
||||
|
||||
// --- Utility for generating diagram ---
|
||||
|
||||
fn get_hash(text: &str) -> String {
|
||||
use sha2::{Sha256, Digest};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(text.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
format!("_{}", hex::encode(result))
|
||||
}
|
||||
|
||||
/// Generates the Mermaid diagram.
|
||||
/// - `tables`: all non-join tables (after join tables are removed)
|
||||
/// - `foreign_keys`: grouped foreign key relationships
|
||||
/// - `join_tables`: a vector of join table keys with their foreign key groups
|
||||
/// - `join_table_keys`: set of join table identifiers to filter out regular relationships
|
||||
fn generate_mermaid(
|
||||
tables: &HashMap<(String, String), Vec<Column>>,
|
||||
foreign_keys: &[ForeignKeyGroup],
|
||||
join_tables: &Vec<((String, String), Vec<&ForeignKeyGroup>)>,
|
||||
join_table_keys: &HashSet<(String, String)>
|
||||
) -> String {
|
||||
let mut mermaid = String::from("erDiagram\n");
|
||||
|
||||
// Define entities for non-join tables
|
||||
for ((schema, table), columns) in tables {
|
||||
let sanitized_schema = sanitize_name(schema);
|
||||
let sanitized_table = sanitize_name(table);
|
||||
let full_name = format!("{}.{}", sanitized_schema, sanitized_table);
|
||||
mermaid.push_str(&format!(" {}[\"{}\"] {{\n", get_hash(&full_name), full_name));
|
||||
for column in columns {
|
||||
mermaid.push_str(&format!(" {} {}\n", column.data_type, sanitize_name(&column.name)));
|
||||
}
|
||||
mermaid.push_str(" }\n");
|
||||
}
|
||||
mermaid.push_str("\n");
|
||||
|
||||
// Normal relationships (one-to-one or one-to-many)
|
||||
for fk in foreign_keys {
|
||||
// Skip relationships that originate from join tables (handled separately)
|
||||
if join_table_keys.contains(&(fk.source_schema.clone(), fk.source_table.clone())) {
|
||||
continue;
|
||||
}
|
||||
let sanitized_source_table = sanitize_name(&fk.source_table);
|
||||
let sanitized_target_table = sanitize_name(&fk.target_table);
|
||||
let sanitized_source_schema = sanitize_name(&fk.source_schema);
|
||||
let sanitized_target_schema = sanitize_name(&fk.target_schema);
|
||||
let source_full_name = format!("{}.{}", sanitized_source_schema, sanitized_source_table);
|
||||
let target_full_name = format!("{}.{}", sanitized_target_schema, sanitized_target_table);
|
||||
let relationship_label = format!("{} -> {}", fk.source_columns.join(", "), fk.target_columns.join(", "));
|
||||
let relationship_type = if fk.is_unique { "||--||" } else { "}o--||" };
|
||||
mermaid.push_str(&format!(
|
||||
" {} {} {} : \"{}\"\n",
|
||||
get_hash(&source_full_name),
|
||||
relationship_type,
|
||||
get_hash(&target_full_name),
|
||||
relationship_label
|
||||
));
|
||||
}
|
||||
|
||||
// Many-to-many relationships for detected join tables
|
||||
for (join_key, fk_list) in join_tables {
|
||||
if fk_list.len() == 2 {
|
||||
let fk1 = fk_list[0];
|
||||
let fk2 = fk_list[1];
|
||||
// Draw an edge between the two target tables of the join table.
|
||||
let source_full_name = format!("{}.{}", sanitize_name(&fk1.target_schema), sanitize_name(&fk1.target_table));
|
||||
let target_full_name = format!("{}.{}", sanitize_name(&fk2.target_schema), sanitize_name(&fk2.target_table));
|
||||
let label = format!("join: {} <-> {}", fk1.target_columns.join(", "), fk2.target_columns.join(", "));
|
||||
mermaid.push_str(&format!(
|
||||
" {} }}|--|{{ {} : \"{}\"\n",
|
||||
get_hash(&source_full_name),
|
||||
get_hash(&target_full_name),
|
||||
label
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
mermaid
|
||||
}
|
||||
|
||||
fn generate_svg(mermaid_file: &str, svg_file: &str) -> Result<(), String> {
|
||||
let mmdc_path = which::which("mmdc").map_err(|_| {
|
||||
"Mermaid CLI (mmdc) is not installed or not found in PATH.\nPlease install it by running: npm install -g @mermaid-js/mermaid-cli".to_string()
|
||||
})?;
|
||||
let status = Command::new(mmdc_path)
|
||||
.arg("-i")
|
||||
.arg(mermaid_file)
|
||||
.arg("-o")
|
||||
.arg(svg_file)
|
||||
.status()
|
||||
.map_err(|e| format!("Failed to execute mmdc: {}", e))?;
|
||||
if status.success() {
|
||||
println!("SVG diagram generated successfully as '{}'.", svg_file);
|
||||
Ok(())
|
||||
} else {
|
||||
Err("An error occurred while generating the SVG.".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
dotenv().ok();
|
||||
env_logger::init();
|
||||
|
||||
// If DB_URL is provided, use it. Otherwise, fall back to individual parameters.
|
||||
let conn_str = match env::var("DB_URL") {
|
||||
Ok(url) => {
|
||||
info!("Using DB_URL environment variable for connection.");
|
||||
url
|
||||
},
|
||||
Err(_) => {
|
||||
let db_host = env::var("DB_HOST").unwrap_or_else(|_| {
|
||||
eprintln!("No DB_HOST environment variable provided.");
|
||||
std::process::exit(1);
|
||||
});
|
||||
info!("DB_HOST: {:?}", db_host);
|
||||
|
||||
let db_port = env::var("DB_PORT").unwrap_or_else(|_| {
|
||||
eprintln!("No DB_PORT environment variable provided.");
|
||||
std::process::exit(1);
|
||||
});
|
||||
let db_name = env::var("DB_NAME").unwrap_or_else(|_| {
|
||||
eprintln!("No DB_NAME environment variable provided.");
|
||||
std::process::exit(1);
|
||||
});
|
||||
let db_user = env::var("DB_USER").unwrap_or_else(|_| {
|
||||
eprintln!("No DB_USER environment variable provided.");
|
||||
std::process::exit(1);
|
||||
});
|
||||
let db_password = env::var("DB_PASSWORD").unwrap_or_else(|_| {
|
||||
eprintln!("No DB_PASSWORD environment variable provided.");
|
||||
std::process::exit(1);
|
||||
});
|
||||
format!("host={} port={} dbname={} user={} password={}", db_host, db_port, db_name, db_user, db_password)
|
||||
}
|
||||
};
|
||||
|
||||
let mut client = match Client::connect(&conn_str, NoTls) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
eprintln!("Unable to connect to the database:\n{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch tables and columns
|
||||
let tables_list = match get_tables(&mut client) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
eprintln!("Error fetching tables:\n{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
let mut tables: HashMap<(String, String), Vec<Column>> = HashMap::new();
|
||||
for (schema, table) in &tables_list {
|
||||
match get_columns(&mut client, schema, table) {
|
||||
Ok(cols) => {
|
||||
tables.insert((schema.clone(), table.clone()), cols);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error fetching columns for table '{}':\n{}", table, e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch grouped foreign keys (composite keys handled together)
|
||||
let foreign_keys = match get_foreign_keys(&mut client) {
|
||||
Ok(fks) => fks,
|
||||
Err(e) => {
|
||||
eprintln!("Error fetching foreign keys:\n{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
// --- Detect join tables for many-to-many relationships ---
|
||||
// Build a map from (schema, table) to foreign key groups where the table is the source.
|
||||
let mut fk_by_source: HashMap<(String, String), Vec<&ForeignKeyGroup>> = HashMap::new();
|
||||
for fk in &foreign_keys {
|
||||
fk_by_source
|
||||
.entry((fk.source_schema.clone(), fk.source_table.clone()))
|
||||
.or_default()
|
||||
.push(fk);
|
||||
}
|
||||
|
||||
// A join table is defined as having exactly two columns and exactly two foreign keys.
|
||||
let mut join_table_keys_vec: Vec<(String, String)> = Vec::new();
|
||||
for ((schema, table), columns) in &tables {
|
||||
if let Some(fk_list) = fk_by_source.get(&(schema.clone(), table.clone())) {
|
||||
if columns.len() == fk_list.len() && fk_list.len() == 2 {
|
||||
join_table_keys_vec.push((schema.clone(), table.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
let join_table_keys: HashSet<(String, String)> = join_table_keys_vec.iter().cloned().collect();
|
||||
|
||||
// Build a vector with join table key and its foreign key groups.
|
||||
let mut join_tables: Vec<((String, String), Vec<&ForeignKeyGroup>)> = Vec::new();
|
||||
for key in join_table_keys_vec {
|
||||
if let Some(fk_list) = fk_by_source.get(&key) {
|
||||
join_tables.push((key, fk_list.clone()));
|
||||
}
|
||||
}
|
||||
// Remove join tables from the main entities so they aren’t drawn separately.
|
||||
for key in &join_table_keys {
|
||||
tables.remove(key);
|
||||
}
|
||||
|
||||
// Generate the Mermaid diagram
|
||||
let mermaid_diagram = generate_mermaid(&tables, &foreign_keys, &join_tables, &join_table_keys);
|
||||
|
||||
// For now, we simply print the diagram (or you could write it to a file)
|
||||
println!("{mermaid_diagram}");
|
||||
|
||||
client.close().expect("Failed to close the database connection.");
|
||||
}
|
||||
Reference in New Issue
Block a user