From 998971c94eca20a8a124aaf8c21e7b1684678efa Mon Sep 17 00:00:00 2001 From: sinder Date: Fri, 27 Feb 2026 19:29:27 +0800 Subject: [PATCH 01/10] Add fix and tests --- sea-orm-macros/src/derives/value_type.rs | 22 ++++++++++++++++++++++ sea-orm-sync/tests/derive_tests.rs | 14 ++++++++++++++ tests/derive_tests.rs | 14 ++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/sea-orm-macros/src/derives/value_type.rs b/sea-orm-macros/src/derives/value_type.rs index 77c8630535..016f5dcbb4 100644 --- a/sea-orm-macros/src/derives/value_type.rs +++ b/sea-orm-macros/src/derives/value_type.rs @@ -185,6 +185,25 @@ impl DeriveValueTypeStruct { quote!() }; + let impl_try_getable_array = if cfg!(feature = "postgres-array") { + quote!( + #[automatically_derived] + impl sea_orm::TryGetableArray for #name { + fn try_get_by( + res: &sea_orm::QueryResult, + index: I, + ) -> std::result::Result, sea_orm::TryGetError> { + Ok( as sea_orm::TryGetable>::try_get_by(res, index)? + .into_iter() + .map(|value| Self(value)) + .collect()) + } + } + ) + } else { + quote!() + }; + let impl_not_u8 = if cfg!(feature = "postgres-array") { quote!( #[automatically_derived] @@ -198,6 +217,7 @@ impl DeriveValueTypeStruct { #[automatically_derived] impl std::convert::From<#name> for sea_orm::Value { fn from(source: #name) -> Self { + println!("Struct"); source.0.into() } } @@ -243,6 +263,8 @@ impl DeriveValueTypeStruct { } } + #impl_try_getable_array + #try_from_u64_impl #impl_not_u8 diff --git a/sea-orm-sync/tests/derive_tests.rs b/sea-orm-sync/tests/derive_tests.rs index e482280052..d972130ead 100644 --- a/sea-orm-sync/tests/derive_tests.rs +++ b/sea-orm-sync/tests/derive_tests.rs @@ -69,3 +69,17 @@ struct FromQueryResultNested { #[sea_orm(nested)] _test: SimpleTest, } + +#[cfg(feature = "postgres-array")] +mod postgres_array { + use crate::FromQueryResult; + use sea_orm::DeriveValueType; + + #[derive(DeriveValueType)] + pub struct GoodId(i32); + + #[derive(FromQueryResult)] + pub struct ArrayTest { + pub ingredient_path: Vec, + } +} diff --git a/tests/derive_tests.rs b/tests/derive_tests.rs index e482280052..3d392f47ec 100644 --- a/tests/derive_tests.rs +++ b/tests/derive_tests.rs @@ -69,3 +69,17 @@ struct FromQueryResultNested { #[sea_orm(nested)] _test: SimpleTest, } + +#[cfg(feature = "postgres-array")] +mod postgres_array { + use crate::FromQueryResult; + use sea_orm::DeriveValueType; + + #[derive(DeriveValueType)] + pub struct GoodId(i32); + + #[derive(FromQueryResult)] + pub struct ArrayTest { + pub ingredient_path: Vec, + } +} From 9f85b5ddca2351c9ff255bbbdd4a7b1ac0759f6b Mon Sep 17 00:00:00 2001 From: sinder Date: Tue, 10 Mar 2026 17:03:33 +0800 Subject: [PATCH 02/10] Fix NewType with vec bug Add a new derive attribute `no_vec_impl` to explicitly omit `sea_orm::TryGetableArray` --- sea-orm-macros/src/derives/attributes.rs | 1 + sea-orm-macros/src/derives/value_type.rs | 10 ++++++++- .../src/derives/value_type_match.rs | 22 +++++++++++++++++++ tests/common/features/value_type.rs | 6 +++++ 4 files changed, 38 insertions(+), 1 deletion(-) diff --git a/sea-orm-macros/src/derives/attributes.rs b/sea-orm-macros/src/derives/attributes.rs index a5cab499af..12d6b837ad 100644 --- a/sea-orm-macros/src/derives/attributes.rs +++ b/sea-orm-macros/src/derives/attributes.rs @@ -77,6 +77,7 @@ pub mod value_type_attr { pub from_str: Option, pub to_str: Option, pub try_from_u64: Option<()>, + pub no_vec_impl: Option<()>, } } diff --git a/sea-orm-macros/src/derives/value_type.rs b/sea-orm-macros/src/derives/value_type.rs index 016f5dcbb4..c7171f1479 100644 --- a/sea-orm-macros/src/derives/value_type.rs +++ b/sea-orm-macros/src/derives/value_type.rs @@ -1,3 +1,5 @@ +use crate::derives::value_type_match::omit_vec_impl; + use super::attributes::value_type_attr; use super::value_type_match::{array_type_expr, can_try_from_u64, column_type_expr}; use proc_macro2::TokenStream; @@ -15,6 +17,8 @@ struct DeriveValueTypeStruct { ty: Type, column_type: TokenStream, array_type: TokenStream, + /// Do not implement `sea_orm::TryGetableArray` for this type. Default: false. + no_vec_impl: bool, can_try_from_u64: bool, } @@ -23,6 +27,7 @@ struct DeriveValueTypeStructAttrs { column_type: Option, array_type: Option, try_from_u64: bool, + no_vec_impl: bool, } impl TryFrom for DeriveValueTypeStructAttrs { @@ -33,6 +38,7 @@ impl TryFrom for DeriveValueTypeStructAttrs { column_type: attrs.column_type.map(|s| s.parse()).transpose()?, array_type: attrs.array_type.map(|s| s.parse()).transpose()?, try_from_u64: attrs.try_from_u64.is_some(), + no_vec_impl: attrs.no_vec_impl.is_some(), }) } } @@ -151,12 +157,14 @@ impl DeriveValueTypeStruct { let column_type = column_type_expr(attrs.column_type, field_type, field_span); let array_type = array_type_expr(attrs.array_type, field_type, field_span); let can_try_from_u64 = attrs.try_from_u64 || can_try_from_u64(field_type); + let no_vec_impl = attrs.no_vec_impl || omit_vec_impl(field_type); Ok(Self { name, ty, column_type, array_type, + no_vec_impl, can_try_from_u64, }) } @@ -185,7 +193,7 @@ impl DeriveValueTypeStruct { quote!() }; - let impl_try_getable_array = if cfg!(feature = "postgres-array") { + let impl_try_getable_array = if cfg!(feature = "postgres-array") && !self.no_vec_impl { quote!( #[automatically_derived] impl sea_orm::TryGetableArray for #name { diff --git a/sea-orm-macros/src/derives/value_type_match.rs b/sea-orm-macros/src/derives/value_type_match.rs index cfba21d052..d1d0be1781 100644 --- a/sea-orm-macros/src/derives/value_type_match.rs +++ b/sea-orm-macros/src/derives/value_type_match.rs @@ -151,6 +151,28 @@ pub fn can_try_from_u64(field_type: &str) -> bool { ) } +/// Maximum depth of vector nesting allowed INSIDE NEW TYPE before omitting sea_orm::TryGetableArray +/// For example, `struct A (Vec>)` has dimensionality of 2 +/// Abosolute maximum would be 5, because of Postgres limit of 6 +const MAX_VEC_DIMENSIONALITY: u8 = 0; + +/// Determines whether to omit `sea_orm::TryGetableArray` implementation for a given field type +/// based on the vector dimensionality. +pub fn omit_vec_impl(field_type: &str) -> bool { + let mut depth = 0u8; + let mut current = field_type.trim(); + + while let Some(inner) = current.strip_prefix("Vec<") { + if depth >= MAX_VEC_DIMENSIONALITY { + return true; + } + depth += 1; + current = inner.trim_start(); + } + + false +} + /// Return whether it is nullable fn trim_option(s: &str) -> (bool, &str) { if s.starts_with("Option<") { diff --git a/tests/common/features/value_type.rs b/tests/common/features/value_type.rs index a58dbc283b..a963f2c522 100644 --- a/tests/common/features/value_type.rs +++ b/tests/common/features/value_type.rs @@ -69,9 +69,15 @@ where } } +// Automatically disable vec impl #[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] pub struct StringVec(pub Vec); +// Explicitly disable vec impl +#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +#[sea_orm(no_vec_impl)] +pub struct StringVecNoImpl(pub Vec); + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] #[sea_orm(value_type = "String")] pub enum Tag1 { From 9b465975b23ae37c8dc064fea904d480d92e3384 Mon Sep 17 00:00:00 2001 From: sinder Date: Tue, 10 Mar 2026 18:01:23 +0800 Subject: [PATCH 03/10] Fix array implementation for DerivValueTypeString --- sea-orm-macros/src/derives/value_type.rs | 32 +++++++++++++ tests/derive_tests.rs | 58 ++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/sea-orm-macros/src/derives/value_type.rs b/sea-orm-macros/src/derives/value_type.rs index c7171f1479..b3d2bf37e2 100644 --- a/sea-orm-macros/src/derives/value_type.rs +++ b/sea-orm-macros/src/derives/value_type.rs @@ -305,6 +305,36 @@ impl DeriveValueTypeString { None => "e!(String(sea_orm::sea_query::StringLen::None)), }; + let impl_try_getable_array = if cfg!(feature = "postgres-array") { + quote!( + #[automatically_derived] + impl sea_orm::TryGetableArray for #name { + fn try_get_by( + res: &sea_orm::QueryResult, + index: I, + ) -> std::result::Result, sea_orm::TryGetError> { + let mut result = Vec::new(); + for string in as sea_orm::TryGetable>::try_get_by(res, index)?.into_iter() { + result.push(#from_str(&string) + .map_err(|err| + { + sea_orm::TryGetError::DbErr( + sea_orm::DbErr::TryIntoErr { + from: "String", + into: stringify!(#name), + source: std::sync::Arc::new(err), + }) + } + )?); + } + Ok(result) + } + } + ) + } else { + quote!() + }; + let impl_not_u8 = if cfg!(feature = "postgres-array") { quote!( #[automatically_derived] @@ -372,6 +402,8 @@ impl DeriveValueTypeString { } #impl_not_u8 + + #impl_try_getable_array ) } } diff --git a/tests/derive_tests.rs b/tests/derive_tests.rs index 3d392f47ec..e33cae65c7 100644 --- a/tests/derive_tests.rs +++ b/tests/derive_tests.rs @@ -76,10 +76,62 @@ mod postgres_array { use sea_orm::DeriveValueType; #[derive(DeriveValueType)] - pub struct GoodId(i32); + pub struct IngredientId(i32); + + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] + #[sea_orm(value_type = "String")] + pub struct NumericLabel { + pub value: i64, + } + + impl std::fmt::Display for NumericLabel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value) + } + } + + impl std::str::FromStr for NumericLabel { + type Err = std::num::ParseIntError; + fn from_str(s: &str) -> Result { + Ok(Self { value: s.parse()? }) + } + } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] + #[sea_orm(value_type = "String")] + pub enum TextureKind { + Hard, + Soft, + } + + impl std::fmt::Display for TextureKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Hard => "hard", + Self::Soft => "soft", + } + ) + } + } + + impl std::str::FromStr for TextureKind { + type Err = sea_query::ValueTypeErr; + fn from_str(s: &str) -> Result { + Ok(match s { + "hard" => Self::Hard, + "soft" => Self::Soft, + _ => return Err(sea_query::ValueTypeErr), + }) + } + } #[derive(FromQueryResult)] - pub struct ArrayTest { - pub ingredient_path: Vec, + pub struct IngredientPathRow { + pub ingredient_path: Vec, + pub numeric_label_path: Vec, + pub texture_path: Vec, } } From 5e7ee973a1472b5fd67dc99e22f415b33ca305d6 Mon Sep 17 00:00:00 2001 From: sinder Date: Tue, 10 Mar 2026 18:22:47 +0800 Subject: [PATCH 04/10] add tests to sea-orm-sync allow absurd_extreme_comparisons --- .../src/derives/value_type_match.rs | 1 + .../tests/common/features/value_type.rs | 6 ++ sea-orm-sync/tests/derive_tests.rs | 58 ++++++++++++++++++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/sea-orm-macros/src/derives/value_type_match.rs b/sea-orm-macros/src/derives/value_type_match.rs index d1d0be1781..26360838f5 100644 --- a/sea-orm-macros/src/derives/value_type_match.rs +++ b/sea-orm-macros/src/derives/value_type_match.rs @@ -163,6 +163,7 @@ pub fn omit_vec_impl(field_type: &str) -> bool { let mut current = field_type.trim(); while let Some(inner) = current.strip_prefix("Vec<") { + #[allow(clippy::absurd_extreme_comparisons)] if depth >= MAX_VEC_DIMENSIONALITY { return true; } diff --git a/sea-orm-sync/tests/common/features/value_type.rs b/sea-orm-sync/tests/common/features/value_type.rs index a58dbc283b..a963f2c522 100644 --- a/sea-orm-sync/tests/common/features/value_type.rs +++ b/sea-orm-sync/tests/common/features/value_type.rs @@ -69,9 +69,15 @@ where } } +// Automatically disable vec impl #[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] pub struct StringVec(pub Vec); +// Explicitly disable vec impl +#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +#[sea_orm(no_vec_impl)] +pub struct StringVecNoImpl(pub Vec); + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] #[sea_orm(value_type = "String")] pub enum Tag1 { diff --git a/sea-orm-sync/tests/derive_tests.rs b/sea-orm-sync/tests/derive_tests.rs index d972130ead..e33cae65c7 100644 --- a/sea-orm-sync/tests/derive_tests.rs +++ b/sea-orm-sync/tests/derive_tests.rs @@ -76,10 +76,62 @@ mod postgres_array { use sea_orm::DeriveValueType; #[derive(DeriveValueType)] - pub struct GoodId(i32); + pub struct IngredientId(i32); + + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] + #[sea_orm(value_type = "String")] + pub struct NumericLabel { + pub value: i64, + } + + impl std::fmt::Display for NumericLabel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value) + } + } + + impl std::str::FromStr for NumericLabel { + type Err = std::num::ParseIntError; + fn from_str(s: &str) -> Result { + Ok(Self { value: s.parse()? }) + } + } + + #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] + #[sea_orm(value_type = "String")] + pub enum TextureKind { + Hard, + Soft, + } + + impl std::fmt::Display for TextureKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Hard => "hard", + Self::Soft => "soft", + } + ) + } + } + + impl std::str::FromStr for TextureKind { + type Err = sea_query::ValueTypeErr; + fn from_str(s: &str) -> Result { + Ok(match s { + "hard" => Self::Hard, + "soft" => Self::Soft, + _ => return Err(sea_query::ValueTypeErr), + }) + } + } #[derive(FromQueryResult)] - pub struct ArrayTest { - pub ingredient_path: Vec, + pub struct IngredientPathRow { + pub ingredient_path: Vec, + pub numeric_label_path: Vec, + pub texture_path: Vec, } } From 208c137dd6b601979e8dcb32f2a46f72abcc0611 Mon Sep 17 00:00:00 2001 From: sinder Date: Thu, 19 Mar 2026 17:36:23 +0800 Subject: [PATCH 05/10] Add separate functions for discovery to schema builder --- src/schema/builder.rs | 425 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 425 insertions(+) diff --git a/src/schema/builder.rs b/src/schema/builder.rs index 83aac77f1f..83a383c514 100644 --- a/src/schema/builder.rs +++ b/src/schema/builder.rs @@ -191,6 +191,164 @@ impl SchemaBuilder { Ok(()) } + /// This function fetches the changes needed to sync the database schema with this builder's entities. + /// * `db` - The database connection to use for fetching existing table schema. + /// * `dangerous` - If `true`, changes will contain drop statements. + #[cfg(feature = "schema-sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] + pub async fn discover(&self, db: &C, allow_dangerous: bool) -> Result, DbErr> + where + C: ConnectionTrait + sea_schema::Connection, + { + let _existing = match db.get_database_backend() { + #[cfg(feature = "sqlx-mysql")] + DbBackend::MySql => { + use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe}; + + let current_schema: String = db + .query_one( + sea_query::SelectStatement::new() + .expr(sea_schema::mysql::MySql::get_current_schema()), + ) + .await? + .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? + .try_get_by_index(0)?; + let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); + + let schema = schema_discovery + .discover_with(db) + .await + .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; + + DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + } + } + #[cfg(feature = "sqlx-postgres")] + DbBackend::Postgres => { + use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe}; + + let current_schema: String = db + .query_one( + sea_query::SelectStatement::new() + .expr(sea_schema::postgres::Postgres::get_current_schema()), + ) + .await? + .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? + .try_get_by_index(0)?; + let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); + + let schema = schema_discovery + .discover_with(db) + .await + .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; + + DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: schema.enums.iter().map(|def| def.write()).collect(), + } + } + #[cfg(feature = "sqlx-sqlite")] + DbBackend::Sqlite => { + use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; + let schema = SchemaDiscovery::discover_with(db) + .await + .map_err(|err| { + DbErr::Query(match err { + SqliteDiscoveryError::SqlxError(err) => { + crate::RuntimeErr::SqlxError(err.into()) + } + _ => crate::RuntimeErr::Internal(format!("{err:?}")), + }) + })? + .merge_indexes_into_table(); + DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + } + } + #[cfg(feature = "rusqlite")] + DbBackend::Sqlite => { + use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; + let schema = SchemaDiscovery::discover_with(db) + .map_err(|err| { + DbErr::Query(match err { + SqliteDiscoveryError::RusqliteError(err) => { + crate::RuntimeErr::Rusqlite(err.into()) + } + _ => crate::RuntimeErr::Internal(format!("{err:?}")), + }) + })? + .merge_indexes_into_table(); + DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + } + } + #[allow(unreachable_patterns)] + other => { + return Err(DbErr::BackendNotSupported { + db: other.as_str(), + ctx: "SchemaBuilder::discover", + }); + } + }; + + #[allow(unreachable_code)] + let db_backend = db.get_database_backend(); + + #[allow(unreachable_code)] + let mut changes: Vec = Vec::new(); + #[allow(unreachable_code)] + let mut created_enums: Vec = Default::default(); + + #[allow(unreachable_code)] + for table_name in self.sorted_tables() { + if let Some(entity) = self + .entities + .iter() + .find(|entity| table_name == get_table_name(entity.table.get_table_name())) + { + entity.discover_changes( + db_backend, + &_existing, + &mut created_enums, + &mut changes, + allow_dangerous, + ); + } + } + + // When dangerous is allowed, detect tables in the database that are not in the entity set + // and generate drop statements for them. + #[allow(unreachable_code)] + if allow_dangerous { + for existing_table in &_existing.tables { + let existing_name = get_table_name(existing_table.get_table_name()); + let in_entities = self + .entities + .iter() + .any(|e| get_table_name(e.table.get_table_name()) == existing_name); + if !in_entities { + let stmt = db_backend.build( + sea_query::Table::drop() + .table( + existing_table + .get_table_name() + .expect("table must have a name") + .clone(), + ) + .if_exists(), + ); + changes.push(stmt); + } + } + } + + Ok(changes) + } + /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys. /// Will fail if any table already exists. Use [`sync`] if you want an incremental version that can perform schema diff. pub async fn apply(self, db: &C) -> Result<(), DbErr> { @@ -492,6 +650,273 @@ impl EntitySchemaInfo { Ok(()) } + // better to always compile this function + #[allow(dead_code)] + fn discover_changes( + &self, + db_backend: DbBackend, + existing: &DiscoveredSchema, + created_enums: &mut Vec, + changes: &mut Vec, + allow_dangerous: bool, + ) { + // create enum before creating table + for stmt in self.enums.iter() { + let mut has_enum = false; + let new_stmt = db_backend.build(stmt); + for existing_enum in &existing.enums { + if db_backend.build(existing_enum) == new_stmt { + has_enum = true; + break; + } + } + if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) { + changes.push(new_stmt.clone()); + created_enums.push(new_stmt); + } + } + + let table_name = get_table_name(self.table.get_table_name()); + let mut existing_table = None; + for tbl in &existing.tables { + if get_table_name(tbl.get_table_name()) == table_name { + existing_table = Some(tbl); + break; + } + } + + if let Some(existing_table) = existing_table { + // Discover column additions / renames + for column_def in self.table.get_columns() { + let mut column_exists = false; + for existing_column in existing_table.get_columns() { + if column_def.get_column_name() == existing_column.get_column_name() { + column_exists = true; + break; + } + } + if !column_exists { + let mut renamed_from = ""; + if let Some(comment) = &column_def.get_column_spec().comment { + if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") { + if let Some((prefix, _)) = suffix.split_once('"') { + renamed_from = prefix; + } + } + } + if renamed_from.is_empty() { + changes.push( + db_backend.build( + TableAlterStatement::new() + .table( + self.table.get_table_name().expect("Checked above").clone(), + ) + .add_column(column_def.to_owned()), + ), + ); + } else { + changes.push( + db_backend.build( + TableAlterStatement::new() + .table( + self.table.get_table_name().expect("Checked above").clone(), + ) + .rename_column( + renamed_from.to_owned(), + column_def.get_column_name(), + ), + ), + ); + } + } + } + + // Discover missing foreign keys (not supported on SQLite) + if db_backend != DbBackend::Sqlite { + for foreign_key in self.table.get_foreign_key_create_stmts().iter() { + let mut key_exists = false; + for existing_key in existing_table.get_foreign_key_create_stmts().iter() { + if compare_foreign_key(foreign_key, existing_key) { + key_exists = true; + break; + } + } + if !key_exists { + changes.push(db_backend.build(foreign_key)); + } + } + } + + // Dangerous: drop columns that exist in DB but not in entity + if allow_dangerous { + for existing_column in existing_table.get_columns() { + let col_name = existing_column.get_column_name(); + let in_entity = self + .table + .get_columns() + .iter() + .any(|c| c.get_column_name() == col_name); + if !in_entity { + changes.push( + db_backend.build( + TableAlterStatement::new() + .table( + self.table.get_table_name().expect("Checked above").clone(), + ) + .drop_column(col_name.into_iden()), + ), + ); + } + } + + // Dangerous: drop foreign keys that exist in DB but not in entity + if db_backend != DbBackend::Sqlite { + for existing_key in existing_table.get_foreign_key_create_stmts().iter() { + let in_entity = self + .table + .get_foreign_key_create_stmts() + .iter() + .any(|fk| compare_foreign_key(fk, existing_key)); + if !in_entity { + if let Some(name) = existing_key.get_foreign_key().get_name() { + changes.push( + db_backend.build( + TableAlterStatement::new() + .table( + self.table + .get_table_name() + .expect("Checked above") + .clone(), + ) + .drop_foreign_key(name.to_owned()), + ), + ); + } + } + } + } + } + } else { + // Table doesn't exist in DB, create it + changes.push(db_backend.build(&self.table)); + } + + // Discover missing indexes + for stmt in self.indexes.iter() { + let mut has_index = false; + if let Some(existing_table) = existing_table { + for existing_index in existing_table.get_indexes() { + if existing_index.get_index_spec().get_column_names() + == stmt.get_index_spec().get_column_names() + { + has_index = true; + break; + } + } + } + if !has_index { + let mut stmt = stmt.clone(); + stmt.if_not_exists(); + changes.push(db_backend.build(&stmt)); + } + } + + if let Some(existing_table) = existing_table { + // Discover missing unique indexes for column-level UNIQUE constraints + for column_def in self.table.get_columns() { + if column_def.get_column_spec().unique { + let col_name = column_def.get_column_name(); + let col_exists = existing_table + .get_columns() + .iter() + .any(|c| c.get_column_name() == col_name); + if !col_exists { + continue; + } + let already_unique = existing_table.get_indexes().iter().any(|idx| { + if !idx.is_unique_key() { + return false; + } + let cols = idx.get_index_spec().get_column_names(); + cols.len() == 1 && cols[0] == col_name + }); + if !already_unique { + let table_name = + self.table.get_table_name().expect("table must have a name"); + let tbl_str = table_name.sea_orm_table().to_string(); + let table_ref = table_name.clone(); + changes.push( + db_backend.build( + Index::create() + .name(format!("idx-{tbl_str}-{col_name}")) + .table(table_ref) + .col(col_name.into_iden()) + .unique() + .if_not_exists(), + ), + ); + } + } + } + } + + if let Some(existing_table) = existing_table { + // Discover unique keys that need to be dropped + for existing_index in existing_table.get_indexes() { + if existing_index.is_unique_key() { + let mut has_index = false; + for stmt in self.indexes.iter() { + if existing_index.get_index_spec().get_column_names() + == stmt.get_index_spec().get_column_names() + { + has_index = true; + break; + } + } + if !has_index { + let index_cols = existing_index.get_index_spec().get_column_names(); + if index_cols.len() == 1 { + for column_def in self.table.get_columns() { + if column_def.get_column_name() == index_cols[0] + && column_def.get_column_spec().unique + { + has_index = true; + break; + } + } + } + } + if !has_index { + if let Some(drop_existing) = existing_index + .get_index_spec() + .get_name() + .map(|s| s.to_owned()) + { + if db_backend == DbBackend::Postgres { + changes.push( + db_backend.build( + TableAlterStatement::new() + .table( + self.table + .get_table_name() + .expect("Checked above") + .clone(), + ) + .drop_constraint(drop_existing), + ), + ); + } else { + changes.push( + db_backend.build(sea_query::Index::drop().name(drop_existing)), + ); + } + } + } + } + } + } + } + fn debug_print( &self, f: &mut std::fmt::Formatter<'_>, From 8251d3a3bcddf3e3181b9a5b74bc4809209b963d Mon Sep 17 00:00:00 2001 From: sinder Date: Fri, 20 Mar 2026 05:49:36 +0800 Subject: [PATCH 06/10] update sync to use discover() instead --- src/schema/builder.rs | 450 +++++++----------------------------------- 1 file changed, 69 insertions(+), 381 deletions(-) diff --git a/src/schema/builder.rs b/src/schema/builder.rs index 83a383c514..f8a5a0e6a9 100644 --- a/src/schema/builder.rs +++ b/src/schema/builder.rs @@ -79,101 +79,29 @@ impl SchemaBuilder { where C: ConnectionTrait + sea_schema::Connection, { - let _existing = match db.get_database_backend() { - #[cfg(feature = "sqlx-mysql")] - DbBackend::MySql => { - use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe}; - - let current_schema: String = db - .query_one( - sea_query::SelectStatement::new() - .expr(sea_schema::mysql::MySql::get_current_schema()), - ) - .await? - .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? - .try_get_by_index(0)?; - let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); - - let schema = schema_discovery - .discover_with(db) - .await - .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; - - DiscoveredSchema { - tables: schema.tables.iter().map(|table| table.write()).collect(), - enums: vec![], - } - } - #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres => { - use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe}; - - let current_schema: String = db - .query_one( - sea_query::SelectStatement::new() - .expr(sea_schema::postgres::Postgres::get_current_schema()), - ) - .await? - .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? - .try_get_by_index(0)?; - let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); + let changes = self.discover(db, false).await?; + for stmt in changes { + db.execute_raw(stmt).await?; + } + Ok(()) + } - let schema = schema_discovery - .discover_with(db) - .await - .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; + /// This function fetches the changes needed to sync the database schema with this builder's entities. + /// * `db` - The database connection to use for fetching existing table schema. + /// * `allow_dangerous` - If `true`, changes will contain drop statements (drop tables, drop columns, drop constraints). + #[cfg(feature = "schema-sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] + pub async fn discover(&self, db: &C, allow_dangerous: bool) -> Result, DbErr> + where + C: ConnectionTrait + sea_schema::Connection, + { + let _existing = Self::discover_existing_schema(db).await?; - DiscoveredSchema { - tables: schema.tables.iter().map(|table| table.write()).collect(), - enums: schema.enums.iter().map(|def| def.write()).collect(), - } - } - #[cfg(feature = "sqlx-sqlite")] - DbBackend::Sqlite => { - use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; - let schema = SchemaDiscovery::discover_with(db) - .await - .map_err(|err| { - DbErr::Query(match err { - SqliteDiscoveryError::SqlxError(err) => { - crate::RuntimeErr::SqlxError(err.into()) - } - _ => crate::RuntimeErr::Internal(format!("{err:?}")), - }) - })? - .merge_indexes_into_table(); - DiscoveredSchema { - tables: schema.tables.iter().map(|table| table.write()).collect(), - enums: vec![], - } - } - #[cfg(feature = "rusqlite")] - DbBackend::Sqlite => { - use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; - let schema = SchemaDiscovery::discover_with(db) - .map_err(|err| { - DbErr::Query(match err { - SqliteDiscoveryError::RusqliteError(err) => { - crate::RuntimeErr::Rusqlite(err.into()) - } - _ => crate::RuntimeErr::Internal(format!("{err:?}")), - }) - })? - .merge_indexes_into_table(); - DiscoveredSchema { - tables: schema.tables.iter().map(|table| table.write()).collect(), - enums: vec![], - } - } - #[allow(unreachable_patterns)] - other => { - return Err(DbErr::BackendNotSupported { - db: other.as_str(), - ctx: "SchemaBuilder::sync", - }); - } - }; + #[allow(unreachable_code)] + let db_backend = db.get_database_backend(); + #[allow(unreachable_code)] + let mut changes: Vec = Vec::new(); #[allow(unreachable_code)] let mut created_enums: Vec = Default::default(); @@ -184,23 +112,51 @@ impl SchemaBuilder { .iter() .find(|entity| table_name == get_table_name(entity.table.get_table_name())) { - entity.sync(db, &_existing, &mut created_enums).await?; + entity.discover_changes( + db_backend, + &_existing, + &mut created_enums, + &mut changes, + allow_dangerous, + ); } } - Ok(()) + // When dangerous is allowed, detect tables in the database that are not in the entity set + // and generate drop statements for them. + #[allow(unreachable_code)] + if allow_dangerous { + for existing_table in &_existing.tables { + let existing_name = get_table_name(existing_table.get_table_name()); + let in_entities = self + .entities + .iter() + .any(|e| get_table_name(e.table.get_table_name()) == existing_name); + if !in_entities { + let stmt = db_backend.build( + sea_query::Table::drop() + .table( + existing_table + .get_table_name() + .expect("table must have a name") + .clone(), + ) + .if_exists(), + ); + changes.push(stmt); + } + } + } + + Ok(changes) } - /// This function fetches the changes needed to sync the database schema with this builder's entities. - /// * `db` - The database connection to use for fetching existing table schema. - /// * `dangerous` - If `true`, changes will contain drop statements. #[cfg(feature = "schema-sync")] - #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] - pub async fn discover(&self, db: &C, allow_dangerous: bool) -> Result, DbErr> + async fn discover_existing_schema(db: &C) -> Result where C: ConnectionTrait + sea_schema::Connection, { - let _existing = match db.get_database_backend() { + match db.get_database_backend() { #[cfg(feature = "sqlx-mysql")] DbBackend::MySql => { use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe}; @@ -220,10 +176,10 @@ impl SchemaBuilder { .await .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; - DiscoveredSchema { + Ok(DiscoveredSchema { tables: schema.tables.iter().map(|table| table.write()).collect(), enums: vec![], - } + }) } #[cfg(feature = "sqlx-postgres")] DbBackend::Postgres => { @@ -244,10 +200,10 @@ impl SchemaBuilder { .await .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; - DiscoveredSchema { + Ok(DiscoveredSchema { tables: schema.tables.iter().map(|table| table.write()).collect(), enums: schema.enums.iter().map(|def| def.write()).collect(), - } + }) } #[cfg(feature = "sqlx-sqlite")] DbBackend::Sqlite => { @@ -263,10 +219,10 @@ impl SchemaBuilder { }) })? .merge_indexes_into_table(); - DiscoveredSchema { + Ok(DiscoveredSchema { tables: schema.tables.iter().map(|table| table.write()).collect(), enums: vec![], - } + }) } #[cfg(feature = "rusqlite")] DbBackend::Sqlite => { @@ -281,72 +237,17 @@ impl SchemaBuilder { }) })? .merge_indexes_into_table(); - DiscoveredSchema { + Ok(DiscoveredSchema { tables: schema.tables.iter().map(|table| table.write()).collect(), enums: vec![], - } + }) } #[allow(unreachable_patterns)] - other => { - return Err(DbErr::BackendNotSupported { - db: other.as_str(), - ctx: "SchemaBuilder::discover", - }); - } - }; - - #[allow(unreachable_code)] - let db_backend = db.get_database_backend(); - - #[allow(unreachable_code)] - let mut changes: Vec = Vec::new(); - #[allow(unreachable_code)] - let mut created_enums: Vec = Default::default(); - - #[allow(unreachable_code)] - for table_name in self.sorted_tables() { - if let Some(entity) = self - .entities - .iter() - .find(|entity| table_name == get_table_name(entity.table.get_table_name())) - { - entity.discover_changes( - db_backend, - &_existing, - &mut created_enums, - &mut changes, - allow_dangerous, - ); - } - } - - // When dangerous is allowed, detect tables in the database that are not in the entity set - // and generate drop statements for them. - #[allow(unreachable_code)] - if allow_dangerous { - for existing_table in &_existing.tables { - let existing_name = get_table_name(existing_table.get_table_name()); - let in_entities = self - .entities - .iter() - .any(|e| get_table_name(e.table.get_table_name()) == existing_name); - if !in_entities { - let stmt = db_backend.build( - sea_query::Table::drop() - .table( - existing_table - .get_table_name() - .expect("table must have a name") - .clone(), - ) - .if_exists(), - ); - changes.push(stmt); - } - } + other => Err(DbErr::BackendNotSupported { + db: other.as_str(), + ctx: "SchemaBuilder::discover_existing_schema", + }), } - - Ok(changes) } /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys. @@ -403,6 +304,7 @@ impl SchemaBuilder { } } +/// Stores the discovered schema from the database, including tables and enums struct DiscoveredSchema { tables: Vec, enums: Vec, @@ -438,220 +340,6 @@ impl EntitySchemaInfo { } // better to always compile this function - #[allow(dead_code)] - async fn sync( - &self, - db: &C, - existing: &DiscoveredSchema, - created_enums: &mut Vec, - ) -> Result<(), DbErr> { - let db_backend = db.get_database_backend(); - - // create enum before creating table - for stmt in self.enums.iter() { - let mut has_enum = false; - let new_stmt = db_backend.build(stmt); - for existing_enum in &existing.enums { - if db_backend.build(existing_enum) == new_stmt { - has_enum = true; - // TODO add enum variants - break; - } - } - if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) { - db.execute(stmt).await?; - created_enums.push(new_stmt); - } - } - let table_name = get_table_name(self.table.get_table_name()); - let mut existing_table = None; - for tbl in &existing.tables { - if get_table_name(tbl.get_table_name()) == table_name { - existing_table = Some(tbl); - break; - } - } - if let Some(existing_table) = existing_table { - for column_def in self.table.get_columns() { - let mut column_exists = false; - for existing_column in existing_table.get_columns() { - if column_def.get_column_name() == existing_column.get_column_name() { - column_exists = true; - break; - } - } - if !column_exists { - let mut renamed_from = ""; - if let Some(comment) = &column_def.get_column_spec().comment { - if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") { - if let Some((prefix, _)) = suffix.split_once('"') { - renamed_from = prefix; - } - } - } - if renamed_from.is_empty() { - db.execute( - TableAlterStatement::new() - .table(self.table.get_table_name().expect("Checked above").clone()) - .add_column(column_def.to_owned()), - ) - .await?; - } else { - db.execute( - TableAlterStatement::new() - .table(self.table.get_table_name().expect("Checked above").clone()) - .rename_column( - renamed_from.to_owned(), - column_def.get_column_name(), - ), - ) - .await?; - } - } - } - if db.get_database_backend() != DbBackend::Sqlite { - for foreign_key in self.table.get_foreign_key_create_stmts().iter() { - let mut key_exists = false; - for existing_key in existing_table.get_foreign_key_create_stmts().iter() { - if compare_foreign_key(foreign_key, existing_key) { - key_exists = true; - break; - } - } - if !key_exists { - db.execute(foreign_key).await?; - } - } - } - } else { - db.execute(&self.table).await?; - } - for stmt in self.indexes.iter() { - let mut has_index = false; - if let Some(existing_table) = existing_table { - for existing_index in existing_table.get_indexes() { - if existing_index.get_index_spec().get_column_names() - == stmt.get_index_spec().get_column_names() - { - has_index = true; - break; - } - } - } - if !has_index { - // shall we do alter table add constraint for unique index? - let mut stmt = stmt.clone(); - stmt.if_not_exists(); - db.execute(&stmt).await?; - } - } - if let Some(existing_table) = existing_table { - // For columns with a column-level UNIQUE constraint (#[sea_orm(unique)]) that - // already exist in the table but do not yet have a unique index, create one. - for column_def in self.table.get_columns() { - if column_def.get_column_spec().unique { - let col_name = column_def.get_column_name(); - let col_exists = existing_table - .get_columns() - .iter() - .any(|c| c.get_column_name() == col_name); - if !col_exists { - // Column is being added in this sync pass; the ALTER TABLE ADD COLUMN - // will include the UNIQUE inline, so no separate index needed. - continue; - } - let already_unique = existing_table.get_indexes().iter().any(|idx| { - if !idx.is_unique_key() { - return false; - } - let cols = idx.get_index_spec().get_column_names(); - cols.len() == 1 && cols[0] == col_name - }); - if !already_unique { - let table_name = - self.table.get_table_name().expect("table must have a name"); - let tbl_str = table_name.sea_orm_table().to_string(); - let table_ref = table_name.clone(); - db.execute( - Index::create() - .name(format!("idx-{tbl_str}-{col_name}")) - .table(table_ref) - .col(col_name.into_iden()) - .unique() - .if_not_exists(), - ) - .await?; - } - } - } - } - if let Some(existing_table) = existing_table { - // find all unique keys from existing table - // if it no longer exist in new schema, drop it - for existing_index in existing_table.get_indexes() { - if existing_index.is_unique_key() { - let mut has_index = false; - for stmt in self.indexes.iter() { - if existing_index.get_index_spec().get_column_names() - == stmt.get_index_spec().get_column_names() - { - has_index = true; - break; - } - } - // Also check if the unique index corresponds to a column-level UNIQUE - // constraint (from #[sea_orm(unique)]). These are embedded in the CREATE - // TABLE column definition and not tracked in self.indexes, so we must not - // try to drop them during sync. - if !has_index { - let index_cols = existing_index.get_index_spec().get_column_names(); - if index_cols.len() == 1 { - for column_def in self.table.get_columns() { - if column_def.get_column_name() == index_cols[0] - && column_def.get_column_spec().unique - { - has_index = true; - break; - } - } - } - } - if !has_index { - if let Some(drop_existing) = existing_index - .get_index_spec() - .get_name() - .map(|s| s.to_owned()) - { - if db_backend == DbBackend::Postgres { - // On PostgreSQL, unique indexes created via column-level UNIQUE - // (e.g. ADD COLUMN ... UNIQUE) are backed by a named constraint. - // DROP INDEX fails on constraint-owned indexes; use - // ALTER TABLE ... DROP CONSTRAINT instead. - db.execute( - TableAlterStatement::new() - .table( - self.table - .get_table_name() - .expect("Checked above") - .clone(), - ) - .drop_constraint(drop_existing), - ) - .await?; - } else { - db.execute(sea_query::Index::drop().name(drop_existing)) - .await?; - } - } - } - } - } - } - Ok(()) - } - - // better to always compile this function - #[allow(dead_code)] fn discover_changes( &self, db_backend: DbBackend, From ac82ef15b342ea8ce65a55814fbad721da6ab1c4 Mon Sep 17 00:00:00 2001 From: sinder Date: Fri, 20 Mar 2026 05:49:51 +0800 Subject: [PATCH 07/10] add .zed to gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7ec785cd4a..237c030705 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,7 @@ Cargo.lock .vscode .idea/* */.idea/* +.zed/* +*/.zed/* .env.local -.DS_Store \ No newline at end of file +.DS_Store From 1813316feb60751cca2eda175be5a187ffed933a Mon Sep 17 00:00:00 2001 From: sinder Date: Fri, 20 Mar 2026 07:04:05 +0800 Subject: [PATCH 08/10] Add first draft implementation of entity-first migrator --- sea-orm-entity/Cargo.toml | 43 +++++ sea-orm-entity/src/cli.rs | 306 ++++++++++++++++++++++++++++++++++ sea-orm-entity/src/codegen.rs | 66 ++++++++ sea-orm-entity/src/filter.rs | 104 ++++++++++++ sea-orm-entity/src/fs.rs | 98 +++++++++++ sea-orm-entity/src/lib.rs | 26 +++ sea-orm-entity/src/summary.rs | 114 +++++++++++++ 7 files changed, 757 insertions(+) create mode 100644 sea-orm-entity/Cargo.toml create mode 100644 sea-orm-entity/src/cli.rs create mode 100644 sea-orm-entity/src/codegen.rs create mode 100644 sea-orm-entity/src/filter.rs create mode 100644 sea-orm-entity/src/fs.rs create mode 100644 sea-orm-entity/src/lib.rs create mode 100644 sea-orm-entity/src/summary.rs diff --git a/sea-orm-entity/Cargo.toml b/sea-orm-entity/Cargo.toml new file mode 100644 index 0000000000..5f5aaefe3f --- /dev/null +++ b/sea-orm-entity/Cargo.toml @@ -0,0 +1,43 @@ +[workspace] +# A separate workspace + +[package] +name = "sea-orm-entity" +version = "0.1.0" +edition = "2024" +description = "Entity-first migration generator for SeaORM" +license = "MIT OR Apache-2.0" +repository = "https://github.com/SeaQL/sea-orm" +rust-version = "1.85.0" + +[dependencies] +sea-orm = { path = "..", version = "~2.0.0-rc.37", features = ["schema-sync"] } +sea-orm-migration = { path = "../sea-orm-migration", version = "~2.0.0-rc.37", default-features = false } +sea-schema = { version = "0.17.0-rc", default-features = false, features = [ + "discovery", + "writer", + "probe", +] } +async-trait = { version = "0.1", default-features = false } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +regex = { version = "1" } +clap = { version = "4.3", features = ["env", "derive"] } +dotenvy = { version = "0.15", default-features = false } +tokio = { version = "1", features = ["rt-multi-thread", "macros"], optional = true } +tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt"] } + +[features] +default = ["runtime-tokio-native-tls", "sqlx-sqlite"] +runtime-tokio = ["tokio", "sea-orm/runtime-tokio", "sea-schema/runtime-tokio", "sea-orm-migration/runtime-tokio"] +runtime-tokio-native-tls = ["tokio", "sea-orm/runtime-tokio-native-tls", "sea-schema/runtime-tokio-native-tls", "sea-orm-migration/runtime-tokio-native-tls"] +runtime-tokio-rustls = ["tokio", "sea-orm/runtime-tokio-rustls", "sea-schema/runtime-tokio-rustls", "sea-orm-migration/runtime-tokio-rustls"] +runtime-async-std = ["sea-orm/runtime-async-std", "sea-schema/runtime-async-std", "sea-orm-migration/runtime-async-std"] +runtime-async-std-native-tls = ["sea-orm/runtime-async-std-native-tls", "sea-schema/runtime-async-std-native-tls", "sea-orm-migration/runtime-async-std-native-tls"] +runtime-async-std-rustls = ["sea-orm/runtime-async-std-rustls", "sea-schema/runtime-async-std-rustls", "sea-orm-migration/runtime-async-std-rustls"] +sqlx-mysql = ["sea-orm/sqlx-mysql", "sea-schema/sqlx-mysql", "sea-orm-migration/sqlx-mysql"] +sqlx-postgres = ["sea-orm/sqlx-postgres", "sea-schema/sqlx-postgres", "sea-orm-migration/sqlx-postgres"] +sqlx-sqlite = ["sea-orm/sqlx-sqlite", "sea-schema/sqlx-sqlite", "sea-orm-migration/sqlx-sqlite"] + +[dev-dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tempfile = "3" diff --git a/sea-orm-entity/src/cli.rs b/sea-orm-entity/src/cli.rs new file mode 100644 index 0000000000..41782f0bfb --- /dev/null +++ b/sea-orm-entity/src/cli.rs @@ -0,0 +1,306 @@ +use chrono::Utc; +use clap::{Parser, Subcommand}; +use dotenvy::dotenv; +use sea_orm::{ConnectOptions, Database, DbBackend, Schema}; +use sea_orm_migration::MigratorTraitSelf; +use tracing_subscriber::{EnvFilter, prelude::*}; + +use crate::filter::filter_protected_drops; +use crate::{EntitySet, codegen::MigrationMetadata, fs::write_migration, summary::summarize}; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +//TODO: Move this to cli later +#[derive(Parser)] +#[command( + name = "entity", + about = "Entity-first migration tool for SeaORM", + version +)] +struct Cli { + #[arg(short = 'v', long, global = true, help = "Show debug messages")] + verbose: bool, + + #[arg( + global = true, + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL" + )] + database_url: Option, + + #[arg( + global = true, + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL and SQLite, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'.\n" + )] + database_schema: Option, + + #[command(subcommand)] + command: Option, +} + +//TODO: Move this to cli later +#[derive(Subcommand)] +enum Commands { + /// Generate a migration file from entity definitions by diffing against the database. + /// + /// By default uses the live database provided via --database-url or env. + /// In the future, an --ephemeral flag will instead spin up an in-memory database, apply existing + /// migrations, and discover changes without a live connection. + Generate { + /// Path to the migration crate directory + #[arg(long, default_value = "../migration")] + migration_dir: String, + + /// Name for the migration (like `add_users`) + #[arg(required = true, help = "Name of the new migration")] + name: String, + + #[arg( + long, + default_value = "true", + help = "Generate migration file based on Utc time", + conflicts_with = "local_time", + display_order = 1001 + )] + universal_time: bool, + + #[arg( + long, + help = "Generate migration file based on Local time", + conflicts_with = "universal_time", + display_order = 1002 + )] + local_time: bool, + + /// Allow dangerous operations (e.g. dropping tables) + #[arg(long, default_value_t = true)] + allow_dangerous: bool, + // Future: --ephemeral flag will be added here for no-live-db gen + }, + #[command( + about = "Drop all tables from the database, then reapply all migrations", + display_order = 30 + )] + Fresh, + #[command( + about = "Rollback all applied migrations, then reapply all migrations", + display_order = 40 + )] + Refresh, + #[command(about = "Rollback all applied migrations", display_order = 50)] + Reset, + #[command(about = "Check the status of all migrations", display_order = 60)] + Status, + #[command(about = "Apply pending migrations", display_order = 70)] + Up { + #[arg(short, long, help = "Number of pending migrations to apply")] + num: Option, + }, + #[command(about = "Rollback applied migrations", display_order = 80)] + Down { + #[arg( + short, + long, + default_value = "1", + help = "Number of applied migrations to be rolled back", + display_order = 90 + )] + num: u32, + }, +} + +/// Run the entity CLI with the given entity set and migrator +/// +/// Call this from your entity crate's `main.rs`: +/// +/// ```rust,ignore +/// #[tokio::main] +/// async fn main() { +/// sea_orm_entity::cli::run_cli(Entities, migration::Migrator).await; +/// } +/// ``` +pub async fn run_cli(entity_set: E, migrator: M) +where + E: EntitySet, + M: MigratorTraitSelf, +{ + dotenv().ok(); + let cli = Cli::parse(); + + let url = cli + .database_url + .expect("Environment variable 'DATABASE_URL' not set"); + let schema = cli.database_schema.unwrap_or_else(|| "public".to_owned()); + let verbose = cli.verbose; + + match cli.command { + Some(Commands::Generate { + migration_dir, + name, + local_time, + universal_time: _, + allow_dangerous, + }) => { + // Extract the migration tracker table name so we never generate DROP statements for it. + let migration_table = migrator.migration_table_name().to_string(); + + println!("Connecting to database..."); + let db = Database::connect( + ConnectOptions::new(url) + .set_schema_search_path(schema) + .to_owned(), + ) + .await + .expect("Failed to connect to database"); + + // Future: when --ephemeral is added, build an in-memory db here instead, + // apply existing migrations via the migrator, then pass it to generate_from_db. + + if let Err(e) = generate_from_db( + entity_set, + db, + &migration_dir, + &name, + local_time, + allow_dangerous, + &migration_table, + ) + .await + { + eprintln!("Error: {e}"); + std::process::exit(1); + } + } + + migrate_cmd => { + // All migration execution logic lives in MigratorTraitSelf (sea-orm-migration) + init_tracing(verbose); + let db = Database::connect( + ConnectOptions::new(url) + .set_schema_search_path(schema) + .to_owned(), + ) + .await + .expect("Failed to connect to database"); + + let result = match migrate_cmd { + Some(Commands::Up { num }) => migrator.up(&db, num).await, + Some(Commands::Down { num }) => migrator.down(&db, Some(num)).await, + Some(Commands::Fresh) => migrator.fresh(&db).await, + Some(Commands::Refresh) => migrator.refresh(&db).await, + Some(Commands::Reset) => migrator.reset(&db).await, + Some(Commands::Status) => migrator.status(&db).await, + // No subcommand: apply all pending migrations + None => migrator.up(&db, None).await, + Some(Commands::Generate { .. }) => unreachable!(), + }; + + if let Err(e) = result { + eprintln!("Error: {e}"); + std::process::exit(1); + } + } + } +} + +/// Core generation logic +async fn generate_from_db( + entity_set: E, + db: sea_orm::DatabaseConnection, + migration_dir: &str, + name: &str, + local_time: bool, + dangerous: bool, + protected_table: &str, +) -> Result<(), Box> { + if name.contains('-') { + return Err("`-` cannot be used in migration name".into()); + } + + let backend = db.get_database_backend(); + + let schema = Schema::new(backend); + let builder = entity_set.register(schema.builder()); + + println!("Discovering schema changes..."); + let raw = builder.discover(&db, dangerous).await?; + + // Never drop migration tracker table + let stmts = filter_protected_drops(raw, protected_table); + + if stmts.is_empty() { + println!("No schema changes detected. Migration file not generated"); + return Ok(()); + } + + let (timestamp, generated_at) = if local_time { + let now = chrono::Local::now(); + ( + now.format("%Y%m%d_%H%M%S").to_string(), + now.format("%Y-%m-%d %H:%M:%S %Z").to_string(), + ) + } else { + let now = Utc::now(); + ( + now.format("%Y%m%d_%H%M%S").to_string(), + now.format("%Y-%m-%d %H:%M:%S UTC").to_string(), + ) + }; + let name_clean = name.trim().replace(' ', "_"); + let migration_name = format!("m{timestamp}_{name_clean}"); + let backend_name = match backend { + DbBackend::MySql => "MySQL", + DbBackend::Postgres => "PostgreSQL", + DbBackend::Sqlite => "SQLite", + _ => "Unknown", + }; + let changes = summarize(&stmts); + let meta = MigrationMetadata { + version: VERSION, + generated_at: &generated_at, + backend: backend_name, + changes: &changes, + }; + + let filepath = write_migration(migration_dir, &migration_name, &stmts, &meta)?; + println!("Generated migration: {}", filepath.display()); + println!("Changes ({}):", changes.len()); + for change in &changes { + println!(" - {change}"); + } + + Ok(()) +} + +fn init_tracing(verbose: bool) { + let filter = if verbose { + "debug" + } else { + "sea_orm_migration=info" + }; + let filter_layer = EnvFilter::try_new(filter).unwrap(); + let fmt_layer = tracing_subscriber::fmt::layer(); + if verbose { + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .init(); + } else { + tracing_subscriber::registry() + .with(filter_layer) + .with( + fmt_layer + .with_target(false) + .with_level(false) + .without_time(), + ) + .init(); + } +} diff --git a/sea-orm-entity/src/codegen.rs b/sea-orm-entity/src/codegen.rs new file mode 100644 index 0000000000..c30c50bdb5 --- /dev/null +++ b/sea-orm-entity/src/codegen.rs @@ -0,0 +1,66 @@ +use sea_orm::Statement; + +pub struct MigrationMetadata<'a> { + pub version: &'a str, + pub generated_at: &'a str, + pub backend: &'a str, + pub changes: &'a [String], +} + +/// Render a complete migration `.rs` file from a list of SQL statements. +pub fn render_migration_file(stmts: &[Statement], meta: &MigrationMetadata<'_>) -> String { + let comment_header = render_comment_header(meta); + let up_body = render_up_body(stmts); + + format!( + r#"{comment_header} +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration {{ + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {{ + let db = manager.get_connection(); +{up_body} + Ok(()) + }} + + async fn down(&self, _manager: &SchemaManager) -> Result<(), DbErr> {{ + // TODO: implement down migration + todo!() + }} +}} +"# + ) +} + +/// Comments the header of the migration file, including version, timestamp, backend, and changes. +fn render_comment_header(meta: &MigrationMetadata<'_>) -> String { + let mut lines = vec![ + format!("// Generated by sea-orm-entity v{}", meta.version), + format!("// Generated at: {}", meta.generated_at), + format!("// Backend: {}", meta.backend), + "// Changes:".to_string(), + ]; + lines.extend(meta.changes.iter().map(|c| format!("// - {c}"))); + lines.join("\n") +} + +/// Renders the body of the up migration +fn render_up_body(stmts: &[Statement]) -> String { + // Just in case. I don't want to deal with this rn + stmts + .iter() + .map(|stmt| { + if stmt.values.is_some() { + let sql = stmt.sql.replace('\\', r"\\").replace('"', r#"\""#); + format!(r#" db.execute_unprepared("{sql}").await?;"#) + } else { + format!(" db.execute_unprepared(r#\"{}\"#).await?;", stmt.sql) + } + }) + .collect::>() + .join("\n") +} diff --git a/sea-orm-entity/src/filter.rs b/sea-orm-entity/src/filter.rs new file mode 100644 index 0000000000..fe3f2b8f0e --- /dev/null +++ b/sea-orm-entity/src/filter.rs @@ -0,0 +1,104 @@ +use sea_orm::Statement; + +/// Filter out DROP TABLE statements for `protected_table` from a list of discovered statements +pub fn filter_protected_drops(stmts: Vec, protected_table: &str) -> Vec { + let protected_upper = protected_table.to_uppercase(); + stmts + .into_iter() + .filter(|stmt| { + let upper = stmt.sql.to_uppercase(); + if upper.contains("DROP TABLE") { + !is_drop_of(upper.as_str(), &protected_upper) + } else { + true + } + }) + .collect() +} + +/// Returns true if `sql_upper` is a DROP TABLE statement targeting `table_upper`. +/// Handles all three quoting styles (double-quote, backtick, unquoted) and the +/// optional `IF EXISTS` clause so that no backend-specific variant slips through. +fn is_drop_of(sql_upper: &str, table_upper: &str) -> bool { + sql_upper.contains(&format!("\"{}\"", table_upper)) + || sql_upper.contains(&format!("`{}`", table_upper)) + || sql_upper.contains(&format!(" {} ", table_upper)) + || sql_upper.ends_with(&format!(" {}", table_upper)) +} + +#[cfg(test)] +mod tests { + use super::*; + use sea_orm::{DbBackend, Statement}; + + fn stmt(sql: &str) -> Statement { + Statement::from_string(DbBackend::Sqlite, sql.to_owned()) + } + + #[test] + fn test_removes_double_quoted_drop() { + let stmts = vec![ + stmt(r#"DROP TABLE IF EXISTS "seaql_migrations""#), + stmt(r#"DROP TABLE IF EXISTS "fruit""#), + ]; + let filtered = filter_protected_drops(stmts, "seaql_migrations"); + assert_eq!(filtered.len(), 1); + assert!(filtered[0].sql.contains("fruit")); + } + + #[test] + fn test_removes_backtick_quoted_drop() { + let stmts = vec![ + stmt("DROP TABLE IF EXISTS `seaql_migrations`"), + stmt("DROP TABLE IF EXISTS `cake`"), + ]; + let filtered = filter_protected_drops(stmts, "seaql_migrations"); + assert_eq!(filtered.len(), 1); + assert!(filtered[0].sql.contains("cake")); + } + + #[test] + fn test_removes_unquoted_drop() { + let stmts = vec![ + stmt("DROP TABLE IF EXISTS seaql_migrations"), + stmt("DROP TABLE IF EXISTS cake"), + ]; + let filtered = filter_protected_drops(stmts, "seaql_migrations"); + assert_eq!(filtered.len(), 1); + assert!(filtered[0].sql.contains("cake")); + } + + #[test] + fn test_does_not_remove_partial_name_match() { + let stmts = vec![stmt(r#"DROP TABLE IF EXISTS "seaql_migrations_old""#)]; + let filtered = filter_protected_drops(stmts, "seaql_migrations"); + assert_eq!(filtered.len(), 1, "partial name match must not be filtered"); + } + + #[test] + fn test_non_drop_stmts_pass_through() { + let stmts = vec![ + stmt(r#"CREATE TABLE "cake" ( "id" integer NOT NULL )"#), + stmt(r#"ALTER TABLE "fruit" ADD COLUMN "weight" integer"#), + ]; + let filtered = filter_protected_drops(stmts, "seaql_migrations"); + assert_eq!(filtered.len(), 2); + } + + #[test] + fn test_custom_migration_table_name() { + let stmts = vec![ + stmt(r#"DROP TABLE IF EXISTS "my_migrations""#), + stmt(r#"DROP TABLE IF EXISTS "cake""#), + ]; + let filtered = filter_protected_drops(stmts, "my_migrations"); + assert_eq!(filtered.len(), 1); + assert!(filtered[0].sql.contains("cake")); + } + + #[test] + fn test_empty_input() { + let filtered = filter_protected_drops(vec![], "seaql_migrations"); + assert!(filtered.is_empty()); + } +} diff --git a/sea-orm-entity/src/fs.rs b/sea-orm-entity/src/fs.rs new file mode 100644 index 0000000000..1137ba9773 --- /dev/null +++ b/sea-orm-entity/src/fs.rs @@ -0,0 +1,98 @@ +use regex::Regex; +use sea_orm::Statement; +use std::{error::Error, fs, path::PathBuf}; + +use crate::codegen::{MigrationMetadata, render_migration_file}; + +pub fn write_migration( + migration_dir: &str, + migration_name: &str, + stmts: &[Statement], + meta: &MigrationMetadata<'_>, +) -> Result> { + let filepath = write_migration_file(migration_dir, migration_name, stmts, meta)?; + update_migrator(migration_dir, migration_name)?; + Ok(filepath) +} + +fn get_full_migration_dir(migration_dir: &str) -> PathBuf { + let base = PathBuf::from(migration_dir); + let with_src = base.join("src"); + if with_src.is_dir() { with_src } else { base } +} + +fn get_migrator_filepath(migration_dir: &str) -> PathBuf { + let full_dir = get_full_migration_dir(migration_dir); + let with_lib = full_dir.join("lib.rs"); + if with_lib.is_file() { + with_lib + } else { + full_dir.join("mod.rs") + } +} + +fn write_migration_file( + migration_dir: &str, + migration_name: &str, + stmts: &[Statement], + meta: &MigrationMetadata<'_>, +) -> Result> { + let filepath = get_full_migration_dir(migration_dir).join(format!("{migration_name}.rs")); + println!("Creating migration file `{}`", filepath.display()); + let content = render_migration_file(stmts, meta); + fs::write(&filepath, content.as_bytes())?; + Ok(filepath) +} + +fn update_migrator(migration_dir: &str, migration_name: &str) -> Result<(), Box> { + let migrator_filepath = get_migrator_filepath(migration_dir); + println!( + "Adding migration `{migration_name}` to `{}`", + migrator_filepath.display() + ); + let original = fs::read_to_string(&migrator_filepath)?; + + // Find existing mod declarations and get insertion index for a new one + let mod_regex = Regex::new(r"mod\s+(?Pm\d{8}_\d{6}_\w+);")?; + let mods: Vec<_> = mod_regex.captures_iter(&original).collect(); + let insert_pos = if let Some(last_match) = mods.last() { + last_match.get(0).unwrap().end() + 1 + } else { + // Insert at the beginning of the file (before `pub struct Migrator`) + original.find("pub struct").unwrap_or(original.len()) + }; + + // Insert the new mod declaration. + let new_mod_decl = if mods.is_empty() { + //When inserting before the struct, add a blank line to look nicer + format!("mod {migration_name};\n\n") + } else { + format!("mod {migration_name};\n") + }; + let mut updated = original.clone(); + updated.insert_str(insert_pos, &new_mod_decl); + + // Rebuild the migrations vec + let mut migrations: Vec<&str> = mods + .iter() + .map(|cap| cap.name("name").unwrap().as_str()) + .collect(); + migrations.push(migration_name); + let boxed = migrations + .iter() + .map(|m| format!(" Box::new({m}::Migration),")) + .collect::>() + .join("\n") + + "\n"; + let new_vec = format!("vec![\n{boxed} ]"); + + // Match both empty vec![] and vec![...] + let vec_regex = Regex::new(r"vec!\[[\s\S]*?\]")?; + let updated = vec_regex.replace(&updated, new_vec.as_str()); + + // write to a temp file beside the target, then rename + let tmp_path = migrator_filepath.with_extension("rs.tmp"); + fs::write(&tmp_path, updated.as_bytes())?; + fs::rename(&tmp_path, &migrator_filepath)?; + Ok(()) +} diff --git a/sea-orm-entity/src/lib.rs b/sea-orm-entity/src/lib.rs new file mode 100644 index 0000000000..04a42710c6 --- /dev/null +++ b/sea-orm-entity/src/lib.rs @@ -0,0 +1,26 @@ +pub mod cli; +pub mod codegen; +pub mod filter; +pub mod fs; +pub mod summary; + +pub use sea_orm::schema::SchemaBuilder; + +/// Trait for a set of entities to be registered into a [`SchemaBuilder`]. +/// +/// Implement this on a unit struct in your entity crate: +/// +/// ```rust,ignore +/// pub struct Entities; +/// +/// impl sea_orm_entity::EntitySet for Entities { +/// fn register(self, builder: sea_orm_entity::SchemaBuilder) -> sea_orm_entity::SchemaBuilder { +/// builder +/// .register(user::Entity) +/// .register(post::Entity) +/// } +/// } +/// ``` +pub trait EntitySet { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder; +} diff --git a/sea-orm-entity/src/summary.rs b/sea-orm-entity/src/summary.rs new file mode 100644 index 0000000000..67413edeb2 --- /dev/null +++ b/sea-orm-entity/src/summary.rs @@ -0,0 +1,114 @@ +use sea_orm::Statement; + +/// Parse a list of SQL statements into hman-readable descriptions +pub fn summarize(stmts: &[Statement]) -> Vec { + stmts.iter().map(|s| describe(&s.sql)).collect() +} + +fn describe(sql: &str) -> String { + let upper = sql.to_uppercase(); + let sql = sql.trim(); + + if upper.contains("CREATE TABLE") { + if let Some(name) = extract_after(sql, &upper, "CREATE TABLE", Some("IF NOT EXISTS")) { + return format!("Created table: {name}"); + } + } + + if upper.contains("ALTER TABLE") { + let table = extract_after(sql, &upper, "ALTER TABLE", None); + if upper.contains("ADD COLUMN") { + if let (Some(table), Some(col)) = ( + table.as_ref(), + extract_after(sql, &upper, "ADD COLUMN", None), + ) { + return format!("Added column: {table}.{col}"); + } + } else if upper.contains("DROP COLUMN") { + if let (Some(table), Some(col)) = ( + table.as_ref(), + extract_after(sql, &upper, "DROP COLUMN", None), + ) { + return format!("Dropped column: {table}.{col}"); + } + } else if upper.contains("ADD CONSTRAINT") { + if let Some(table) = table { + return format!("Added foreign key on: {table}"); + } + } else if upper.contains("DROP CONSTRAINT") { + if let Some(table) = table { + return format!("Dropped constraint on: {table}"); + } + } else if upper.contains("DROP FOREIGN KEY") { + if let Some(table) = table { + return format!("Dropped foreign key on: {table}"); + } + } + } + + if upper.contains("DROP TABLE") { + if let Some(name) = extract_after(sql, &upper, "DROP TABLE", Some("IF EXISTS")) { + return format!("Dropped table: {name}"); + } + } + + if upper.contains("CREATE INDEX") || upper.contains("CREATE UNIQUE INDEX") { + if let Some(pos) = upper.find(" ON ") { + let after = sql[pos + " ON ".len()..].trim_start(); + let table = extract_identifier(after); + let kind = if upper.contains("UNIQUE") { + "unique index" + } else { + "index" + }; + return format!("Added {kind} on: {table}"); + } + } + + if upper.contains("CREATE TYPE") { + return "Created enum type".to_string(); + } + + // Fallback: first 80 chars of SQL + if sql.len() > 80 { + format!("SQL: {}...", &sql[..80]) + } else { + format!("SQL: {sql}") + } +} + +fn extract_after(sql: &str, upper: &str, keyword: &str, skip: Option<&str>) -> Option { + let pos = upper.find(keyword)?; + let rest = sql[pos + keyword.len()..].trim_start(); + let rest_upper = &upper[pos + keyword.len()..]; + let rest_upper = rest_upper.trim_start(); + let rest = if let Some(skip) = skip { + if rest_upper.starts_with(skip) { + rest[skip.len()..].trim_start() + } else { + rest + } + } else { + rest + }; + Some(extract_identifier(rest)) +} + +fn extract_identifier(s: &str) -> String { + let s = s.trim(); + if s.starts_with('"') { + // Double-quoted identifier + let end = s[1..].find('"').unwrap_or(s.len() - 1); + s[1..end + 1].to_string() + } else if s.starts_with('`') { + // Backtick-quoted identifier (MySQL) + let end = s[1..].find('`').unwrap_or(s.len() - 1); + s[1..end + 1].to_string() + } else { + // Unquoted: take until whitespace or `(` + s.split(|c: char| c.is_whitespace() || c == '(') + .next() + .unwrap_or(s) + .to_string() + } +} From f905ad249a32a79b7e0def6add1a16047fd687c8 Mon Sep 17 00:00:00 2001 From: sinder Date: Fri, 20 Mar 2026 07:59:34 +0800 Subject: [PATCH 09/10] Appease Clippy and Taplo --- .github/workflows/rust.yml | 2 +- sea-orm-entity/Cargo.toml | 91 ++++++++++++++++++++++++++++---------- src/query/helper.rs | 2 +- src/schema/builder.rs | 26 +++++------ 4 files changed, 83 insertions(+), 38 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 297beddb1b..6401118e55 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -141,7 +141,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - uses: mozilla-actions/sccache-action@v0.0.9 - run: cargo install --locked taplo-cli - - run: taplo fmt --check + - run: taplo fmt --check compile: name: Compile (${{ matrix.label }}) diff --git a/sea-orm-entity/Cargo.toml b/sea-orm-entity/Cargo.toml index 5f5aaefe3f..71fe1c4c5d 100644 --- a/sea-orm-entity/Cargo.toml +++ b/sea-orm-entity/Cargo.toml @@ -2,15 +2,20 @@ # A separate workspace [package] -name = "sea-orm-entity" -version = "0.1.0" -edition = "2024" -description = "Entity-first migration generator for SeaORM" -license = "MIT OR Apache-2.0" -repository = "https://github.com/SeaQL/sea-orm" +description = "Entity-first migration generator for SeaORM" +edition = "2024" +license = "MIT OR Apache-2.0" +name = "sea-orm-entity" +repository = "https://github.com/SeaQL/sea-orm" rust-version = "1.85.0" +version = "0.1.0" [dependencies] +async-trait = { version = "0.1", default-features = false } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +clap = { version = "4.3", features = ["env", "derive"] } +dotenvy = { version = "0.15", default-features = false } +regex = { version = "1" } sea-orm = { path = "..", version = "~2.0.0-rc.37", features = ["schema-sync"] } sea-orm-migration = { path = "../sea-orm-migration", version = "~2.0.0-rc.37", default-features = false } sea-schema = { version = "0.17.0-rc", default-features = false, features = [ @@ -18,26 +23,66 @@ sea-schema = { version = "0.17.0-rc", default-features = false, features = [ "writer", "probe", ] } -async-trait = { version = "0.1", default-features = false } -chrono = { version = "0.4", default-features = false, features = ["clock"] } -regex = { version = "1" } -clap = { version = "4.3", features = ["env", "derive"] } -dotenvy = { version = "0.15", default-features = false } -tokio = { version = "1", features = ["rt-multi-thread", "macros"], optional = true } -tracing-subscriber = { version = "0.3", default-features = false, features = ["env-filter", "fmt"] } +tokio = { version = "1", features = [ + "rt-multi-thread", + "macros", +], optional = true } +tracing-subscriber = { version = "0.3", default-features = false, features = [ + "env-filter", + "fmt", +] } [features] default = ["runtime-tokio-native-tls", "sqlx-sqlite"] -runtime-tokio = ["tokio", "sea-orm/runtime-tokio", "sea-schema/runtime-tokio", "sea-orm-migration/runtime-tokio"] -runtime-tokio-native-tls = ["tokio", "sea-orm/runtime-tokio-native-tls", "sea-schema/runtime-tokio-native-tls", "sea-orm-migration/runtime-tokio-native-tls"] -runtime-tokio-rustls = ["tokio", "sea-orm/runtime-tokio-rustls", "sea-schema/runtime-tokio-rustls", "sea-orm-migration/runtime-tokio-rustls"] -runtime-async-std = ["sea-orm/runtime-async-std", "sea-schema/runtime-async-std", "sea-orm-migration/runtime-async-std"] -runtime-async-std-native-tls = ["sea-orm/runtime-async-std-native-tls", "sea-schema/runtime-async-std-native-tls", "sea-orm-migration/runtime-async-std-native-tls"] -runtime-async-std-rustls = ["sea-orm/runtime-async-std-rustls", "sea-schema/runtime-async-std-rustls", "sea-orm-migration/runtime-async-std-rustls"] -sqlx-mysql = ["sea-orm/sqlx-mysql", "sea-schema/sqlx-mysql", "sea-orm-migration/sqlx-mysql"] -sqlx-postgres = ["sea-orm/sqlx-postgres", "sea-schema/sqlx-postgres", "sea-orm-migration/sqlx-postgres"] -sqlx-sqlite = ["sea-orm/sqlx-sqlite", "sea-schema/sqlx-sqlite", "sea-orm-migration/sqlx-sqlite"] +runtime-async-std = [ + "sea-orm/runtime-async-std", + "sea-schema/runtime-async-std", + "sea-orm-migration/runtime-async-std", +] +runtime-async-std-native-tls = [ + "sea-orm/runtime-async-std-native-tls", + "sea-schema/runtime-async-std-native-tls", + "sea-orm-migration/runtime-async-std-native-tls", +] +runtime-async-std-rustls = [ + "sea-orm/runtime-async-std-rustls", + "sea-schema/runtime-async-std-rustls", + "sea-orm-migration/runtime-async-std-rustls", +] +runtime-tokio = [ + "tokio", + "sea-orm/runtime-tokio", + "sea-schema/runtime-tokio", + "sea-orm-migration/runtime-tokio", +] +runtime-tokio-native-tls = [ + "tokio", + "sea-orm/runtime-tokio-native-tls", + "sea-schema/runtime-tokio-native-tls", + "sea-orm-migration/runtime-tokio-native-tls", +] +runtime-tokio-rustls = [ + "tokio", + "sea-orm/runtime-tokio-rustls", + "sea-schema/runtime-tokio-rustls", + "sea-orm-migration/runtime-tokio-rustls", +] +sqlx-mysql = [ + "sea-orm/sqlx-mysql", + "sea-schema/sqlx-mysql", + "sea-orm-migration/sqlx-mysql", +] +sqlx-postgres = [ + "sea-orm/sqlx-postgres", + "sea-schema/sqlx-postgres", + "sea-orm-migration/sqlx-postgres", +] +sqlx-sqlite = [ + "sea-orm/sqlx-sqlite", + "sea-schema/sqlx-sqlite", + "sea-orm-migration/sqlx-sqlite", +] [dev-dependencies] -tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tempfile = "3" +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } diff --git a/src/query/helper.rs b/src/query/helper.rs index 7aef179692..c3b0971722 100644 --- a/src/query/helper.rs +++ b/src/query/helper.rs @@ -909,7 +909,7 @@ pub(crate) fn join_tbl_on_condition( foreign_keys: Identity, ) -> Condition { let mut cond = Condition::all(); - for (owner_key, foreign_key) in owner_keys.into_iter().zip(foreign_keys.into_iter()) { + for (owner_key, foreign_key) in owner_keys.into_iter().zip(foreign_keys) { cond = cond .add(Expr::col((from_tbl.clone(), owner_key)).equals((to_tbl.clone(), foreign_key))); } diff --git a/src/schema/builder.rs b/src/schema/builder.rs index f8a5a0e6a9..e37c943402 100644 --- a/src/schema/builder.rs +++ b/src/schema/builder.rs @@ -1,10 +1,13 @@ use super::{Schema, TopologicalSort}; use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement}; use sea_query::{ - ForeignKeyCreateStatement, Index, IndexCreateStatement, IntoIden, TableAlterStatement, - TableCreateStatement, TableName, TableRef, extension::postgres::TypeCreateStatement, + IndexCreateStatement, TableCreateStatement, TableName, TableRef, + extension::postgres::TypeCreateStatement, }; +#[cfg(feature = "schema-sync")] +use sea_query::{ForeignKeyCreateStatement, Index, IntoIden, TableAlterStatement}; + /// A schema builder that can take a registry of Entities and synchronize it with database. pub struct SchemaBuilder { helper: Schema, @@ -89,6 +92,8 @@ impl SchemaBuilder { /// This function fetches the changes needed to sync the database schema with this builder's entities. /// * `db` - The database connection to use for fetching existing table schema. /// * `allow_dangerous` - If `true`, changes will contain drop statements (drop tables, drop columns, drop constraints). + /// # Panics + /// if #[cfg(feature = "schema-sync")] #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] pub async fn discover(&self, db: &C, allow_dangerous: bool) -> Result, DbErr> @@ -133,16 +138,8 @@ impl SchemaBuilder { .iter() .any(|e| get_table_name(e.table.get_table_name()) == existing_name); if !in_entities { - let stmt = db_backend.build( - sea_query::Table::drop() - .table( - existing_table - .get_table_name() - .expect("table must have a name") - .clone(), - ) - .if_exists(), - ); + let stmt = + db_backend.build(sea_query::Table::drop().table(existing_name).if_exists()); changes.push(stmt); } } @@ -305,6 +302,8 @@ impl SchemaBuilder { } /// Stores the discovered schema from the database, including tables and enums +#[cfg(feature = "schema-sync")] +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] struct DiscoveredSchema { tables: Vec, enums: Vec, @@ -339,7 +338,7 @@ impl EntitySchemaInfo { Ok(()) } - // better to always compile this function + #[cfg(feature = "schema-sync")] fn discover_changes( &self, db_backend: DbBackend, @@ -640,6 +639,7 @@ fn get_table_name(table_ref: Option<&TableRef>) -> TableName { } } +#[cfg(feature = "schema-sync")] fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool { let a = a.get_foreign_key(); let b = b.get_foreign_key(); From 2816d5386e181ae8a81b14166387059a4cf36829 Mon Sep 17 00:00:00 2001 From: sinder Date: Sat, 9 May 2026 18:25:36 +0000 Subject: [PATCH 10/10] End of experiments --- .github/workflows/rust.yml | 2 +- .gitignore | 4 +- Cargo.toml | 5 +- sea-orm-cli/Cargo.toml | 35 +- sea-orm-cli/src/cli.rs | 229 +++- sea-orm-cli/src/commands/entity.rs | 298 ++++ sea-orm-cli/src/commands/generate.rs | 165 ++- sea-orm-cli/src/commands/migrate.rs | 320 +++-- sea-orm-cli/src/commands/mod.rs | 7 +- sea-orm-cli/src/commands/subprocess.rs | 230 ++++ sea-orm-codegen/src/entity/writer.rs | 155 ++- sea-orm-codegen/src/entity/writer/frontend.rs | 159 +++ sea-orm-codegen/tests/compact/imports.rs | 25 + .../tests/compact_with_schema_name/imports.rs | 25 + sea-orm-codegen/tests/dense/imports.rs | 23 + sea-orm-codegen/tests/expanded/imports.rs | 87 ++ .../expanded_with_schema_name/imports.rs | 90 ++ sea-orm-codegen/tests/frontend/imports.rs | 16 + .../tests/frontend_with_imports/imports.rs | 26 + .../frontend_with_schema_name/imports.rs | 16 + sea-orm-codegen/tests/with_seaography/mod.rs | 1 + sea-orm-macros/src/derives/attributes.rs | 1 - sea-orm-macros/src/derives/value_type.rs | 62 - .../src/derives/value_type_match.rs | 23 - sea-orm-migration/Cargo.toml | 30 +- sea-orm-migration/src/cli.rs | 156 ++- sea-orm-migration/src/codegen.rs | 66 + sea-orm-migration/src/entity_cli.rs | 670 +++++++++ sea-orm-migration/src/fs.rs | 98 ++ sea-orm-migration/src/lib.rs | 33 + sea-orm-migration/src/migrator.rs | 66 +- sea-orm-migration/src/migrator/exec.rs | 58 +- sea-orm-migration/src/migrator/with_self.rs | 141 +- sea-orm-migration/src/response.rs | 171 +++ sea-orm-migration/src/summary.rs | 122 ++ .../template/migration/README.md | 41 + .../template/migration/_Cargo.toml | 23 + .../template/migration/_gitignore | 1 + .../template/migration/src/lib.rs | 12 + .../src/m20220101_000001_create_table.rs | 33 + .../template/migration/src/main.rs | 6 + .../tests/common/entity_common/mod.rs | 233 ++++ .../m20250101_000001_create_cake_table.rs | 26 + .../m20250101_000002_create_fruit_table.rs | 44 + .../tests/common/entity_migration/mod.rs | 2 + .../tests/common/entity_migrator/default.rs | 15 + .../tests/common/entity_migrator/mod.rs | 1 + sea-orm-migration/tests/common/mod.rs | 7 + sea-orm-migration/tests/entity_first.rs | 865 ++++++++++++ sea-orm-migration/tests/main.rs | 5 - sea-orm-sync/Cargo.toml | 1 + .../tests/common/features/value_type.rs | 6 - sea-orm-sync/tests/derive_tests.rs | 66 - src/query/helper.rs | 2 +- src/schema/builder.rs | 556 +++----- src/schema/discover/changes.rs | 237 ++++ src/schema/discover/enum_.rs | 92 ++ src/schema/discover/interpret.rs | 510 +++++++ src/schema/discover/mod.rs | 80 ++ src/schema/discover/resolver.rs | 618 +++++++++ src/schema/discover/schema.rs | 109 ++ src/schema/discover/suggestion.rs | 35 + src/schema/discover/table.rs | 378 ++++++ src/schema/discover/warning.rs | 32 + src/schema/entity.rs | 274 ++++ src/schema/mod.rs | 42 + tests/common/features/schema.rs | 2 +- tests/common/features/value_type.rs | 6 - tests/common/fixtures.rs | 404 ++++++ tests/common/helpers.rs | 149 ++ tests/common/mod.rs | 2 + tests/derive_tests.rs | 66 - tests/schema_discover_tests.rs | 1201 +++++++++++++++++ tests/schema_sync_tests.rs | 1 - 74 files changed, 8815 insertions(+), 983 deletions(-) create mode 100644 sea-orm-cli/src/commands/entity.rs create mode 100644 sea-orm-cli/src/commands/subprocess.rs create mode 100644 sea-orm-codegen/tests/compact/imports.rs create mode 100644 sea-orm-codegen/tests/compact_with_schema_name/imports.rs create mode 100644 sea-orm-codegen/tests/dense/imports.rs create mode 100644 sea-orm-codegen/tests/expanded/imports.rs create mode 100644 sea-orm-codegen/tests/expanded_with_schema_name/imports.rs create mode 100644 sea-orm-codegen/tests/frontend/imports.rs create mode 100644 sea-orm-codegen/tests/frontend_with_imports/imports.rs create mode 100644 sea-orm-codegen/tests/frontend_with_schema_name/imports.rs create mode 100644 sea-orm-migration/src/codegen.rs create mode 100644 sea-orm-migration/src/entity_cli.rs create mode 100644 sea-orm-migration/src/fs.rs create mode 100644 sea-orm-migration/src/response.rs create mode 100644 sea-orm-migration/src/summary.rs create mode 100644 sea-orm-migration/template/migration/README.md create mode 100644 sea-orm-migration/template/migration/_Cargo.toml create mode 100644 sea-orm-migration/template/migration/_gitignore create mode 100644 sea-orm-migration/template/migration/src/lib.rs create mode 100644 sea-orm-migration/template/migration/src/m20220101_000001_create_table.rs create mode 100644 sea-orm-migration/template/migration/src/main.rs create mode 100644 sea-orm-migration/tests/common/entity_common/mod.rs create mode 100644 sea-orm-migration/tests/common/entity_migration/m20250101_000001_create_cake_table.rs create mode 100644 sea-orm-migration/tests/common/entity_migration/m20250101_000002_create_fruit_table.rs create mode 100644 sea-orm-migration/tests/common/entity_migration/mod.rs create mode 100644 sea-orm-migration/tests/common/entity_migrator/default.rs create mode 100644 sea-orm-migration/tests/common/entity_migrator/mod.rs create mode 100644 sea-orm-migration/tests/entity_first.rs create mode 100644 src/schema/discover/changes.rs create mode 100644 src/schema/discover/enum_.rs create mode 100644 src/schema/discover/interpret.rs create mode 100644 src/schema/discover/mod.rs create mode 100644 src/schema/discover/resolver.rs create mode 100644 src/schema/discover/schema.rs create mode 100644 src/schema/discover/suggestion.rs create mode 100644 src/schema/discover/table.rs create mode 100644 src/schema/discover/warning.rs create mode 100644 tests/common/fixtures.rs create mode 100644 tests/common/helpers.rs create mode 100644 tests/schema_discover_tests.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 297beddb1b..6401118e55 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -141,7 +141,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - uses: mozilla-actions/sccache-action@v0.0.9 - run: cargo install --locked taplo-cli - - run: taplo fmt --check + - run: taplo fmt --check compile: name: Compile (${{ matrix.label }}) diff --git a/.gitignore b/.gitignore index 7ec785cd4a..237c030705 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,7 @@ Cargo.lock .vscode .idea/* */.idea/* +.zed/* +*/.zed/* .env.local -.DS_Store \ No newline at end of file +.DS_Store diff --git a/Cargo.toml b/Cargo.toml index 31c30b0570..72affab0e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "sea-orm-macros", "sea-orm-codegen", "sea-orm-arrow"] +members = [".", "sea-orm-macros", "sea-orm-codegen", "sea-orm-arrow", "sea-orm-cli", "sea-orm-migration"] [package] authors = ["Chris Tsang "] @@ -238,5 +238,4 @@ with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-sqlx?/with-uuid"] # This allows us to develop using a local version of sea-query [patch.crates-io] -# sea-query = { path = "../sea-query" } -# sea-query = { git = "https://github.com/SeaQL/sea-query", branch = "master" } +sea-query = { git = "https://github.com/SeaQL/sea-query", branch = "master" } diff --git a/sea-orm-cli/Cargo.toml b/sea-orm-cli/Cargo.toml index db9fe1a6d5..c23337045f 100644 --- a/sea-orm-cli/Cargo.toml +++ b/sea-orm-cli/Cargo.toml @@ -1,5 +1,5 @@ -[workspace] -# A separate workspace +# [workspace] +# # A separate workspace [package] authors = [ @@ -24,13 +24,18 @@ name = "sea_orm_cli" path = "src/lib.rs" [[bin]] -name = "sea-orm-cli" -path = "src/bin/main.rs" +name = "sea-orm-cli" +path = "src/bin/main.rs" +required-features = ["cli", "codegen"] + +[[bin]] +name = "sea" +path = "src/bin/main.rs" required-features = ["cli", "codegen"] [[bin]] -name = "sea" -path = "src/bin/main.rs" +name = "sea-dev" +path = "src/bin/main.rs" required-features = ["cli", "codegen"] [dependencies] @@ -38,14 +43,15 @@ async-std = { version = "1.9", default-features = false, features = [ "attributes", "tokio1", ], optional = true } -chrono = { version = "0.4.20", default-features = false, features = [ - "clock", -] } +chrono = { version = "0.4.20", default-features = false, features = ["clock"] } clap = { version = "4.3", features = ["env", "derive"], optional = true } +colored = "2" dotenvy = { version = "0.15", default-features = false, optional = true } glob = { version = "0.3", default-features = false } indoc = "2.0.6" regex = { version = "1.11.2" } +serde = { version = "1", features = ["derive"] } +serde_json = { version = "1" } sea-orm-codegen = { version = "=2.0.0-rc.38", path = "../sea-orm-codegen", default-features = false, optional = true } sea-schema = { version = "0.17.0-rc.1", default-features = false, features = [ "discovery", @@ -78,11 +84,7 @@ default = [ "runtime-tokio-native-tls", ] postgres-vector = ["sea-schema/postgres-vector"] -sqlx-mysql = [ - "sqlx?/sqlx-mysql", - "sea-schema?/sqlx-mysql", - "sea-schema?/mysql", -] +sqlx-mysql = ["sqlx?/sqlx-mysql", "sea-schema?/sqlx-mysql", "sea-schema?/mysql"] sqlx-postgres = [ "sqlx?/sqlx-postgres", "sea-schema?/sqlx-postgres", @@ -94,9 +96,9 @@ sqlx-sqlite = [ "sea-schema?/sqlite", ] -runtime-actix = ["runtime-tokio"] +runtime-actix = ["runtime-tokio"] runtime-actix-native-tls = ["runtime-tokio-native-tls"] -runtime-actix-rustls = ["runtime-tokio-rustls"] +runtime-actix-rustls = ["runtime-tokio-rustls"] runtime-async-std = [ "async-std", @@ -129,4 +131,5 @@ runtime-tokio-rustls = [ # This allows us to develop using an overridden version of sea-query [patch.crates-io] # sea-query = { path = "../sea-query" } +sea-query = { git = "https://github.com/SeaQL/sea-query", branch = "master" } # sea-query = { git = "https://github.com/SeaQL/sea-query", branch = "master" } diff --git a/sea-orm-cli/src/cli.rs b/sea-orm-cli/src/cli.rs index 810a81ba66..dd493e347c 100644 --- a/sea-orm-cli/src/cli.rs +++ b/sea-orm-cli/src/cli.rs @@ -1,10 +1,15 @@ use clap::{ArgAction, ArgGroup, Parser, Subcommand, ValueEnum}; -#[cfg(feature = "codegen")] +use colored::Colorize; +#[cfg(feature = "cli")] use dotenvy::dotenv; use std::ffi::OsStr; +#[cfg(feature = "cli")] +use crate::commands::entity::{run_entity_init, run_entity_schema, run_entity_sync}; #[cfg(feature = "codegen")] -use crate::{handle_error, run_generate_command, run_migrate_command}; +use crate::run_generate_command; +#[cfg(feature = "cli")] +use crate::{handle_error, run_migrate_command}; #[derive(Parser, Debug)] #[command( @@ -63,6 +68,15 @@ pub enum Commands { #[command(subcommand)] command: GenerateSubcommands, }, + #[command( + about = "Entity-first workflow commands", + arg_required_else_help = true, + display_order = 15 + )] + Entity { + #[command(subcommand)] + command: EntitySubcommands, + }, #[command(about = "Migration related commands", display_order = 20)] Migrate { #[arg( @@ -163,6 +177,110 @@ pub enum MigrateSubcommands { }, } +#[derive(Subcommand, PartialEq, Eq, Debug)] +pub enum EntitySubcommands { + #[command( + about = "Diff entity definitions against the live database and interactively generate a migration", + display_order = 10 + )] + Sync { + // TODO: add a help message in case default value fails + #[arg( + short = 'd', + long, + env = "ENTITY_DIR", + help = "Path to the entity crate root directory", + default_value = "./entity" + )] + dir: String, + + #[arg( + long, + env = "MIGRATION_DIR", + help = "Path to the migration crate root directory", + default_value = "./migration" + )] + migration_dir: String, + + #[arg( + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL", + hide_env_values = true + )] + database_url: Option, + + #[arg( + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL and SQLite, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'.\n" + )] + database_schema: Option, + + /// Name for the generated migration (e.g. add_users). Prompted interactively if omitted. + #[arg(long, help = "Name for the generated migration (e.g. add_users)")] + name: Option, + + #[arg( + long, + default_value_t = true, + help = "Allow dangerous operations (e.g. dropping tables) in diff" + )] + allow_dangerous: bool, + + /// Pre-supply rename decisions: `table.old_col:new_col`. + /// May be repeated for multiple renames. If any unresolved rename is + /// not covered by a --rename flag the command will exit with an error. + #[arg(long = "rename", value_name = "TABLE.OLD:NEW")] + renames: Vec, + + /// Skip the Y/n confirmation prompt and generate the migration immediately. + #[arg(long, default_value_t = false)] + no_confirm: bool, + }, + + #[command( + about = "Preview the schema as defined by registered entities, without connecting to a database", + display_order = 15 + )] + Schema { + #[arg( + short = 'd', + long, + env = "ENTITY_DIR", + help = "Path to the entity crate root directory", + default_value = "./entity" + )] + dir: String, + + #[arg( + long, + default_value = "postgres", + help = "Database backend to render SQL for (postgres, mysql, sqlite)" + )] + database_backend: String, + }, + + #[command( + about = "Scaffold a new entity crate (not yet implemented)", + display_order = 20 + )] + Init { + #[arg( + short = 'd', + long, + env = "ENTITY_DIR", + help = "Path where the entity crate should be created", + default_value = "./entity" + )] + dir: String, + }, +} + #[derive(Subcommand, PartialEq, Eq, Debug)] pub enum GenerateSubcommands { #[command(about = "Generate entity")] @@ -392,6 +510,57 @@ pub enum GenerateSubcommands { )] er_diagram: bool, }, + #[command(about = "Preview the current database schema as SQL DDL statements")] + Schema { + #[arg( + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL", + hide_env_values = true + )] + database_url: String, + + #[arg( + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'." + )] + database_schema: Option, + + #[arg( + short = 't', + long, + value_delimiter = ',', + help = "Preview schema for specified tables only (comma separated)" + )] + tables: Vec, + + #[arg( + long, + value_delimiter = ',', + default_value = "seaql_migrations", + help = "Skip tables from schema preview (comma separated)" + )] + ignore_tables: Vec, + + #[arg( + long, + default_value = "1", + help = "The maximum amount of connections to use when connecting to the database." + )] + max_connections: u32, + + #[arg( + long, + default_value = "30", + long_help = "Acquire timeout in seconds of the connection used for schema discovery" + )] + acquire_timeout: u64, + }, } #[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum, Default)] @@ -424,7 +593,7 @@ fn is_deprecated_preserve_user_modifications_flag(arg: &OsStr) -> bool { /// Use this to build a local, version-controlled `sea-orm-cli` in dependent projects /// (see [example use case](https://github.com/SeaQL/sea-orm/discussions/1889)). -#[cfg(feature = "codegen")] +#[cfg(feature = "cli")] pub async fn main() { dotenv().ok(); @@ -435,17 +604,69 @@ pub async fn main() { let cli = Cli::parse(); if deprecated_preserve_user_modifications_flag_used { eprintln!( - "warning: `--preserve-user-modifications` is deprecated; use `--experimental-preserve-user-modifications` instead." + "{}: `--preserve-user-modifications` is deprecated; use `--experimental-preserve-user-modifications` instead.", + "warning".yellow().bold() ); } let verbose = cli.verbose; match cli.command { Commands::Generate { command } => { + #[cfg(feature = "codegen")] run_generate_command(command, verbose) .await .unwrap_or_else(handle_error); + #[cfg(not(feature = "codegen"))] + { + let _ = command; + eprintln!( + "{} `generate` requires the `codegen` feature.", + "Error:".red().bold() + ); + std::process::exit(1); + } } + Commands::Entity { command } => match command { + EntitySubcommands::Sync { + dir, + migration_dir, + database_url, + database_schema, + name, + allow_dangerous, + renames, + no_confirm, + } => { + if let Err(e) = run_entity_sync( + &dir, + &migration_dir, + name.as_deref(), + database_url.as_deref(), + database_schema.as_deref(), + allow_dangerous, + &renames, + no_confirm, + ) { + eprintln!("{} {e}", "Error:".red().bold()); + std::process::exit(1); + } + } + EntitySubcommands::Schema { + dir, + database_backend, + } => { + if let Err(e) = run_entity_schema(&dir, &database_backend) { + eprintln!("{} {e}", "Error:".red().bold()); + std::process::exit(1); + } + } + EntitySubcommands::Init { dir } => { + if let Err(e) = run_entity_init(&dir) { + eprintln!("{} {e}", "Error:".red().bold()); + std::process::exit(1); + } + } + }, Commands::Migrate { migration_dir, database_schema, diff --git a/sea-orm-cli/src/commands/entity.rs b/sea-orm-cli/src/commands/entity.rs new file mode 100644 index 0000000000..36c1912a75 --- /dev/null +++ b/sea-orm-cli/src/commands/entity.rs @@ -0,0 +1,298 @@ +//! Entity-first commands for sea-orm-cli. + +use std::error::Error; +use std::io; + +use colored::Colorize; + +use crate::commands::subprocess::{ + DiffData, GenerateData, SchemaData, manifest_path, run_subprocess_json, +}; + +pub fn run_entity_sync( + dir: &str, + migration_dir: &str, + name: Option<&str>, + database_url: Option<&str>, + database_schema: Option<&str>, + allow_dangerous: bool, + renames: &[String], + no_confirm: bool, +) -> Result<(), Box> { + let manifest = manifest_path(dir); + + let mut diff_args = vec!["diff"]; + if !allow_dangerous { + diff_args.push("--allow-dangerous=false"); + } + + let (_, diff) = + run_subprocess_json::(&manifest, &diff_args, database_url, database_schema) + .map_err(|e| format!("diff failed: {e}"))?; + + let decision = run_sync(diff, name, renames, no_confirm)?; + + match decision { + SyncDecision::Quit => { + println!("{}", "Aborted.".yellow()); + return Ok(()); + } + SyncDecision::Generate { + schema_hash, + renames: resolved_renames, + migration_name: gen_name, + } => { + let mut gen_args = vec![ + "generate".to_string(), + gen_name, + format!("--migration-dir={migration_dir}"), + format!("--schema-hash={schema_hash}"), + ]; + if !allow_dangerous { + gen_args.push("--allow-dangerous=false".to_string()); + } + for (table, old, new) in &resolved_renames { + gen_args.push(format!("--rename={table}.{old}:{new}")); + } + + let gen_args_ref: Vec<&str> = gen_args.iter().map(String::as_str).collect(); + + let (_, result) = run_subprocess_json::( + &manifest, + &gen_args_ref, + database_url, + database_schema, + ) + .map_err(|e| format!("generate failed: {e}"))?; + + print_generate_result(&result); + } + } + + Ok(()) +} + +pub fn run_entity_schema(dir: &str, database_backend: &str) -> Result<(), Box> { + let manifest = manifest_path(dir); + let backend_arg = format!("--database-backend={database_backend}"); + let args = ["schema", backend_arg.as_str()]; + let (_, data) = run_subprocess_json::(&manifest, &args, None, None) + .map_err(|e| format!("schema failed: {e}"))?; + for stmt in &data.statements { + println!("{stmt}"); + } + Ok(()) +} + +pub fn run_entity_init(_dir: &str) -> Result<(), Box> { + println!("Entity crate scaffolding is not yet implemented."); + Ok(()) +} + +enum SyncDecision { + Quit, + Generate { + schema_hash: String, + renames: Vec<(String, String, String)>, // (table, old, new) + migration_name: String, + }, +} + +fn run_sync( + diff: DiffData, + name: Option<&str>, + rename_flags: &[String], + no_confirm: bool, +) -> Result> { + if diff.statements.is_empty() { + println!("{}", "No schema changes detected. Nothing to migrate.".green()); + return Ok(SyncDecision::Quit); + } + + println!("{}", format!("Changes ({}):", diff.changes.len()).bold()); + for change in &diff.changes { + println!(" {} {change}", "-".yellow()); + } + + println!(); + println!("{}", format!("SQL statements ({}):", diff.statements.len()).bold()); + for stmt in &diff.statements { + println!(" {}", stmt.dimmed()); + } + + if !diff.warnings.is_empty() { + println!(); + println!("{}", format!("Warnings ({}):", diff.warnings.len()).yellow().bold()); + for w in &diff.warnings { + println!(" {} {}", format!("[{}]", w.kind).yellow(), w.message); + } + } + + if !diff.suggestions.is_empty() { + println!(); + println!("{}", format!("Suggestions ({}):", diff.suggestions.len()).blue().bold()); + for s in &diff.suggestions { + println!(" {} {}", format!("[{}]", s.kind).blue(), s.message); + } + } + + let mut rename_map: std::collections::HashMap<(String, String), String> = + std::collections::HashMap::new(); + for flag in rename_flags { + let (table_col, new) = flag + .split_once(':') + .ok_or_else(|| format!("invalid --rename value '{flag}': expected table.old:new"))?; + let (table, old) = table_col + .split_once('.') + .ok_or_else(|| format!("invalid --rename value '{flag}': expected table.old:new"))?; + rename_map.insert((table.to_string(), old.to_string()), new.to_string()); + } + + let has_rename_flags = !rename_flags.is_empty(); + let schema_hash = diff.schema_hash.clone(); + let mut resolved_renames: Vec<(String, String, String)> = Vec::new(); + + if !diff.unresolved.is_empty() { + println!(); + println!( + "{}", + format!("Unresolved renames ({}):", diff.unresolved.len()) + .yellow() + .bold() + ); + } + + for unresolved in &diff.unresolved { + let key = (unresolved.table.clone(), unresolved.removed.clone()); + + if let Some(new_col) = rename_map.get(&key) { + if !unresolved.candidates.contains(new_col) { + return Err(format!( + "--rename {}.{}:{} is invalid: '{}' is not among the candidates: {}", + unresolved.table, + unresolved.removed, + new_col, + new_col, + unresolved.candidates.join(", ") + ) + .into()); + } + resolved_renames.push(( + unresolved.table.clone(), + unresolved.removed.clone(), + new_col.clone(), + )); + } else if has_rename_flags { + return Err(format!( + "unresolved rename for {}.{} (candidates: {}): provide --rename={}.{}:", + unresolved.table, + unresolved.removed, + unresolved.candidates.join(", "), + unresolved.table, + unresolved.removed, + ) + .into()); + } else { + println!( + " Table {}: column {} was removed.", + format!("'{}'", unresolved.table).bold(), + format!("'{}'", unresolved.removed).yellow() + ); + println!(" {}", "Candidates for rename:".bold()); + for (i, c) in unresolved.candidates.iter().enumerate() { + println!(" {}) {}", i + 1, c.cyan()); + } + println!( + " {}) {}", + (unresolved.candidates.len() + 1).to_string().red(), + "drop (treat as a plain column drop)".red() + ); + + let choice = prompt_rename_choice(&unresolved.candidates)?; + if let Some(new_col) = choice { + resolved_renames.push(( + unresolved.table.clone(), + unresolved.removed.clone(), + new_col, + )); + } + } + } + + let migration_name = match name { + Some(n) => n.to_string(), + None => { + print!("{}", "Migration name (e.g. add_users): ".bold()); + io::Write::flush(&mut io::stdout())?; + let mut input = String::new(); + io::BufRead::read_line(&mut io::stdin().lock(), &mut input)?; + let input = input.trim().to_string(); + if input.is_empty() { + return Err("migration name cannot be empty".into()); + } + input + } + }; + + if !no_confirm { + print!( + "{}", + format!("Generate migration '{migration_name}'? [Y/n]: ").bold() + ); + io::Write::flush(&mut io::stdout())?; + let mut input = String::new(); + io::BufRead::read_line(&mut io::stdin().lock(), &mut input)?; + let input = input.trim().to_lowercase(); + if input == "n" || input == "no" { + return Ok(SyncDecision::Quit); + } + } + + Ok(SyncDecision::Generate { + schema_hash, + renames: resolved_renames, + migration_name, + }) +} + +fn prompt_rename_choice(candidates: &[String]) -> Result, Box> { + let drop_option = candidates.len() + 1; + loop { + print!("{}", format!(" Choice [1-{drop_option}]: ").bold()); + io::Write::flush(&mut io::stdout())?; + let mut input = String::new(); + io::BufRead::read_line(&mut io::stdin().lock(), &mut input)?; + let input = input.trim(); + match input.parse::() { + Ok(n) if n >= 1 && n <= candidates.len() => { + return Ok(Some(candidates[n - 1].clone())); + } + Ok(n) if n == drop_option => { + return Ok(None); + } + _ => { + println!( + " {}", + format!("Please enter a number between 1 and {drop_option}.").yellow() + ); + } + } + } +} + +fn print_generate_result(result: &GenerateData) { + println!(); + println!( + " {} {}", + "Migration generated:".green().bold(), + result.migration_name.bold() + ); + println!(" File: {}", result.filepath.dimmed()); + if !result.changes.is_empty() { + println!(" {}", format!("Changes ({}):", result.changes.len()).bold()); + for change in &result.changes { + println!(" {} {change}", "+".green()); + } + } + println!(); +} diff --git a/sea-orm-cli/src/commands/generate.rs b/sea-orm-cli/src/commands/generate.rs index 7e9c5916e2..89f71e8e39 100644 --- a/sea-orm-cli/src/commands/generate.rs +++ b/sea-orm-cli/src/commands/generate.rs @@ -1,10 +1,12 @@ use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands}; +use colored::Colorize; use core::time; use sea_orm_codegen::{ BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType, DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext, MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files, }; +use sea_schema::sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, SqliteQueryBuilder}; use std::{error::Error, fs, path::Path, process::Command, str::FromStr}; use tracing_subscriber::{EnvFilter, prelude::*}; use url::Url; @@ -129,7 +131,7 @@ pub async fn run_generate_command( use sea_schema::mysql::discovery::SchemaDiscovery; use sqlx::MySql; - println!("Connecting to MySQL ..."); + println!("{}", "Connecting to MySQL ...".cyan()); let connection = sqlx_connect::( max_connections, acquire_timeout, @@ -137,7 +139,7 @@ pub async fn run_generate_command( None, ) .await?; - println!("Discovering schema ..."); + println!("{}", "Discovering schema ...".cyan().dimmed()); let schema_discovery = SchemaDiscovery::new(connection, _database_name); let schema = schema_discovery.discover().await?; let table_stmts = schema @@ -161,7 +163,7 @@ pub async fn run_generate_command( use sea_schema::sqlite::discovery::SchemaDiscovery; use sqlx::Sqlite; - println!("Connecting to SQLite ..."); + println!("{}", "Connecting to SQLite ...".cyan()); let connection = sqlx_connect::( max_connections, acquire_timeout, @@ -169,7 +171,7 @@ pub async fn run_generate_command( None, ) .await?; - println!("Discovering schema ..."); + println!("{}", "Discovering schema ...".cyan().dimmed()); let schema_discovery = SchemaDiscovery::new(connection); let schema = schema_discovery .discover() @@ -196,7 +198,7 @@ pub async fn run_generate_command( use sea_schema::postgres::discovery::SchemaDiscovery; use sqlx::Postgres; - println!("Connecting to Postgres ..."); + println!("{}", "Connecting to Postgres ...".cyan()); let schema = database_schema.as_deref().unwrap_or("public"); let connection = sqlx_connect::( max_connections, @@ -205,7 +207,7 @@ pub async fn run_generate_command( Some(schema), ) .await?; - println!("Discovering schema ..."); + println!("{}", "Discovering schema ...".cyan().dimmed()); let schema_discovery = SchemaDiscovery::new(connection, schema); let schema = schema_discovery.discover().await?; let table_stmts = schema @@ -221,7 +223,7 @@ pub async fn run_generate_command( } _ => unimplemented!("{} is not supported", url.scheme()), }; - println!("... discovered."); + println!("{}", "... discovered.".green()); let writer_context = EntityWriterContext::new( if expanded_format { @@ -260,7 +262,7 @@ pub async fn run_generate_command( let diagram = entity_writer.generate_er_diagram(); let diagram_path = dir.join("entities.mermaid"); fs::write(&diagram_path, &diagram)?; - println!("Writing {}", diagram_path.display()); + println!("Writing {}", diagram_path.display().to_string().dimmed()); } let output = entity_writer.generate(&writer_context); @@ -269,7 +271,7 @@ pub async fn run_generate_command( for OutputFile { name, content } in output.files.iter() { let file_path = dir.join(name); - println!("Writing {}", file_path.display()); + println!("Writing {}", file_path.display().to_string().dimmed()); if !matches!( name.as_str(), @@ -288,7 +290,7 @@ pub async fn run_generate_command( fallback_applied, }) => { for message in warnings { - eprintln!("{message}"); + eprintln!("{}", message.yellow()); } fs::write(file_path, output)?; if fallback_applied { @@ -311,7 +313,7 @@ pub async fn run_generate_command( } if merge_fallback_files.is_empty() { - println!("... Done."); + println!("{}", "... Done.".green().bold()); } else { return Err(format!( "Merge fallback applied for {} file(s): \n{}", @@ -321,6 +323,147 @@ pub async fn run_generate_command( .into()); } } + GenerateSubcommands::Schema { + database_url, + database_schema, + tables, + ignore_tables, + max_connections, + acquire_timeout, + } => { + let url = Url::parse(&database_url)?; + let is_sqlite = url.scheme() == "sqlite"; + + let filter_tables = + |table: &String| -> bool { tables.is_empty() || tables.contains(table) }; + let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) }; + + if !is_sqlite { + let database_name = url + .path_segments() + .unwrap_or_else(|| { + panic!( + "There is no database name as part of the url path: {}", + url.as_str() + ) + }) + .next() + .unwrap(); + if database_name.is_empty() { + panic!( + "There is no database name as part of the url path: {}", + url.as_str() + ); + } + } + + match url.scheme() { + "mysql" => { + #[cfg(not(feature = "sqlx-mysql"))] + { + panic!("mysql feature is off") + } + #[cfg(feature = "sqlx-mysql")] + { + use sea_schema::mysql::discovery::SchemaDiscovery; + use sqlx::MySql; + + let database_name = url.path_segments().unwrap().next().unwrap(); + println!("{}", "Connecting to MySQL ...".cyan()); + let connection = sqlx_connect::( + max_connections, + acquire_timeout, + url.as_str(), + None, + ) + .await?; + println!("{}", "Discovering schema ...".cyan().dimmed()); + let schema_discovery = SchemaDiscovery::new(connection, database_name); + let schema = schema_discovery.discover().await?; + let stmts: Vec<_> = schema + .tables + .into_iter() + .filter(|s| filter_tables(&s.info.name)) + .filter(|s| filter_skip_tables(&s.info.name)) + .map(|s| s.write()) + .collect(); + for stmt in stmts { + println!("{};", stmt.build(MysqlQueryBuilder)); + } + } + } + "sqlite" => { + #[cfg(not(feature = "sqlx-sqlite"))] + { + panic!("sqlite feature is off") + } + #[cfg(feature = "sqlx-sqlite")] + { + use sea_schema::sqlite::discovery::SchemaDiscovery; + use sqlx::Sqlite; + + println!("{}", "Connecting to SQLite ...".cyan()); + let connection = sqlx_connect::( + max_connections, + acquire_timeout, + url.as_str(), + None, + ) + .await?; + println!("{}", "Discovering schema ...".cyan().dimmed()); + let schema = SchemaDiscovery::new(connection) + .discover() + .await? + .merge_indexes_into_table(); + let stmts: Vec<_> = schema + .tables + .into_iter() + .filter(|s| filter_tables(&s.name)) + .filter(|s| filter_skip_tables(&s.name)) + .map(|s| s.write()) + .collect(); + for stmt in stmts { + println!("{};", stmt.build(SqliteQueryBuilder)); + } + } + } + "postgres" | "postgresql" => { + #[cfg(not(feature = "sqlx-postgres"))] + { + panic!("postgres feature is off") + } + #[cfg(feature = "sqlx-postgres")] + { + use sea_schema::postgres::discovery::SchemaDiscovery; + use sqlx::Postgres; + + let schema = database_schema.as_deref().unwrap_or("public"); + println!("{}", "Connecting to Postgres ...".cyan()); + let connection = sqlx_connect::( + max_connections, + acquire_timeout, + url.as_str(), + Some(schema), + ) + .await?; + println!("{}", "Discovering schema ...".cyan().dimmed()); + let schema_discovery = SchemaDiscovery::new(connection, schema); + let discovered = schema_discovery.discover().await?; + let stmts: Vec<_> = discovered + .tables + .into_iter() + .filter(|s| filter_tables(&s.info.name)) + .filter(|s| filter_skip_tables(&s.info.name)) + .map(|s| s.write()) + .collect(); + for stmt in stmts { + println!("{};", stmt.build(PostgresQueryBuilder)); + } + } + } + _ => unimplemented!("{} is not supported", url.scheme()), + } + } } Ok(()) diff --git a/sea-orm-cli/src/commands/migrate.rs b/sea-orm-cli/src/commands/migrate.rs index 11978b4dd8..6f21109cf6 100644 --- a/sea-orm-cli/src/commands/migrate.rs +++ b/sea-orm-cli/src/commands/migrate.rs @@ -1,4 +1,5 @@ use chrono::{Local, Utc}; +use colored::Colorize; use regex::Regex; use std::{ error::Error, @@ -6,11 +7,14 @@ use std::{ fs, io::Write, path::{Path, PathBuf}, - process::Command, }; #[cfg(feature = "cli")] use crate::MigrateSubcommands; +use crate::commands::subprocess::{ + AppliedData, LifecycleData, RolledBackData, StatusData, SubprocessError, manifest_path, + run_subprocess_json, +}; #[cfg(feature = "cli")] pub fn run_migrate_command( @@ -27,64 +31,237 @@ pub fn run_migrate_command( universal_time: _, local_time, }) => run_migrate_generate(migration_dir, &migration_name, !local_time)?, - _ => { - let (subcommand, migration_dir, steps, verbose) = match command { - Some(MigrateSubcommands::Fresh) => ("fresh", migration_dir, None, verbose), - Some(MigrateSubcommands::Refresh) => ("refresh", migration_dir, None, verbose), - Some(MigrateSubcommands::Reset) => ("reset", migration_dir, None, verbose), - Some(MigrateSubcommands::Status) => ("status", migration_dir, None, verbose), - Some(MigrateSubcommands::Up { num }) => ("up", migration_dir, num, verbose), - Some(MigrateSubcommands::Down { num }) => { - ("down", migration_dir, Some(num), verbose) - } - _ => ("up", migration_dir, None, verbose), - }; + cmd => run_migrate_json( + cmd, + migration_dir, + database_url.as_deref(), + database_schema.as_deref(), + verbose, + )?, + } - // Construct the `--manifest-path` - let manifest_path = if migration_dir.ends_with('/') { - format!("{migration_dir}Cargo.toml") - } else { - format!("{migration_dir}/Cargo.toml") - }; - // Construct the arguments that will be supplied to `cargo` command - let mut args = vec!["run", "--manifest-path", &manifest_path, "--", subcommand]; - let mut envs = vec![]; + Ok(()) +} + +fn run_migrate_json( + command: Option, + migration_dir: &str, + database_url: Option<&str>, + database_schema: Option<&str>, + _verbose: bool, +) -> Result<(), Box> { + let manifest = manifest_path(migration_dir); - let mut num: String = "".to_string(); - if let Some(steps) = steps { - num = steps.to_string(); + match command { + Some(MigrateSubcommands::Status) | None => { + match run_subprocess_json::( + &manifest, + &["status"], + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_status(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), + } + } + Some(MigrateSubcommands::Up { num }) => { + let mut args = vec!["up".to_string()]; + if let Some(n) = num { + args.push(format!("-n={n}")); } - if !num.is_empty() { - args.extend(["-n", &num]) + let args_ref: Vec<&str> = args.iter().map(String::as_str).collect(); + match run_subprocess_json::( + &manifest, + &args_ref, + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_applied(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), } - if let Some(database_url) = &database_url { - envs.push(("DATABASE_URL", database_url)); + } + Some(MigrateSubcommands::Down { num }) => { + let n_str = num.to_string(); + let args = ["down", "-n", &n_str]; + match run_subprocess_json::( + &manifest, + &args, + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_rolled_back(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), } - if let Some(database_schema) = &database_schema { - envs.push(("DATABASE_SCHEMA", database_schema)); + } + Some(MigrateSubcommands::Fresh) => { + match run_subprocess_json::( + &manifest, + &["fresh"], + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_applied(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), } - if verbose { - args.push("-v"); + } + Some(MigrateSubcommands::Refresh) => { + match run_subprocess_json::( + &manifest, + &["refresh"], + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_lifecycle(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), } - // Run migrator CLI on user's behalf - println!("Running `cargo {}`", args.join(" ")); - let exit_status = Command::new("cargo").args(args).envs(envs).status()?; // Get the status code - if !exit_status.success() { - // Propagate the error if any - return Err("Fail to run migration".into()); + } + Some(MigrateSubcommands::Reset) => { + match run_subprocess_json::( + &manifest, + &["reset"], + database_url, + database_schema, + ) { + Ok((_, data)) => { + print_rolled_back(&data); + return Ok(()); + } + Err(e) => return Err(render_subprocess_error(e).into()), } } + Some(MigrateSubcommands::Init) | Some(MigrateSubcommands::Generate { .. }) => { + unreachable!("init/generate handled before reaching this function") + } } +} - Ok(()) +fn render_subprocess_error(e: SubprocessError) -> String { + match e { + SubprocessError::VersionMismatch { expected, got } => { + format!( + "Version mismatch: CLI is {expected} but crate returned {got}.\n \ + Rebuild your migration crate with a matching sea-orm-migration version." + ) + } + other => other.to_string(), + } +} + +fn print_status(data: &StatusData) { + println!(); + if data.migrations.is_empty() { + println!(" No migrations found."); + } else { + let name_w = data + .migrations + .iter() + .map(|m| m.name.len()) + .max() + .unwrap_or(10); + println!( + " {}", + format!("{: Result<(), Box> { let migration_dir = match migration_dir.ends_with('/') { true => migration_dir.to_string(), false => format!("{migration_dir}/"), }; - println!("Initializing migration directory..."); + println!("{}", "Initializing migration directory...".cyan()); macro_rules! write_file { ($filename: literal) => { let fn_content = |content: String| content; @@ -96,7 +273,7 @@ pub fn run_migrate_init(migration_dir: &str) -> Result<(), Box> { }; ($filename: literal, $template: literal, $fn_content: expr) => { let filepath = [&migration_dir, $filename].join(""); - println!("Creating file `{}`", filepath); + println!("Creating file `{}`", filepath.dimmed()); let path = Path::new(&filepath); let prefix = path.parent().unwrap(); fs::create_dir_all(prefix).unwrap(); @@ -121,7 +298,7 @@ pub fn run_migrate_init(migration_dir: &str) -> Result<(), Box> { if glob::glob(&format!("{migration_dir}**/.git"))?.count() > 0 { write_file!(".gitignore", "_gitignore"); } - println!("Done!"); + println!("{}", "Done!".green().bold()); Ok(()) } @@ -131,17 +308,14 @@ pub fn run_migrate_generate( migration_name: &str, universal_time: bool, ) -> Result<(), Box> { - // Make sure the migration name doesn't contain any characters that - // are invalid module names in Rust. if migration_name.contains('-') { return Err(Box::new(MigrationCommandError::InvalidName( "Hyphen `-` cannot be used in migration name".to_string(), ))); } - println!("Generating new migration..."); + println!("{}", "Generating new migration...".cyan()); - // build new migration filename const FMT: &str = "%Y%m%d_%H%M%S"; let formatted_now = if universal_time { Utc::now().format(FMT) @@ -158,15 +332,6 @@ pub fn run_migrate_generate( Ok(()) } -/// `get_full_migration_dir` looks for a `src` directory -/// inside of `migration_dir` and appends that to the returned path if found. -/// -/// Otherwise, `migration_dir` can point directly to a directory containing the -/// migrations. In that case, nothing is appended. -/// -/// This way, `src` doesn't need to be appended in the standard case where -/// migrations are in their own crate. If the migrations are in a submodule -/// of another crate, `migration_dir` can point directly to that module. fn get_full_migration_dir(migration_dir: &str) -> PathBuf { let without_src = Path::new(migration_dir).to_owned(); let with_src = without_src.join("src"); @@ -176,11 +341,22 @@ fn get_full_migration_dir(migration_dir: &str) -> PathBuf { } } +fn get_migrator_filepath(migration_dir: &str) -> PathBuf { + let full_migration_dir = get_full_migration_dir(migration_dir); + let with_lib = full_migration_dir.join("lib.rs"); + match () { + _ if with_lib.is_file() => with_lib, + _ => full_migration_dir.join("mod.rs"), + } +} + fn create_new_migration(migration_name: &str, migration_dir: &str) -> Result<(), Box> { let migration_filepath = get_full_migration_dir(migration_dir).join(format!("{}.rs", &migration_name)); - println!("Creating migration file `{}`", migration_filepath.display()); - // TODO: make OS agnostic + println!( + "Creating migration file `{}`", + migration_filepath.display().to_string().dimmed() + ); let migration_template = include_str!("../../template/migration/src/m20220101_000001_create_table.rs"); let mut migration_file = fs::File::create(migration_filepath)?; @@ -188,43 +364,20 @@ fn create_new_migration(migration_name: &str, migration_dir: &str) -> Result<(), Ok(()) } -/// `get_migrator_filepath` looks for a file `migration_dir/src/lib.rs` -/// and returns that path if found. -/// -/// If `src` is not found, it will look directly in `migration_dir` for `lib.rs`. -/// -/// If `lib.rs` is not found, it will look for `mod.rs` instead, -/// e.g. `migration_dir/mod.rs`. -/// -/// This way, `src` doesn't need to be appended in the standard case where -/// migrations are in their own crate (with a file `lib.rs`). If the -/// migrations are in a submodule of another crate (with a file `mod.rs`), -/// `migration_dir` can point directly to that module. -fn get_migrator_filepath(migration_dir: &str) -> PathBuf { - let full_migration_dir = get_full_migration_dir(migration_dir); - let with_lib = full_migration_dir.join("lib.rs"); - match () { - _ if with_lib.is_file() => with_lib, - _ => full_migration_dir.join("mod.rs"), - } -} - fn update_migrator(migration_name: &str, migration_dir: &str) -> Result<(), Box> { let migrator_filepath = get_migrator_filepath(migration_dir); println!( "Adding migration `{}` to `{}`", - migration_name, - migrator_filepath.display() + migration_name.cyan(), + migrator_filepath.display().to_string().dimmed() ); let migrator_content = fs::read_to_string(&migrator_filepath)?; let mut updated_migrator_content = migrator_content.clone(); - // create a backup of the migrator file in case something goes wrong let migrator_backup_filepath = migrator_filepath.with_extension("rs.bak"); fs::copy(&migrator_filepath, &migrator_backup_filepath)?; let mut migrator_file = fs::File::create(&migrator_filepath)?; - // find existing mod declarations, add new line let mod_regex = Regex::new(r"mod\s+(?Pm\d{8}_\d{6}_\w+);")?; let mods: Vec<_> = mod_regex.captures_iter(&migrator_content).collect(); let mods_end = if let Some(last_match) = mods.last() { @@ -234,7 +387,6 @@ fn update_migrator(migration_name: &str, migration_dir: &str) -> Result<(), Box< }; updated_migrator_content.insert_str(mods_end, format!("mod {migration_name};\n").as_str()); - // build new vector from declared migration modules let mut migrations: Vec<&str> = mods .iter() .map(|cap| cap.name("name").unwrap().as_str()) diff --git a/sea-orm-cli/src/commands/mod.rs b/sea-orm-cli/src/commands/mod.rs index 4ef1dd2e64..9c96b35459 100644 --- a/sea-orm-cli/src/commands/mod.rs +++ b/sea-orm-cli/src/commands/mod.rs @@ -1,8 +1,13 @@ use std::fmt::Display; +use colored::Colorize; + #[cfg(feature = "codegen")] pub mod generate; +#[cfg(feature = "cli")] +pub mod entity; pub mod migrate; +pub mod subprocess; #[cfg(feature = "codegen")] pub use generate::*; @@ -12,6 +17,6 @@ pub fn handle_error(error: E) where E: Display, { - eprintln!("{error}"); + eprintln!("{} {error}", "Error:".red().bold()); ::std::process::exit(1); } diff --git a/sea-orm-cli/src/commands/subprocess.rs b/sea-orm-cli/src/commands/subprocess.rs new file mode 100644 index 0000000000..0f0867e0b8 --- /dev/null +++ b/sea-orm-cli/src/commands/subprocess.rs @@ -0,0 +1,230 @@ +//! Shared helper for invoking user crates via `cargo run` and parsing their +//! JSON API responses. + +use std::{ + error::Error, + process::{Command, Stdio}, +}; + +use serde::Deserialize; +use serde::de::DeserializeOwned; + +// --------------------------------------------------------------------------- +// Mirror of sea_orm_migration::response types +// We re-declare them here so sea-orm-cli does not need to depend on +// sea-orm-migration as a library — the contract is the JSON wire format. +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub struct ApiResponse { + pub ok: bool, + pub error: Option, + pub meta: ApiMeta, + pub data: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ApiMeta { + pub version: String, + pub migrations_hash: Option, + pub schema_hash: Option, +} + +// --- entity-first --- + +#[derive(Debug, Deserialize)] +pub struct DiffData { + pub changes: Vec, + pub statements: Vec, + pub warnings: Vec, + pub suggestions: Vec, + pub unresolved: Vec, + pub schema_hash: String, +} + +#[derive(Debug, Deserialize)] +pub struct GenerateData { + pub migration_name: String, + pub filepath: String, + pub changes: Vec, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct WarningJson { + pub kind: String, + pub message: String, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct SuggestionJson { + pub kind: String, + pub message: String, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct UnresolvedRenameJson { + pub table: String, + pub removed: String, + pub candidates: Vec, +} + +/// Output of `schema` — entity-defined schema as SQL DDL, no DB connection needed. +#[derive(Debug, Deserialize)] +pub struct SchemaData { + pub statements: Vec, +} + +// --- migration-first --- + +#[derive(Debug, Deserialize)] +pub struct StatusData { + pub migrations: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct MigrationEntry { + pub name: String, + pub status: String, +} + +#[derive(Debug, Deserialize)] +pub struct AppliedData { + pub applied: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct RolledBackData { + pub rolled_back: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct LifecycleData { + pub rolled_back: Vec, + pub applied: Vec, +} + +/// Build the manifest path from a crate root directory. +pub fn manifest_path(dir: &str) -> String { + if dir.ends_with('/') { + format!("{dir}Cargo.toml") + } else { + format!("{dir}/Cargo.toml") + } +} + +/// Error from calling a user crate subprocess. +#[derive(Debug)] +pub enum SubprocessError { + /// The cargo invocation itself failed (non-zero exit, or IO error). + Spawn(String), + /// The subprocess produced no output on stdout. + NoOutput, + /// stdout was not valid JSON. + InvalidJson(String), + /// The JSON parsed but `ok = false`. + ApiError(ApiMeta, String), + /// `ok = true` but `data` was null. + MissingData, + /// The API version returned by the subprocess does not match what we expect. + VersionMismatch { expected: String, got: String }, +} + +impl std::fmt::Display for SubprocessError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Spawn(e) => write!(f, "Failed to run subprocess: {e}"), + Self::NoOutput => write!(f, "Subprocess produced no output"), + Self::InvalidJson(e) => write!(f, "Invalid JSON from subprocess: {e}"), + Self::ApiError(_, msg) => write!(f, "API error: {msg}"), + Self::MissingData => write!(f, "API returned ok=true but no data"), + Self::VersionMismatch { expected, got } => write!( + f, + "Version mismatch: CLI expects {expected}, subprocess returned {got}. \ + Rebuild your crate with the matching sea-orm-migration version." + ), + } + } +} + +impl Error for SubprocessError {} + +const EXPECTED_VERSION: &str = env!("CARGO_PKG_VERSION"); + +/// Run `cargo run --manifest-path -- ` with the given env +/// vars, capture stdout, parse as `ApiResponse`, and return the data. +/// +/// Performs a version check on `meta.version` and returns +/// [`SubprocessError::VersionMismatch`] if they differ. +pub fn run_subprocess_json( + manifest: &str, + args: &[&str], + env_database_url: Option<&str>, + env_database_schema: Option<&str>, +) -> Result<(ApiMeta, T), SubprocessError> { + let mut cmd = Command::new("cargo"); + cmd.args(["run", "--manifest-path", manifest, "--"]); + cmd.args(args); + cmd.stdout(Stdio::piped()); + cmd.stderr(Stdio::inherit()); + + if let Some(url) = env_database_url { + cmd.env("DATABASE_URL", url); + } + if let Some(schema) = env_database_schema { + cmd.env("DATABASE_SCHEMA", schema); + } + + if true { + // Append to any existing RUSTFLAGS rather than clobbering them + let existing = std::env::var("RUSTFLAGS").unwrap_or_default(); + let new_flags = if existing.is_empty() { + "-A warnings".to_string() + } else { + format!("{existing} -A warnings") + }; + cmd.env("RUSTFLAGS", new_flags); + } + + let output = cmd + .output() + .map_err(|e| SubprocessError::Spawn(e.to_string()))?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let line = stdout.lines().last().unwrap_or("").trim(); + + if line.is_empty() { + return Err(SubprocessError::NoOutput); + } + + // Parse the envelope first (without T) to get meta even on error + let raw: serde_json::Value = + serde_json::from_str(line).map_err(|e| SubprocessError::InvalidJson(e.to_string()))?; + + let meta_val = raw.get("meta").cloned().unwrap_or(serde_json::Value::Null); + let meta: ApiMeta = serde_json::from_value(meta_val) + .map_err(|e| SubprocessError::InvalidJson(e.to_string()))?; + + // Version check + if meta.version != EXPECTED_VERSION { + return Err(SubprocessError::VersionMismatch { + expected: EXPECTED_VERSION.to_string(), + got: meta.version, + }); + } + + let ok = raw.get("ok").and_then(|v| v.as_bool()).unwrap_or(false); + if !ok { + let error = raw + .get("error") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error") + .to_string(); + return Err(SubprocessError::ApiError(meta, error)); + } + + let data_val = raw.get("data").cloned().unwrap_or(serde_json::Value::Null); + let data: T = serde_json::from_value(data_val) + .map_err(|e| SubprocessError::InvalidJson(e.to_string()))?; + + Ok((meta, data)) +} diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index b94b9aa5e0..abcf4bf8e3 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -879,7 +879,10 @@ mod tests { use proc_macro2::TokenStream; use quote::quote; use sea_query::{Alias, ColumnType, ForeignKeyAction, RcOrArc, SeaRc, StringLen}; - use std::io::{self, BufRead, BufReader, Read}; + use std::{ + io::{self, BufRead, BufReader, Read}, + sync::Arc, + }; fn default_column_option() -> ColumnOption { Default::default() @@ -1576,6 +1579,106 @@ mod tests { name: "id".to_owned(), }], }, + Entity { + table_name: "imports".to_owned(), + columns: vec![ + Column { + name: "a".to_owned(), + col_type: ColumnType::Json, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "b".to_owned(), + col_type: ColumnType::Date, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "c".to_owned(), + col_type: ColumnType::Time, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "d".to_owned(), + col_type: ColumnType::DateTime, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "e".to_owned(), + col_type: ColumnType::TimestampWithTimeZone, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "f".to_owned(), + col_type: ColumnType::Decimal(None), + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "g".to_owned(), + col_type: ColumnType::Uuid, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "h".to_owned(), + col_type: ColumnType::Vector(None), + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "i".to_owned(), + col_type: ColumnType::Inet, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "j".to_owned(), + col_type: ColumnType::Array(Arc::new(ColumnType::Json)), + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "k".to_owned(), + col_type: ColumnType::Array(Arc::new(ColumnType::Array(Arc::new( + ColumnType::Cidr, + )))), + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + ], + relations: vec![], + conjunct_relations: vec![], + primary_keys: vec![PrimaryKey { + name: "a".to_owned(), + }], + }, ] } @@ -1618,7 +1721,7 @@ mod tests { #[test] fn test_gen_expanded_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 13] = [ + const ENTITY_FILES: [&str; 14] = [ include_str!("../../tests/expanded/cake.rs"), include_str!("../../tests/expanded/cake_filling.rs"), include_str!("../../tests/expanded/cake_filling_price.rs"), @@ -1632,8 +1735,9 @@ mod tests { include_str!("../../tests/expanded/collection_float.rs"), include_str!("../../tests/expanded/parent.rs"), include_str!("../../tests/expanded/child.rs"), + include_str!("../../tests/expanded/imports.rs"), ]; - const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 13] = [ + const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 14] = [ include_str!("../../tests/expanded_with_schema_name/cake.rs"), include_str!("../../tests/expanded_with_schema_name/cake_filling.rs"), include_str!("../../tests/expanded_with_schema_name/cake_filling_price.rs"), @@ -1647,6 +1751,7 @@ mod tests { include_str!("../../tests/expanded_with_schema_name/collection_float.rs"), include_str!("../../tests/expanded_with_schema_name/parent.rs"), include_str!("../../tests/expanded_with_schema_name/child.rs"), + include_str!("../../tests/expanded_with_schema_name/imports.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); @@ -1706,7 +1811,7 @@ mod tests { #[test] fn test_gen_compact_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 13] = [ + const ENTITY_FILES: [&str; 14] = [ include_str!("../../tests/compact/cake.rs"), include_str!("../../tests/compact/cake_filling.rs"), include_str!("../../tests/compact/cake_filling_price.rs"), @@ -1720,8 +1825,9 @@ mod tests { include_str!("../../tests/compact/collection_float.rs"), include_str!("../../tests/compact/parent.rs"), include_str!("../../tests/compact/child.rs"), + include_str!("../../tests/compact/imports.rs"), ]; - const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 13] = [ + const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 14] = [ include_str!("../../tests/compact_with_schema_name/cake.rs"), include_str!("../../tests/compact_with_schema_name/cake_filling.rs"), include_str!("../../tests/compact_with_schema_name/cake_filling_price.rs"), @@ -1735,6 +1841,7 @@ mod tests { include_str!("../../tests/compact_with_schema_name/collection_float.rs"), include_str!("../../tests/compact_with_schema_name/parent.rs"), include_str!("../../tests/compact_with_schema_name/child.rs"), + include_str!("../../tests/compact_with_schema_name/imports.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); @@ -1794,7 +1901,7 @@ mod tests { #[test] fn test_gen_frontend_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 13] = [ + const ENTITY_FILES: [&str; 14] = [ include_str!("../../tests/frontend/cake.rs"), include_str!("../../tests/frontend/cake_filling.rs"), include_str!("../../tests/frontend/cake_filling_price.rs"), @@ -1808,8 +1915,9 @@ mod tests { include_str!("../../tests/frontend/collection_float.rs"), include_str!("../../tests/frontend/parent.rs"), include_str!("../../tests/frontend/child.rs"), + include_str!("../../tests/frontend/imports.rs"), ]; - const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 13] = [ + const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 14] = [ include_str!("../../tests/frontend_with_schema_name/cake.rs"), include_str!("../../tests/frontend_with_schema_name/cake_filling.rs"), include_str!("../../tests/frontend_with_schema_name/cake_filling_price.rs"), @@ -1823,6 +1931,7 @@ mod tests { include_str!("../../tests/frontend_with_schema_name/collection_float.rs"), include_str!("../../tests/frontend_with_schema_name/parent.rs"), include_str!("../../tests/frontend_with_schema_name/child.rs"), + include_str!("../../tests/frontend_with_schema_name/imports.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); @@ -1879,6 +1988,35 @@ mod tests { Ok(()) } + #[test] + fn test_gen_frontend_imports() -> io::Result<()> { + let imports_entity = setup() + .into_iter() + .find(|e| e.get_table_name_snake_case() == "imports") + .unwrap(); + + assert_eq!(imports_entity.get_table_name_snake_case(), "imports"); + + assert_eq!( + comparable_file_string(include_str!("../../tests/frontend_with_imports/imports.rs"))?, + generated_to_string(EntityWriter::gen_frontend_code_blocks( + &imports_entity, + &WithSerde::None, + &default_column_option(), + &None, + true, + false, + &TokenStream::new(), + &TokenStream::new(), + &TokenStream::new(), + false, + true, + )) + ); + + Ok(()) + } + #[test] fn test_gen_with_serde() -> io::Result<()> { let cake_entity = setup().get(0).unwrap().clone(); @@ -3032,7 +3170,7 @@ mod tests { #[test] fn test_gen_dense_code_blocks() -> io::Result<()> { let entities = setup(); - const ENTITY_FILES: [&str; 13] = [ + const ENTITY_FILES: [&str; 14] = [ include_str!("../../tests/dense/cake.rs"), include_str!("../../tests/dense/cake_filling.rs"), include_str!("../../tests/dense/cake_filling_price.rs"), @@ -3046,6 +3184,7 @@ mod tests { include_str!("../../tests/dense/collection_float.rs"), include_str!("../../tests/dense/parent.rs"), include_str!("../../tests/dense/child.rs"), + include_str!("../../tests/dense/imports.rs"), ]; assert_eq!(entities.len(), ENTITY_FILES.len()); diff --git a/sea-orm-codegen/src/entity/writer/frontend.rs b/sea-orm-codegen/src/entity/writer/frontend.rs index 6b2fa2507c..7d262377df 100644 --- a/sea-orm-codegen/src/entity/writer/frontend.rs +++ b/sea-orm-codegen/src/entity/writer/frontend.rs @@ -1,5 +1,41 @@ +use sea_query::ColumnType; +use std::collections::HashSet; + use super::*; +// seperate enum so `ColumnType` doesnt need to derive `Hash` or `Eq` +#[derive(Hash, PartialEq, Eq)] +enum ExternalTypes { + JsonOrJsonBinary, + Date, + Time, + DateTime, + Timestamp, + TimestampWithTimeZone, + DecimalOrMoney, + Uuid, + Vector, + CidrOrInet, +} + +impl ExternalTypes { + fn from_column_type(col_type: &ColumnType) -> Option { + Some(match col_type { + ColumnType::Json | ColumnType::JsonBinary => Self::JsonOrJsonBinary, + ColumnType::Date => Self::Date, + ColumnType::Time => Self::Time, + ColumnType::DateTime => Self::DateTime, + ColumnType::Timestamp => Self::Timestamp, + ColumnType::TimestampWithTimeZone => Self::TimestampWithTimeZone, + ColumnType::Decimal(..) | ColumnType::Money(..) => Self::DecimalOrMoney, + ColumnType::Uuid => Self::Uuid, + ColumnType::Vector(..) => Self::Vector, + ColumnType::Cidr | ColumnType::Inet => Self::CidrOrInet, + _ => return None, + }) + } +} + impl EntityWriter { #[allow(clippy::too_many_arguments)] pub fn gen_frontend_code_blocks( @@ -17,6 +53,7 @@ impl EntityWriter { ) -> Vec { let mut imports = Self::gen_import_serde(with_serde); imports.extend(Self::gen_import_active_enum(entity)); + imports.extend(Self::gen_import_frontend(entity, column_option)); let code_blocks = vec![ imports, Self::gen_frontend_model_struct( @@ -77,4 +114,126 @@ impl EntityWriter { } } } + + pub fn gen_import_frontend(entity: &Entity, opt: &ColumnOption) -> TokenStream { + fn collect( + col_type: &ColumnType, + opt: &ColumnOption, + date_time: &mut Vec, + aliases: &mut Vec, + plain_uses: &mut Vec, + encountered: &mut HashSet, + ) { + // skip column types we have already generated imports for + if let Some(ty) = ExternalTypes::from_column_type(col_type) { + if !encountered.insert(ty) { + return; + } + } + + match col_type { + ColumnType::Json | ColumnType::JsonBinary => { + plain_uses.push(quote! { use serde_json::Value as Json; }); + } + ColumnType::Date => match opt.date_time_crate { + DateTimeCrate::Chrono => { + date_time.push(quote! { NaiveDate as Date }); + } + DateTimeCrate::Time => { + date_time.push(quote! { Date as TimeDate }); + } + }, + ColumnType::Time => match opt.date_time_crate { + DateTimeCrate::Chrono => { + date_time.push(quote! { NaiveTime as Time }); + } + DateTimeCrate::Time => { + date_time.push(quote! { Time as TimeTime }); + } + }, + ColumnType::DateTime => match opt.date_time_crate { + DateTimeCrate::Chrono => { + date_time.push(quote! { NaiveDateTime as DateTime }); + } + DateTimeCrate::Time => { + date_time.push(quote! { PrimitiveDateTime as TimeDateTime }); + } + }, + ColumnType::Timestamp => match opt.date_time_crate { + DateTimeCrate::Chrono => { + aliases.push(quote! { + type DateTimeUtc = chrono::DateTime; + }); + } + DateTimeCrate::Time => { + date_time.push(quote! { PrimitiveDateTime as TimeDateTime }); + } + }, + ColumnType::TimestampWithTimeZone => match opt.date_time_crate { + DateTimeCrate::Chrono => { + aliases.push(quote! { + type DateTimeWithTimeZone = chrono::DateTime; + }); + } + DateTimeCrate::Time => { + date_time.push(quote! { OffsetDateTime as TimeDateTimeWithTimeZone }); + } + }, + ColumnType::Decimal(_) | ColumnType::Money(_) => { + plain_uses.push(quote! { use rust_decimal::Decimal; }) + } + ColumnType::Uuid => { + plain_uses.push(quote! { use uuid::Uuid; }); + } + ColumnType::Vector(_) => { + plain_uses.push(quote! { use pgvector::Vector as PgVector; }); + } + ColumnType::Cidr | ColumnType::Inet => { + plain_uses.push(quote! { use ipnetwork::IpNetwork; }); + } + ColumnType::Array(inner) => { + collect( + inner.as_ref(), + opt, + date_time, + aliases, + plain_uses, + encountered, + ); + } + _ => {} + } + } + + let mut date_time_uses = Vec::new(); + let mut aliases = Vec::new(); + let mut plain_uses = Vec::new(); + let mut encountered = HashSet::new(); + + for col in &entity.columns { + collect( + &col.col_type, + opt, + &mut date_time_uses, + &mut aliases, + &mut plain_uses, + &mut encountered, + ); + } + + let time_use = if date_time_uses.is_empty() { + quote! {} + } else { + match opt.date_time_crate { + DateTimeCrate::Chrono => quote! { use chrono::{ #(#date_time_uses),* }; }, + DateTimeCrate::Time => quote! { use time::{ #(#date_time_uses),* }; }, + } + }; + + quote! { + #time_use + #(#plain_uses)* + #(#aliases)* + } + } } diff --git a/sea-orm-codegen/tests/compact/imports.rs b/sea-orm-codegen/tests/compact/imports.rs new file mode 100644 index 0000000000..fa6c356f2d --- /dev/null +++ b/sea-orm-codegen/tests/compact/imports.rs @@ -0,0 +1,25 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "imports")] +pub struct Model { + #[sea_orm(primary_key)] + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/compact_with_schema_name/imports.rs b/sea-orm-codegen/tests/compact_with_schema_name/imports.rs new file mode 100644 index 0000000000..93ff2ef927 --- /dev/null +++ b/sea-orm-codegen/tests/compact_with_schema_name/imports.rs @@ -0,0 +1,25 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(schema_name = "schema_name", table_name = "imports")] +pub struct Model { + #[sea_orm(primary_key)] + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/dense/imports.rs b/sea-orm-codegen/tests/dense/imports.rs new file mode 100644 index 0000000000..4f88bbe423 --- /dev/null +++ b/sea-orm-codegen/tests/dense/imports.rs @@ -0,0 +1,23 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[sea_orm::model] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "imports")] +pub struct Model { + #[sea_orm(primary_key)] + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/expanded/imports.rs b/sea-orm-codegen/tests/expanded/imports.rs new file mode 100644 index 0000000000..8f17a58e89 --- /dev/null +++ b/sea-orm-codegen/tests/expanded/imports.rs @@ -0,0 +1,87 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn table_name(&self) -> & 'static str { + "imports" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + A, + B, + C, + D, + E, + F, + G, + H, + I, + J, + K, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + A, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = Json; + fn auto_increment() -> bool { + true + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + fn def(&self) -> ColumnDef { + match self { + Self::A => ColumnType::Json.def(), + Self::B => ColumnType::Date.def(), + Self::C => ColumnType::Time.def(), + Self::D => ColumnType::DateTime.def(), + Self::E => ColumnType::TimestampWithTimeZone.def(), + Self::F => ColumnType::Decimal(None).def(), + Self::G => ColumnType::Uuid.def(), + Self::H => ColumnType::Vector(None).def(), + Self::I => ColumnType::Inet.def(), + Self::J => ColumnType::Array(RcOrArc::new(ColumnType::Json)).def(), + Self::K => ColumnType::Array(RcOrArc::new(ColumnType::Array(RcOrArc::new( + ColumnType::Cidr + )))) + .def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + panic!("No RelationDef") + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/expanded_with_schema_name/imports.rs b/sea-orm-codegen/tests/expanded_with_schema_name/imports.rs new file mode 100644 index 0000000000..b98d5b8426 --- /dev/null +++ b/sea-orm-codegen/tests/expanded_with_schema_name/imports.rs @@ -0,0 +1,90 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.10.0 + +use sea_orm::entity::prelude::*; + +#[derive(Copy, Clone, Default, Debug, DeriveEntity)] +pub struct Entity; + +impl EntityName for Entity { + fn schema_name(&self) -> Option< &str> { + Some("schema_name") + } + fn table_name(&self) -> & 'static str { + "imports" + } +} + +#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel)] +pub struct Model { + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] +pub enum Column { + A, + B, + C, + D, + E, + F, + G, + H, + I, + J, + K, +} + +#[derive(Copy, Clone, Debug, EnumIter, DerivePrimaryKey)] +pub enum PrimaryKey { + A, +} + +impl PrimaryKeyTrait for PrimaryKey { + type ValueType = Json; + fn auto_increment() -> bool { + true + } +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl ColumnTrait for Column { + type EntityName = Entity; + fn def(&self) -> ColumnDef { + match self { + Self::A => ColumnType::Json.def(), + Self::B => ColumnType::Date.def(), + Self::C => ColumnType::Time.def(), + Self::D => ColumnType::DateTime.def(), + Self::E => ColumnType::TimestampWithTimeZone.def(), + Self::F => ColumnType::Decimal(None).def(), + Self::G => ColumnType::Uuid.def(), + Self::H => ColumnType::Vector(None).def(), + Self::I => ColumnType::Inet.def(), + Self::J => ColumnType::Array(RcOrArc::new(ColumnType::Json)).def(), + Self::K => ColumnType::Array(RcOrArc::new(ColumnType::Array(RcOrArc::new( + ColumnType::Cidr + )))) + .def(), + } + } +} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + panic!("No RelationDef") + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/sea-orm-codegen/tests/frontend/imports.rs b/sea-orm-codegen/tests/frontend/imports.rs new file mode 100644 index 0000000000..7aad07095e --- /dev/null +++ b/sea-orm-codegen/tests/frontend/imports.rs @@ -0,0 +1,16 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +#[derive(Clone, Debug, PartialEq)] +pub struct Model { + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} diff --git a/sea-orm-codegen/tests/frontend_with_imports/imports.rs b/sea-orm-codegen/tests/frontend_with_imports/imports.rs new file mode 100644 index 0000000000..6f53966fe8 --- /dev/null +++ b/sea-orm-codegen/tests/frontend_with_imports/imports.rs @@ -0,0 +1,26 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +use chrono::{NaiveDate as Date, NaiveTime as Time, NaiveDateTime as DateTime}; +use serde_json::Value as Json; +use rust_decimal::Decimal; +use uuid::Uuid; +use pgvector::Vector as PgVector; +use ipnetwork::IpNetwork; +type DateTimeWithTimeZone = chrono::DateTime ; + +#[derive(Clone, Debug, PartialEq)] +pub struct Model { + #[serde(skip_deserializing)] + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} + diff --git a/sea-orm-codegen/tests/frontend_with_schema_name/imports.rs b/sea-orm-codegen/tests/frontend_with_schema_name/imports.rs new file mode 100644 index 0000000000..7aad07095e --- /dev/null +++ b/sea-orm-codegen/tests/frontend_with_schema_name/imports.rs @@ -0,0 +1,16 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.1.0 + +#[derive(Clone, Debug, PartialEq)] +pub struct Model { + pub a: Json, + pub b: Date, + pub c: Time, + pub d: DateTime, + pub e: DateTimeWithTimeZone, + pub f: Decimal, + pub g: Uuid, + pub h: PgVector, + pub i: IpNetwork, + pub j: Vec , + pub k: Vec> , +} diff --git a/sea-orm-codegen/tests/with_seaography/mod.rs b/sea-orm-codegen/tests/with_seaography/mod.rs index 59445e8d7a..0897f49a24 100644 --- a/sea-orm-codegen/tests/with_seaography/mod.rs +++ b/sea-orm-codegen/tests/with_seaography/mod.rs @@ -14,6 +14,7 @@ seaography::register_entity_modules!([ collection_float, parent, child, + imports, ]); seaography::register_active_enums!([ diff --git a/sea-orm-macros/src/derives/attributes.rs b/sea-orm-macros/src/derives/attributes.rs index 12d6b837ad..a5cab499af 100644 --- a/sea-orm-macros/src/derives/attributes.rs +++ b/sea-orm-macros/src/derives/attributes.rs @@ -77,7 +77,6 @@ pub mod value_type_attr { pub from_str: Option, pub to_str: Option, pub try_from_u64: Option<()>, - pub no_vec_impl: Option<()>, } } diff --git a/sea-orm-macros/src/derives/value_type.rs b/sea-orm-macros/src/derives/value_type.rs index 9033cdbcaf..f140eaeeae 100644 --- a/sea-orm-macros/src/derives/value_type.rs +++ b/sea-orm-macros/src/derives/value_type.rs @@ -1,5 +1,3 @@ -use crate::derives::value_type_match::omit_vec_impl; - use super::attributes::value_type_attr; use super::value_type_match::{array_type_expr, can_try_from_u64, column_type_expr}; use proc_macro2::TokenStream; @@ -17,8 +15,6 @@ struct DeriveValueTypeStruct { ty: Type, column_type: TokenStream, array_type: TokenStream, - /// Do not implement `sea_orm::TryGetableArray` for this type. Default: false. - no_vec_impl: bool, can_try_from_u64: bool, } @@ -27,7 +23,6 @@ struct DeriveValueTypeStructAttrs { column_type: Option, array_type: Option, try_from_u64: bool, - no_vec_impl: bool, } impl TryFrom for DeriveValueTypeStructAttrs { @@ -38,7 +33,6 @@ impl TryFrom for DeriveValueTypeStructAttrs { column_type: attrs.column_type.map(|s| s.parse()).transpose()?, array_type: attrs.array_type.map(|s| s.parse()).transpose()?, try_from_u64: attrs.try_from_u64.is_some(), - no_vec_impl: attrs.no_vec_impl.is_some(), }) } } @@ -158,14 +152,12 @@ impl DeriveValueTypeStruct { let column_type = column_type_expr(attrs.column_type, field_type, field_span); let array_type = array_type_expr(attrs.array_type, field_type, field_span); let can_try_from_u64 = attrs.try_from_u64 || can_try_from_u64(field_type); - let no_vec_impl = attrs.no_vec_impl || omit_vec_impl(field_type); Ok(Self { name, ty, column_type, array_type, - no_vec_impl, can_try_from_u64, }) } @@ -194,25 +186,6 @@ impl DeriveValueTypeStruct { quote!() }; - let impl_try_getable_array = if cfg!(feature = "postgres-array") && !self.no_vec_impl { - quote!( - #[automatically_derived] - impl sea_orm::TryGetableArray for #name { - fn try_get_by( - res: &sea_orm::QueryResult, - index: I, - ) -> std::result::Result, sea_orm::TryGetError> { - Ok( as sea_orm::TryGetable>::try_get_by(res, index)? - .into_iter() - .map(|value| Self(value)) - .collect()) - } - } - ) - } else { - quote!() - }; - let impl_not_u8 = if cfg!(feature = "postgres-array") { quote!( #[automatically_derived] @@ -226,7 +199,6 @@ impl DeriveValueTypeStruct { #[automatically_derived] impl std::convert::From<#name> for sea_orm::Value { fn from(source: #name) -> Self { - println!("Struct"); source.0.into() } } @@ -272,8 +244,6 @@ impl DeriveValueTypeStruct { } } - #impl_try_getable_array - #try_from_u64_impl #impl_not_u8 @@ -306,36 +276,6 @@ impl DeriveValueTypeString { None => "e!(String(sea_orm::sea_query::StringLen::None)), }; - let impl_try_getable_array = if cfg!(feature = "postgres-array") { - quote!( - #[automatically_derived] - impl sea_orm::TryGetableArray for #name { - fn try_get_by( - res: &sea_orm::QueryResult, - index: I, - ) -> std::result::Result, sea_orm::TryGetError> { - let mut result = Vec::new(); - for string in as sea_orm::TryGetable>::try_get_by(res, index)?.into_iter() { - result.push(#from_str(&string) - .map_err(|err| - { - sea_orm::TryGetError::DbErr( - sea_orm::DbErr::TryIntoErr { - from: "String", - into: stringify!(#name), - source: std::sync::Arc::new(err), - }) - } - )?); - } - Ok(result) - } - } - ) - } else { - quote!() - }; - let impl_not_u8 = if cfg!(feature = "postgres-array") { quote!( #[automatically_derived] @@ -403,8 +343,6 @@ impl DeriveValueTypeString { } #impl_not_u8 - - #impl_try_getable_array ) } } diff --git a/sea-orm-macros/src/derives/value_type_match.rs b/sea-orm-macros/src/derives/value_type_match.rs index 26360838f5..cfba21d052 100644 --- a/sea-orm-macros/src/derives/value_type_match.rs +++ b/sea-orm-macros/src/derives/value_type_match.rs @@ -151,29 +151,6 @@ pub fn can_try_from_u64(field_type: &str) -> bool { ) } -/// Maximum depth of vector nesting allowed INSIDE NEW TYPE before omitting sea_orm::TryGetableArray -/// For example, `struct A (Vec>)` has dimensionality of 2 -/// Abosolute maximum would be 5, because of Postgres limit of 6 -const MAX_VEC_DIMENSIONALITY: u8 = 0; - -/// Determines whether to omit `sea_orm::TryGetableArray` implementation for a given field type -/// based on the vector dimensionality. -pub fn omit_vec_impl(field_type: &str) -> bool { - let mut depth = 0u8; - let mut current = field_type.trim(); - - while let Some(inner) = current.strip_prefix("Vec<") { - #[allow(clippy::absurd_extreme_comparisons)] - if depth >= MAX_VEC_DIMENSIONALITY { - return true; - } - depth += 1; - current = inner.trim_start(); - } - - false -} - /// Return whether it is nullable fn trim_option(s: &str) -> (bool, &str) { if s.starts_with("Option<") { diff --git a/sea-orm-migration/Cargo.toml b/sea-orm-migration/Cargo.toml index 6dbc69e38e..3efb6dd165 100644 --- a/sea-orm-migration/Cargo.toml +++ b/sea-orm-migration/Cargo.toml @@ -1,5 +1,5 @@ -[workspace] -# A separate workspace +# [workspace] +# # A separate workspace [package] authors = ["Billy Chan "] @@ -21,29 +21,34 @@ path = "src/lib.rs" [dependencies] async-trait = { version = "0.1", default-features = false } +chrono = { version = "0.4", default-features = false, features = ["clock"], optional = true } clap = { version = "4.3", features = ["env", "derive"], optional = true } +colored = "2" dotenvy = { version = "0.15", default-features = false, optional = true } -sea-orm = { version = "~2.0.0-rc.38", path = "../", features = [ - "schema-sync", -] } -sea-orm-cli = { version = "~2.0.0-rc.38", path = "../sea-orm-cli", default-features = false, optional = true } +glob = { version = "0.3", default-features = false } +regex = { version = "1", optional = true } +serde = { version = "1", features = ["derive"] } +serde_json = { version = "1" } +sea-orm = { version = "~2.0.0-rc.38", path = "../", features = ["schema-sync",] } sea-schema = { version = "0.17.0-rc", default-features = false, features = [ "discovery", "writer", "probe", ] } tracing = { version = "0.1", default-features = false, features = ["log"] } -tracing-subscriber = { version = "0.3.17", default-features = false, features = [ - "env-filter", - "fmt", -] } [dev-dependencies] +tempfile = "3" tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +[[test]] +name = "entity_first" +required-features = ["entity-first", "sqlx-sqlite", "runtime-tokio-native-tls"] + [features] -cli = ["clap", "dotenvy", "sea-orm-cli/cli"] +cli = ["clap", "dotenvy", "chrono", "regex"] default = ["cli"] +entity-first = ["chrono", "clap", "dotenvy", "regex"] entity-registry = ["sea-orm/entity-registry"] runtime-async-std = [ "sea-orm/runtime-async-std", @@ -92,6 +97,3 @@ with-json = ["sea-orm/with-json"] with-rust_decimal = ["sea-orm/with-rust_decimal"] with-time = ["sea-orm/with-time"] with-uuid = ["sea-orm/with-uuid"] - -[patch.crates-io] -# sea-query = { path = "../sea-query" } diff --git a/sea-orm-migration/src/cli.rs b/sea-orm-migration/src/cli.rs index 6222009c21..d98b60c896 100644 --- a/sea-orm-migration/src/cli.rs +++ b/sea-orm-migration/src/cli.rs @@ -2,15 +2,16 @@ use std::future::Future; use clap::Parser; use dotenvy::dotenv; -use std::{error::Error, fmt::Display, process::exit}; -use tracing_subscriber::{EnvFilter, prelude::*}; +use std::process::exit; use sea_orm::{ConnectOptions, Database, DbConn, DbErr}; use sea_orm_cli::{MigrateSubcommands, run_migrate_generate, run_migrate_init}; +use crate::response::{ApiMeta, ApiResponse}; use super::MigratorTraitSelf; const MIGRATION_DIR: &str = "./"; +const VERSION: &str = env!("CARGO_PKG_VERSION"); pub async fn run_cli(migrator: M) where @@ -20,10 +21,6 @@ where } /// Same as [`run_cli`] where you provide the function to create the [`DbConn`]. -/// -/// This allows configuring the database connection as you see fit. -/// E.g. you can change settings in [`ConnectOptions`] or you can load sqlite -/// extensions. pub async fn run_cli_with_connection(migrator: M, make_connection: F) where M: MigratorTraitSelf, @@ -42,73 +39,106 @@ where .set_schema_search_path(schema) .to_owned(); - let db = make_connection(connect_options) - .await - .expect("Fail to acquire database connection"); + let db = match make_connection(connect_options).await { + Ok(db) => db, + Err(e) => { + let meta = migrator_meta(&migrator); + emit_err::<()>(meta, e); + exit(1); + } + }; - run_migrate(migrator, &db, cli.command, cli.verbose) - .await - .unwrap_or_else(handle_error); + run_migrate(migrator, &db, cli.command).await; } -pub async fn run_migrate( - migrator: M, - db: &DbConn, - command: Option, - verbose: bool, -) -> Result<(), Box> +pub async fn run_migrate(migrator: M, db: &DbConn, command: Option) where M: MigratorTraitSelf, { - let filter = match verbose { - true => "debug", - false => "sea_orm_migration=info", - }; - - let filter_layer = EnvFilter::try_new(filter).unwrap(); - - if verbose { - let fmt_layer = tracing_subscriber::fmt::layer(); - tracing_subscriber::registry() - .with(filter_layer) - .with(fmt_layer) - .init() - } else { - let fmt_layer = tracing_subscriber::fmt::layer() - .with_target(false) - .with_level(false) - .without_time(); - tracing_subscriber::registry() - .with(filter_layer) - .with(fmt_layer) - .init() - }; + let meta = migrator_meta(&migrator); match command { - Some(MigrateSubcommands::Fresh) => migrator.fresh(db).await?, - Some(MigrateSubcommands::Refresh) => migrator.refresh(db).await?, - Some(MigrateSubcommands::Reset) => migrator.reset(db).await?, - Some(MigrateSubcommands::Status) => migrator.status(db).await?, - Some(MigrateSubcommands::Up { num }) => migrator.up(db, num).await?, - Some(MigrateSubcommands::Down { num }) => migrator.down(db, Some(num)).await?, - Some(MigrateSubcommands::Init) => run_migrate_init(MIGRATION_DIR)?, - Some(MigrateSubcommands::Generate { - migration_name, - universal_time: _, - local_time, - }) => run_migrate_generate(MIGRATION_DIR, &migration_name, !local_time)?, - _ => migrator.up(db, None).await?, - }; + Some(MigrateSubcommands::Status) => { + match migrator.status(db).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Up { num }) => { + match migrator.up(db, num).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Down { num }) => { + match migrator.down(db, Some(num)).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Fresh) => { + match migrator.fresh(db).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Refresh) => { + match migrator.refresh(db).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Reset) => { + match migrator.reset(db).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Init) => { + match run_migrate_init(MIGRATION_DIR) { + Ok(()) => { + #[derive(serde::Serialize)] + struct InitData { migration_dir: &'static str } + println!("{}", serde_json::to_string(&ApiResponse::ok(meta, InitData { migration_dir: MIGRATION_DIR })).unwrap()); + } + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + Some(MigrateSubcommands::Generate { migration_name, universal_time: _, local_time }) => { + match run_migrate_generate(MIGRATION_DIR, &migration_name, !local_time) { + Ok(()) => { + #[derive(serde::Serialize)] + struct GenData<'a> { migration_name: &'a str, migration_dir: &'static str } + println!("{}", serde_json::to_string(&ApiResponse::ok(meta, GenData { migration_name: &migration_name, migration_dir: MIGRATION_DIR })).unwrap()); + } + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + // No subcommand: apply all pending migrations + None => { + match migrator.up(db, None).await { + Ok(data) => println!("{}", serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap()), + Err(e) => { emit_err::<()>(meta, e); exit(1); } + } + } + } +} - Ok(()) +fn migrator_meta(migrator: &M) -> ApiMeta { + ApiMeta { + version: VERSION.to_string(), + migrations_hash: Some(migrator.migrations_hash()), + schema_hash: None, + } +} + +fn emit_err(meta: ApiMeta, error: impl std::fmt::Display) { + println!("{}", serde_json::to_string(&ApiResponse::::err(meta, error.to_string())).unwrap()); } #[derive(Parser)] #[command(version)] pub struct Cli { - #[arg(short = 'v', long, global = true, help = "Show debug messages")] - verbose: bool, - #[arg( global = true, short = 's', @@ -132,11 +162,3 @@ pub struct Cli { #[command(subcommand)] command: Option, } - -fn handle_error(error: E) -where - E: Display, -{ - eprintln!("{error}"); - exit(1); -} diff --git a/sea-orm-migration/src/codegen.rs b/sea-orm-migration/src/codegen.rs new file mode 100644 index 0000000000..c30c50bdb5 --- /dev/null +++ b/sea-orm-migration/src/codegen.rs @@ -0,0 +1,66 @@ +use sea_orm::Statement; + +pub struct MigrationMetadata<'a> { + pub version: &'a str, + pub generated_at: &'a str, + pub backend: &'a str, + pub changes: &'a [String], +} + +/// Render a complete migration `.rs` file from a list of SQL statements. +pub fn render_migration_file(stmts: &[Statement], meta: &MigrationMetadata<'_>) -> String { + let comment_header = render_comment_header(meta); + let up_body = render_up_body(stmts); + + format!( + r#"{comment_header} +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration {{ + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {{ + let db = manager.get_connection(); +{up_body} + Ok(()) + }} + + async fn down(&self, _manager: &SchemaManager) -> Result<(), DbErr> {{ + // TODO: implement down migration + todo!() + }} +}} +"# + ) +} + +/// Comments the header of the migration file, including version, timestamp, backend, and changes. +fn render_comment_header(meta: &MigrationMetadata<'_>) -> String { + let mut lines = vec![ + format!("// Generated by sea-orm-entity v{}", meta.version), + format!("// Generated at: {}", meta.generated_at), + format!("// Backend: {}", meta.backend), + "// Changes:".to_string(), + ]; + lines.extend(meta.changes.iter().map(|c| format!("// - {c}"))); + lines.join("\n") +} + +/// Renders the body of the up migration +fn render_up_body(stmts: &[Statement]) -> String { + // Just in case. I don't want to deal with this rn + stmts + .iter() + .map(|stmt| { + if stmt.values.is_some() { + let sql = stmt.sql.replace('\\', r"\\").replace('"', r#"\""#); + format!(r#" db.execute_unprepared("{sql}").await?;"#) + } else { + format!(" db.execute_unprepared(r#\"{}\"#).await?;", stmt.sql) + } + }) + .collect::>() + .join("\n") +} diff --git a/sea-orm-migration/src/entity_cli.rs b/sea-orm-migration/src/entity_cli.rs new file mode 100644 index 0000000000..f418714e26 --- /dev/null +++ b/sea-orm-migration/src/entity_cli.rs @@ -0,0 +1,670 @@ +use chrono::Utc; +use clap::{Parser, Subcommand}; +use dotenvy::dotenv; +use sea_orm::{ConnectOptions, Database, DbBackend, InterpretConfig, Schema, interpret_changes}; + +use crate::{ + EntitySet, MigratorTraitSelf, + codegen::MigrationMetadata, + fs::write_migration, + response::{ + ApiMeta, ApiResponse, DiffData, GenerateData, SchemaData, SuggestionJson, + UnresolvedRenameJson, WarningJson, fnv64_hex, + }, + summary::summarize, +}; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +#[derive(Parser)] +#[command( + name = "entity", + about = "Entity-first migration tool for SeaORM", + version +)] +struct Cli { + #[arg( + global = true, + short = 'u', + long, + env = "DATABASE_URL", + help = "Database URL" + )] + database_url: Option, + + #[arg( + global = true, + short = 's', + long, + env = "DATABASE_SCHEMA", + long_help = "Database schema\n \ + - For MySQL and SQLite, this argument is ignored.\n \ + - For PostgreSQL, this argument is optional with default value 'public'.\n" + )] + database_schema: Option, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand)] +enum Commands { + /// Discover schema changes between entity definitions and the live database. + /// + /// Returns JSON with discovered SQL statements, warnings, suggestions, and + /// any ambiguous renames that must be resolved before calling `generate`. + /// Never writes any files. + Diff { + /// Allow dangerous operations (e.g. dropping tables) in the diff + #[arg(long, default_value_t = true)] + allow_dangerous: bool, + }, + + /// Generate a migration file from entity definitions. + /// + /// Requires that `diff` was run first. Pass the `schema_hash` from the diff + /// output via `--schema-hash` so stale calls are rejected. All ambiguous + /// renames reported by `diff` must be resolved via `--rename` flags. + Generate { + /// Path to the migration crate directory + #[arg(long, default_value = "../migration")] + migration_dir: String, + + /// Name for the migration (e.g. `add_users`) + #[arg(required = true)] + name: String, + + /// Schema hash from the preceding `diff` output — used to detect staleness + #[arg(long, required = true)] + schema_hash: String, + + #[arg( + long, + default_value = "true", + help = "Generate migration file based on Utc time", + conflicts_with = "local_time", + display_order = 1001 + )] + universal_time: bool, + + #[arg( + long, + help = "Generate migration file based on Local time", + conflicts_with = "universal_time", + display_order = 1002 + )] + local_time: bool, + + /// Allow dangerous operations (must match the value used in `diff`) + #[arg(long, default_value_t = true)] + allow_dangerous: bool, + + /// Resolve an ambiguous rename in the format TABLE.OLD_COL:NEW_COL + #[arg(long = "rename", value_name = "TABLE.OLD:NEW")] + renames: Vec, + }, + + /// Preview the schema as defined by the registered entities, as SQL DDL statements. + /// + /// Does not connect to any database. Returns a JSON object with a `statements` array + /// of CREATE TABLE / CREATE TYPE / CREATE INDEX SQL strings, rendered for the + /// target database backend (specified via `--database-backend`). + Schema { + /// Database backend to render SQL for (postgres, mysql, sqlite) + #[arg(long, default_value = "postgres")] + database_backend: String, + }, + + #[command( + about = "Drop all tables from the database, then reapply all migrations", + display_order = 30 + )] + Fresh, + #[command( + about = "Rollback all applied migrations, then reapply all migrations", + display_order = 40 + )] + Refresh, + #[command(about = "Rollback all applied migrations", display_order = 50)] + Reset, + #[command(about = "Check the status of all migrations", display_order = 60)] + Status, + #[command(about = "Apply pending migrations", display_order = 70)] + Up { + #[arg(short, long, help = "Number of pending migrations to apply")] + num: Option, + }, + #[command(about = "Rollback applied migrations", display_order = 80)] + Down { + #[arg( + short, + long, + default_value = "1", + help = "Number of applied migrations to roll back", + display_order = 90 + )] + num: u32, + }, +} + +/// Run the entity-first CLI with the given entity set and migrator. +/// +/// Call this from your entity crate's `main.rs`: +/// +/// ```rust,ignore +/// #[tokio::main] +/// async fn main() { +/// sea_orm_migration::entity_cli::run_cli(Entities, migration::Migrator).await; +/// } +/// ``` +pub async fn run_cli(entity_set: E, migrator: M) +where + E: EntitySet, + M: MigratorTraitSelf, +{ + dotenv().ok(); + let cli = Cli::parse(); + + // Handle commands that don't need a DB connection first. + if let Some(Commands::Schema { database_backend }) = cli.command { + let meta = build_meta(&migrator, None); + match run_schema(entity_set, &database_backend) { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + } + return; + } + + let url = cli + .database_url + .expect("Environment variable 'DATABASE_URL' not set"); + let schema_path = cli.database_schema.unwrap_or_else(|| "public".to_owned()); + + let db = match Database::connect( + ConnectOptions::new(url) + .set_schema_search_path(schema_path) + .to_owned(), + ) + .await + { + Ok(db) => db, + Err(e) => { + let meta = build_meta(&migrator, None); + emit_err::<()>(meta, e); + std::process::exit(1); + } + }; + + let migration_table = migrator.migration_table_name().to_string(); + + match cli.command { + Some(Commands::Schema { .. }) => unreachable!("handled above"), + Some(Commands::Diff { allow_dangerous }) => { + let meta = build_meta(&migrator, None); + let pending = match migrator.get_pending_migrations(&db).await { + Ok(p) => p, + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }; + if !pending.is_empty() { + let names: Vec = pending.iter().map(|m| m.name().to_owned()).collect(); + emit_err::<()>( + meta, + format!( + "{} pending migration(s) must be applied before running diff:\n {}\nRun `migrate up` first.", + names.len(), + names.join("\n ") + ), + ); + std::process::exit(1); + } + match run_diff(entity_set, &db, allow_dangerous, &migration_table).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + } + } + + Some(Commands::Generate { + migration_dir, + name, + schema_hash, + local_time, + universal_time: _, + allow_dangerous, + renames, + }) => { + let meta = build_meta(&migrator, None); + match run_generate( + entity_set, + &db, + &migration_dir, + &name, + &schema_hash, + local_time, + allow_dangerous, + &renames, + &migration_table, + ) + .await + { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + } + } + + migrate_cmd => { + let meta = build_meta(&migrator, None); + match migrate_cmd { + Some(Commands::Up { num }) => match migrator.up(&db, num).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Down { num }) => match migrator.down(&db, Some(num)).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Fresh) => match migrator.fresh(&db).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Refresh) => match migrator.refresh(&db).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Reset) => match migrator.reset(&db).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Status) => match migrator.status(&db).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + None => match migrator.up(&db, None).await { + Ok(data) => println!( + "{}", + serde_json::to_string(&ApiResponse::ok(meta, data)).unwrap() + ), + Err(e) => { + emit_err::<()>(meta, e); + std::process::exit(1); + } + }, + Some(Commands::Diff { .. }) + | Some(Commands::Generate { .. }) + | Some(Commands::Schema { .. }) => unreachable!(), + } + } + } +} + +/// Discover schema changes. Never writes anything. +async fn run_diff( + entity_set: E, + db: &sea_orm::DatabaseConnection, + dangerous: bool, + protected_table: &str, +) -> Result> { + let backend = db.get_database_backend(); + let schema = Schema::new(backend); + let builder = entity_set + .register(schema.builder()) + .exclude(protected_table); + + let change_set = builder.discover(db, dangerous).await?; + let result = interpret_changes( + change_set, + &InterpretConfig { + db_backend: backend, + assumptions: true, + allow_dangerous: dangerous, + }, + ); + + let statements: Vec = result + .statements + .iter() + .map(|(_, s)| s.sql.clone()) + .collect(); + let schema_hash = fnv64_hex(statements.iter().map(String::as_str)); + let changes = summarize( + &result + .statements + .iter() + .map(|(_, s)| s.clone()) + .collect::>(), + ); + + let warnings = result + .warnings + .iter() + .map(|w| WarningJson { + kind: format!("{:?}", w.kind), + message: w.message.clone(), + }) + .collect(); + + let suggestions = result + .suggestions + .iter() + .map(|s| SuggestionJson { + kind: format!("{:?}", s.kind), + message: s.message.clone(), + }) + .collect(); + + let unresolved = result + .unresolved + .iter() + .map(|u| UnresolvedRenameJson { + table: u.table.clone(), + removed: u.removed.clone(), + candidates: u.candidates.iter().map(|c| c.added.clone()).collect(), + }) + .collect(); + + Ok(DiffData { + changes, + statements, + warnings, + suggestions, + unresolved, + schema_hash, + }) +} + +/// Build the entity-defined schema as SQL DDL without connecting to a database. +fn run_schema( + entity_set: E, + database_backend: &str, +) -> Result> { + let backend = match database_backend { + "postgres" | "postgresql" => DbBackend::Postgres, + "mysql" => DbBackend::MySql, + "sqlite" => DbBackend::Sqlite, + other => { + return Err(format!( + "Unknown database backend: {other}. Use postgres, mysql, or sqlite." + ) + .into()); + } + }; + let schema = Schema::new(backend); + let builder = entity_set.register(schema.builder()); + let statements = builder + .schema_statements() + .into_iter() + .map(|s| s.sql) + .collect(); + Ok(SchemaData { statements }) +} + +/// Generate and write a migration file. +async fn run_generate( + entity_set: E, + db: &sea_orm::DatabaseConnection, + migration_dir: &str, + name: &str, + expected_schema_hash: &str, + local_time: bool, + dangerous: bool, + renames: &[String], + protected_table: &str, +) -> Result> { + if name.contains('-') { + return Err("`-` cannot be used in migration name".into()); + } + + let backend = db.get_database_backend(); + let schema = Schema::new(backend); + let builder = entity_set + .register(schema.builder()) + .exclude(protected_table); + + let change_set = builder.discover(db, dangerous).await?; + let mut result = interpret_changes( + change_set, + &InterpretConfig { + db_backend: backend, + assumptions: true, + allow_dangerous: dangerous, + }, + ); + + // Validate schema hash before proceeding + let current_stmts: Vec = result + .statements + .iter() + .map(|(_, s)| s.sql.clone()) + .collect(); + let current_hash = fnv64_hex(current_stmts.iter().map(String::as_str)); + if current_hash != expected_schema_hash { + return Err(format!( + "Schema hash mismatch: expected {expected_schema_hash}, got {current_hash}. \ + Re-run `diff` to get a fresh schema hash." + ) + .into()); + } + + // Error if there are still unresolved renames + if !result.unresolved.is_empty() { + // Try to apply provided --rename flags + let cli_renames: Vec<(String, String, String)> = + renames.iter().filter_map(|s| parse_rename_arg(s)).collect(); + + let decisions = resolve_renames(&result.unresolved, &cli_renames)?; + result.apply_rename_decisions(&decisions, backend); + } else if !renames.is_empty() { + // --rename flags provided but nothing to resolve — harmless, ignore + } + + // After applying decisions, check if any unresolved remain + if !result.unresolved.is_empty() { + let remaining: Vec = result + .unresolved + .iter() + .map(|u| { + let candidates = u + .candidates + .iter() + .map(|c| c.added.as_str()) + .collect::>() + .join(", "); + format!("{}.{} -> [{}]", u.table, u.removed, candidates) + }) + .collect(); + return Err(format!( + "Unresolved ambiguous renames remain. Provide --rename for each:\n {}", + remaining.join("\n ") + ) + .into()); + } + + let stmts: Vec<_> = result.statements.into_iter().map(|(_, s)| s).collect(); + + if stmts.is_empty() { + return Err("No schema changes detected. Migration file not generated.".into()); + } + + let (timestamp, generated_at) = if local_time { + let now = chrono::Local::now(); + ( + now.format("%Y%m%d_%H%M%S").to_string(), + now.format("%Y-%m-%d %H:%M:%S %Z").to_string(), + ) + } else { + let now = Utc::now(); + ( + now.format("%Y%m%d_%H%M%S").to_string(), + now.format("%Y-%m-%d %H:%M:%S UTC").to_string(), + ) + }; + + let name_clean = name.trim().replace(' ', "_"); + let migration_name = format!("m{timestamp}_{name_clean}"); + let backend_name = match backend { + DbBackend::MySql => "MySQL", + DbBackend::Postgres => "PostgreSQL", + DbBackend::Sqlite => "SQLite", + _ => "Unknown", + }; + + let changes = summarize(&stmts); + let meta = MigrationMetadata { + version: VERSION, + generated_at: &generated_at, + backend: backend_name, + changes: &changes, + }; + + let filepath = write_migration(migration_dir, &migration_name, &stmts, &meta)?; + + Ok(GenerateData { + migration_name, + filepath: filepath.display().to_string(), + changes, + }) +} + +/// Apply `--rename` CLI overrides to the unresolved list, returning decisions. +/// Errors if any ambiguity is left unresolved. +fn resolve_renames( + unresolved: &[sea_orm::schema::resolver::AmbiguousRename], + cli_renames: &[(String, String, String)], +) -> Result, Box> { + use sea_orm::schema::RenameDecision; + + let mut decisions = Vec::new(); + let mut missing = Vec::new(); + + for ambiguous in unresolved { + if let Some((_, _, new_name)) = cli_renames + .iter() + .find(|(table, old, _)| *table == ambiguous.table && *old == ambiguous.removed) + { + if ambiguous.candidates.iter().any(|c| c.added == *new_name) { + decisions.push(RenameDecision::Rename { + from: ambiguous.removed.clone(), + to: new_name.clone(), + }); + } else { + return Err(format!( + "--rename {}.{}:{} is invalid: '{}' is not among candidates [{}]", + ambiguous.table, + ambiguous.removed, + new_name, + new_name, + ambiguous + .candidates + .iter() + .map(|c| c.added.as_str()) + .collect::>() + .join(", ") + ) + .into()); + } + } else { + missing.push(format!( + "{}.{} (candidates: {})", + ambiguous.table, + ambiguous.removed, + ambiguous + .candidates + .iter() + .map(|c| c.added.as_str()) + .collect::>() + .join(", ") + )); + } + } + + if !missing.is_empty() { + return Err(format!( + "Ambiguous renames require --rename flags:\n {}", + missing.join("\n ") + ) + .into()); + } + + Ok(decisions) +} + +/// Parse a `--rename TABLE.OLD:NEW` string into `(table, old, new)`. +fn parse_rename_arg(s: &str) -> Option<(String, String, String)> { + let (table_old, new) = s.split_once(':')?; + let (table, old) = table_old.split_once('.')?; + if table.is_empty() || old.is_empty() || new.is_empty() { + return None; + } + Some((table.to_string(), old.to_string(), new.to_string())) +} + +fn build_meta(migrator: &M, schema_hash: Option) -> ApiMeta { + ApiMeta { + version: VERSION.to_string(), + migrations_hash: Some(migrator.migrations_hash()), + schema_hash, + } +} + +fn emit_err(meta: ApiMeta, error: impl std::fmt::Display) { + println!( + "{}", + serde_json::to_string(&ApiResponse::::err(meta, error.to_string())).unwrap() + ); +} diff --git a/sea-orm-migration/src/fs.rs b/sea-orm-migration/src/fs.rs new file mode 100644 index 0000000000..1137ba9773 --- /dev/null +++ b/sea-orm-migration/src/fs.rs @@ -0,0 +1,98 @@ +use regex::Regex; +use sea_orm::Statement; +use std::{error::Error, fs, path::PathBuf}; + +use crate::codegen::{MigrationMetadata, render_migration_file}; + +pub fn write_migration( + migration_dir: &str, + migration_name: &str, + stmts: &[Statement], + meta: &MigrationMetadata<'_>, +) -> Result> { + let filepath = write_migration_file(migration_dir, migration_name, stmts, meta)?; + update_migrator(migration_dir, migration_name)?; + Ok(filepath) +} + +fn get_full_migration_dir(migration_dir: &str) -> PathBuf { + let base = PathBuf::from(migration_dir); + let with_src = base.join("src"); + if with_src.is_dir() { with_src } else { base } +} + +fn get_migrator_filepath(migration_dir: &str) -> PathBuf { + let full_dir = get_full_migration_dir(migration_dir); + let with_lib = full_dir.join("lib.rs"); + if with_lib.is_file() { + with_lib + } else { + full_dir.join("mod.rs") + } +} + +fn write_migration_file( + migration_dir: &str, + migration_name: &str, + stmts: &[Statement], + meta: &MigrationMetadata<'_>, +) -> Result> { + let filepath = get_full_migration_dir(migration_dir).join(format!("{migration_name}.rs")); + println!("Creating migration file `{}`", filepath.display()); + let content = render_migration_file(stmts, meta); + fs::write(&filepath, content.as_bytes())?; + Ok(filepath) +} + +fn update_migrator(migration_dir: &str, migration_name: &str) -> Result<(), Box> { + let migrator_filepath = get_migrator_filepath(migration_dir); + println!( + "Adding migration `{migration_name}` to `{}`", + migrator_filepath.display() + ); + let original = fs::read_to_string(&migrator_filepath)?; + + // Find existing mod declarations and get insertion index for a new one + let mod_regex = Regex::new(r"mod\s+(?Pm\d{8}_\d{6}_\w+);")?; + let mods: Vec<_> = mod_regex.captures_iter(&original).collect(); + let insert_pos = if let Some(last_match) = mods.last() { + last_match.get(0).unwrap().end() + 1 + } else { + // Insert at the beginning of the file (before `pub struct Migrator`) + original.find("pub struct").unwrap_or(original.len()) + }; + + // Insert the new mod declaration. + let new_mod_decl = if mods.is_empty() { + //When inserting before the struct, add a blank line to look nicer + format!("mod {migration_name};\n\n") + } else { + format!("mod {migration_name};\n") + }; + let mut updated = original.clone(); + updated.insert_str(insert_pos, &new_mod_decl); + + // Rebuild the migrations vec + let mut migrations: Vec<&str> = mods + .iter() + .map(|cap| cap.name("name").unwrap().as_str()) + .collect(); + migrations.push(migration_name); + let boxed = migrations + .iter() + .map(|m| format!(" Box::new({m}::Migration),")) + .collect::>() + .join("\n") + + "\n"; + let new_vec = format!("vec![\n{boxed} ]"); + + // Match both empty vec![] and vec![...] + let vec_regex = Regex::new(r"vec!\[[\s\S]*?\]")?; + let updated = vec_regex.replace(&updated, new_vec.as_str()); + + // write to a temp file beside the target, then rename + let tmp_path = migrator_filepath.with_extension("rs.tmp"); + fs::write(&tmp_path, updated.as_bytes())?; + fs::rename(&tmp_path, &migrator_filepath)?; + Ok(()) +} diff --git a/sea-orm-migration/src/lib.rs b/sea-orm-migration/src/lib.rs index 4834393c97..47ae72f935 100644 --- a/sea-orm-migration/src/lib.rs +++ b/sea-orm-migration/src/lib.rs @@ -4,10 +4,20 @@ pub mod connection; pub mod manager; pub mod migrator; pub mod prelude; +pub mod response; pub mod schema; pub mod seaql_migrations; pub mod util; +#[cfg(feature = "entity-first")] +pub mod codegen; +#[cfg(feature = "entity-first")] +pub mod entity_cli; +#[cfg(feature = "entity-first")] +pub mod fs; +#[cfg(feature = "entity-first")] +pub mod summary; + pub use connection::*; pub use manager::*; pub use migrator::*; @@ -17,6 +27,29 @@ pub use sea_orm; pub use sea_orm::DbErr; pub use sea_orm::sea_query; +#[cfg(feature = "entity-first")] +pub use sea_orm::schema::SchemaBuilder; + +/// Trait for a set of entities to be registered into a [`SchemaBuilder`]. +/// +/// Implement this on a unit struct in your entity crate: +/// +/// ```rust,ignore +/// pub struct Entities; +/// +/// impl sea_orm_migration::EntitySet for Entities { +/// fn register(self, builder: sea_orm_migration::SchemaBuilder) -> sea_orm_migration::SchemaBuilder { +/// builder +/// .register(user::Entity) +/// .register(post::Entity) +/// } +/// } +/// ``` +#[cfg(feature = "entity-first")] +pub trait EntitySet { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder; +} + pub trait MigrationName { fn name(&self) -> &str; } diff --git a/sea-orm-migration/src/migrator.rs b/sea-orm-migration/src/migrator.rs index bdae6073d0..39e9d2a2a4 100644 --- a/sea-orm-migration/src/migrator.rs +++ b/sea-orm-migration/src/migrator.rs @@ -7,9 +7,9 @@ mod with_self; pub use with_self::*; use std::fmt::Display; -use tracing::info; use super::{IntoSchemaManagerConnection, MigrationTrait, SchemaManager, seaql_migrations}; +use crate::response::{AppliedData, LifecycleData, MigrationEntry, RolledBackData, StatusData}; use sea_orm::sea_query::IntoIden; use sea_orm::{ConnectionTrait, DbErr, DynIden}; @@ -127,23 +127,24 @@ pub trait MigratorTrait: Send { } /// Check the status of all migrations - async fn status(db: &C) -> Result<(), DbErr> + async fn status(db: &C) -> Result where C: ConnectionTrait, { Self::install(db).await?; - - info!("Checking migration status"); - - for Migration { migration, status } in Self::get_migration_with_status(db).await? { - info!("Migration '{}'... {}", migration.name(), status); - } - - Ok(()) + let migrations = Self::get_migration_with_status(db) + .await? + .into_iter() + .map(|m| MigrationEntry { + name: m.migration.name().to_owned(), + status: m.status.to_string(), + }) + .collect(); + Ok(StatusData { migrations }) } /// Drop all tables from the database, then reapply all migrations - async fn fresh<'c, C>(db: C) -> Result<(), DbErr> + async fn fresh<'c, C>(db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { @@ -153,25 +154,30 @@ pub trait MigratorTrait: Send { } /// Rollback all applied migrations, then reapply all migrations - async fn refresh<'c, C>(db: C) -> Result<(), DbErr> + async fn refresh<'c, C>(db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down::(&manager, None).await?; - exec_up::(&manager, None).await + let rolled_back = exec_down::(&manager, None).await?; + let applied = exec_up::(&manager, None).await?; + Ok::<_, DbErr>(LifecycleData { + rolled_back, + applied, + }) } /// Rollback all applied migrations - async fn reset<'c, C>(db: C) -> Result<(), DbErr> + async fn reset<'c, C>(db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down::(&manager, None).await?; - uninstall(&manager, Self::migration_table_name()).await + let rolled_back = exec_down::(&manager, None).await?; + uninstall(&manager, Self::migration_table_name()).await?; + Ok::<_, DbErr>(RolledBackData { rolled_back }) } /// Uninstall migration tracking table only (non-destructive) @@ -186,47 +192,45 @@ pub trait MigratorTrait: Send { } /// Apply pending migrations - async fn up<'c, C>(db: C, steps: Option) -> Result<(), DbErr> + async fn up<'c, C>(db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_up::(&manager, steps).await + let applied = exec_up::(&manager, steps).await?; + Ok(AppliedData { applied }) } /// Rollback applied migrations - async fn down<'c, C>(db: C, steps: Option) -> Result<(), DbErr> + async fn down<'c, C>(db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down::(&manager, steps).await + let rolled_back = exec_down::(&manager, steps).await?; + Ok(RolledBackData { rolled_back }) } } -async fn exec_fresh(manager: &SchemaManager<'_>) -> Result<(), DbErr> +async fn exec_fresh(manager: &SchemaManager<'_>) -> Result where M: MigratorTrait + ?Sized, { let db = manager.get_connection(); - M::install(db).await?; - drop_everything(db).await?; - - exec_up::(manager, None).await + let applied = exec_up::(manager, None).await?; + Ok(AppliedData { applied }) } -async fn exec_up(manager: &SchemaManager<'_>, steps: Option) -> Result<(), DbErr> +async fn exec_up(manager: &SchemaManager<'_>, steps: Option) -> Result, DbErr> where M: MigratorTrait + ?Sized, { let db = manager.get_connection(); - M::install(db).await?; - exec_up_with( manager, steps, @@ -236,14 +240,12 @@ where .await } -async fn exec_down(manager: &SchemaManager<'_>, steps: Option) -> Result<(), DbErr> +async fn exec_down(manager: &SchemaManager<'_>, steps: Option) -> Result, DbErr> where M: MigratorTrait + ?Sized, { let db = manager.get_connection(); - M::install(db).await?; - exec_down_with( manager, steps, diff --git a/sea-orm-migration/src/migrator/exec.rs b/sea-orm-migration/src/migrator/exec.rs index 8fdde4208a..76f9a689ab 100644 --- a/sea-orm-migration/src/migrator/exec.rs +++ b/sea-orm-migration/src/migrator/exec.rs @@ -1,7 +1,6 @@ use std::collections::HashSet; #[cfg(not(feature = "with-time"))] use std::time::SystemTime; -use tracing::info; use super::{Migration, MigrationStatus, queries::*}; use crate::{SchemaManager, seaql_migrations}; @@ -106,76 +105,55 @@ pub async fn drop_everything(db: &C) -> R async fn drop_everything_impl(db: &C) -> Result<(), DbErr> { let db_backend = db.get_database_backend(); - // Temporarily disable the foreign key check if db_backend == DbBackend::Sqlite { - info!("Disabling foreign key check"); db.execute_raw(Statement::from_string( db_backend, "PRAGMA foreign_keys = OFF".to_owned(), )) .await?; - info!("Foreign key check disabled"); } - // Drop all foreign keys if db_backend == DbBackend::MySql { - info!("Dropping all foreign keys"); let stmt = query_mysql_foreign_keys(db); let rows = db.query_all(&stmt).await?; for row in rows.into_iter() { let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?; let table_name: String = row.try_get("", "TABLE_NAME")?; - info!( - "Dropping foreign key '{}' from table '{}'", - constraint_name, table_name - ); let mut stmt = ForeignKey::drop(); stmt.table(Alias::new(table_name.as_str())) .name(constraint_name.as_str()); db.execute(&stmt).await?; - info!("Foreign key '{}' has been dropped", constraint_name); } - info!("All foreign keys dropped"); } - // Drop all tables let stmt = query_tables(db)?; let rows = db.query_all(&stmt).await?; for row in rows.into_iter() { let table_name: String = row.try_get("", "table_name")?; - info!("Dropping table '{}'", table_name); let mut stmt = Table::drop(); stmt.table(Alias::new(table_name.as_str())) .if_exists() .cascade(); db.execute(&stmt).await?; - info!("Table '{}' has been dropped", table_name); } - // Drop all types if db_backend == DbBackend::Postgres { - info!("Dropping all types"); let stmt = query_pg_types(db); let rows = db.query_all(&stmt).await?; for row in rows { let type_name: String = row.try_get("", "typname")?; - info!("Dropping type '{}'", type_name); let mut stmt = Type::drop(); stmt.name(Alias::new(&type_name)); db.execute(&stmt).await?; - info!("Type '{}' has been dropped", type_name); } } - // Restore the foreign key check if db_backend == DbBackend::Sqlite { - info!("Restoring foreign key check"); db.execute_raw(Statement::from_string( db_backend, "PRAGMA foreign_keys = ON".to_owned(), )) .await?; - info!("Foreign key check restored"); } Ok(()) @@ -228,17 +206,9 @@ pub async fn exec_up_with( mut steps: Option, pending_migrations: Vec, migration_table_name: DynIden, -) -> Result<(), DbErr> { +) -> Result, DbErr> { let db = manager.get_connection(); - - if let Some(steps) = steps { - info!("Applying {} pending migrations", steps); - } else { - info!("Applying all pending migrations"); - } - if pending_migrations.is_empty() { - info!("No pending migrations"); - } + let mut applied = Vec::new(); for Migration { migration, .. } in pending_migrations { if let Some(steps) = steps.as_mut() { @@ -249,24 +219,22 @@ pub async fn exec_up_with( } let use_txn = should_use_transaction(migration.as_ref(), db.get_database_backend()); - info!("Applying migration '{}'", migration.name()); if use_txn { let transaction = db.begin().await?; let txn_manager = SchemaManager::new(&transaction); migration.up(&txn_manager).await?; - info!("Migration '{}' has been applied", migration.name()); insert_migration_record(&transaction, migration.name(), migration_table_name.clone()) .await?; transaction.commit().await?; } else { migration.up(manager).await?; - info!("Migration '{}' has been applied", migration.name()); insert_migration_record(db, migration.name(), migration_table_name.clone()).await?; } + applied.push(migration.name().to_owned()); } - Ok(()) + Ok(applied) } pub async fn exec_down_with( @@ -274,17 +242,9 @@ pub async fn exec_down_with( mut steps: Option, applied_migrations: Vec, migration_table_name: DynIden, -) -> Result<(), DbErr> { +) -> Result, DbErr> { let db = manager.get_connection(); - - if let Some(steps) = steps { - info!("Rolling back {} applied migrations", steps); - } else { - info!("Rolling back all applied migrations"); - } - if applied_migrations.is_empty() { - info!("No applied migrations"); - } + let mut rolled_back = Vec::new(); for Migration { migration, .. } in applied_migrations.into_iter().rev() { if let Some(steps) = steps.as_mut() { @@ -295,22 +255,20 @@ pub async fn exec_down_with( } let use_txn = should_use_transaction(migration.as_ref(), db.get_database_backend()); - info!("Rolling back migration '{}'", migration.name()); if use_txn { let transaction = db.begin().await?; let txn_manager = SchemaManager::new(&transaction); migration.down(&txn_manager).await?; - info!("Migration '{}' has been rolled back", migration.name()); delete_migration_record(&transaction, migration.name(), migration_table_name.clone()) .await?; transaction.commit().await?; } else { migration.down(manager).await?; - info!("Migration '{}' has been rolled back", migration.name()); delete_migration_record(db, migration.name(), migration_table_name.clone()).await?; } + rolled_back.push(migration.name().to_owned()); } - Ok(()) + Ok(rolled_back) } diff --git a/sea-orm-migration/src/migrator/with_self.rs b/sea-orm-migration/src/migrator/with_self.rs index d679bb0f85..8e440405d0 100644 --- a/sea-orm-migration/src/migrator/with_self.rs +++ b/sea-orm-migration/src/migrator/with_self.rs @@ -1,10 +1,12 @@ use super::{Migration, MigrationStatus, exec::*}; -use crate::{IntoSchemaManagerConnection, MigrationTrait, SchemaManager, seaql_migrations}; +use crate::{ + IntoSchemaManagerConnection, MigrationTrait, SchemaManager, + response::{AppliedData, LifecycleData, MigrationEntry, RolledBackData, StatusData, fnv64_hex}, + seaql_migrations, +}; use sea_orm::sea_query::IntoIden; use sea_orm::{ConnectionTrait, DbErr, DynIden}; -use tracing::info; - /// Performing migrations on a database #[async_trait::async_trait] pub trait MigratorTraitSelf: Sized + Send + Sync { @@ -16,6 +18,18 @@ pub trait MigratorTraitSelf: Sized + Send + Sync { seaql_migrations::Entity.into_iden() } + /// FNV64 hex digest of the ordered list of migration names. + /// Used as the `migrations_hash` in JSON responses so callers can detect + /// binary/config drift. + fn migrations_hash(&self) -> String { + let names: Vec = self + .migrations() + .iter() + .map(|m| m.name().to_owned()) + .collect(); + fnv64_hex(names.iter().map(String::as_str)) + } + /// Get list of migrations wrapped in `Migration` struct fn get_migration_files(&self) -> Vec { self.migrations() @@ -85,23 +99,25 @@ pub trait MigratorTraitSelf: Sized + Send + Sync { } /// Check the status of all migrations - async fn status(&self, db: &C) -> Result<(), DbErr> + async fn status(&self, db: &C) -> Result where C: ConnectionTrait, { self.install(db).await?; - - info!("Checking migration status"); - - for Migration { migration, status } in self.get_migration_with_status(db).await? { - info!("Migration '{}'... {}", migration.name(), status); - } - - Ok(()) + let migrations = self + .get_migration_with_status(db) + .await? + .into_iter() + .map(|m| MigrationEntry { + name: m.migration.name().to_owned(), + status: m.status.to_string(), + }) + .collect(); + Ok(StatusData { migrations }) } /// Drop all tables from the database, then reapply all migrations - async fn fresh<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn fresh<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { @@ -111,29 +127,33 @@ pub trait MigratorTraitSelf: Sized + Send + Sync { } /// Rollback all applied migrations, then reapply all migrations - async fn refresh<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn refresh<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down(self, &manager, None).await?; - exec_up(self, &manager, None).await + let rolled_back = exec_down(self, &manager, None).await?; + let applied = exec_up(self, &manager, None).await?; + Ok::<_, DbErr>(LifecycleData { + rolled_back, + applied, + }) } /// Rollback all applied migrations - async fn reset<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn reset<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down(self, &manager, None).await?; - uninstall(&manager, self.migration_table_name()).await + let rolled_back = exec_down(self, &manager, None).await?; + uninstall(&manager, self.migration_table_name()).await?; + Ok::<_, DbErr>(RolledBackData { rolled_back }) } /// Uninstall migration tracking table only (non-destructive) - /// This will drop the `seaql_migrations` table but won't rollback other schema changes. async fn uninstall<'c, C>(&self, db: C) -> Result<(), DbErr> where C: IntoSchemaManagerConnection<'c>, @@ -144,23 +164,25 @@ pub trait MigratorTraitSelf: Sized + Send + Sync { } /// Apply pending migrations - async fn up<'c, C>(&self, db: C, steps: Option) -> Result<(), DbErr> + async fn up<'c, C>(&self, db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_up(self, &manager, steps).await + let applied = exec_up(self, &manager, steps).await?; + Ok(AppliedData { applied }) } /// Rollback applied migrations - async fn down<'c, C>(&self, db: C, steps: Option) -> Result<(), DbErr> + async fn down<'c, C>(&self, db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { let db = db.into_database_executor(); let manager = SchemaManager::new(db); - exec_down(self, &manager, steps).await + let rolled_back = exec_down(self, &manager, steps).await?; + Ok(RolledBackData { rolled_back }) } } @@ -216,33 +238,56 @@ where M::install(db).await } - /// Check the status of all migrations - async fn status(&self, db: &C) -> Result<(), DbErr> + async fn status(&self, db: &C) -> Result where C: ConnectionTrait, { - M::status(db).await + self.install(db).await?; + let migrations = self + .get_migration_with_status(db) + .await? + .into_iter() + .map(|m| MigrationEntry { + name: m.migration.name().to_owned(), + status: m.status.to_string(), + }) + .collect(); + Ok(StatusData { migrations }) } - async fn fresh<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn fresh<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { - M::fresh(db).await + let db = db.into_database_executor(); + let manager = SchemaManager::new(db); + let applied = exec_fresh(self, &manager).await?; + Ok::<_, DbErr>(applied) } - async fn refresh<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn refresh<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { - M::refresh(db).await + let db = db.into_database_executor(); + let manager = SchemaManager::new(db); + let rolled_back = exec_down(self, &manager, None).await?; + let applied = exec_up(self, &manager, None).await?; + Ok::<_, DbErr>(LifecycleData { + rolled_back, + applied, + }) } - async fn reset<'c, C>(&self, db: C) -> Result<(), DbErr> + async fn reset<'c, C>(&self, db: C) -> Result where C: IntoSchemaManagerConnection<'c>, { - M::reset(db).await + let db = db.into_database_executor(); + let manager = SchemaManager::new(db); + let rolled_back = exec_down(self, &manager, None).await?; + uninstall(&manager, self.migration_table_name()).await?; + Ok::<_, DbErr>(RolledBackData { rolled_back }) } async fn uninstall<'c, C>(&self, db: C) -> Result<(), DbErr> @@ -252,46 +297,48 @@ where M::uninstall(db).await } - async fn up<'c, C>(&self, db: C, steps: Option) -> Result<(), DbErr> + async fn up<'c, C>(&self, db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { - M::up(db, steps).await + let db = db.into_database_executor(); + let manager = SchemaManager::new(db); + let applied = exec_up(self, &manager, steps).await?; + Ok(AppliedData { applied }) } - async fn down<'c, C>(&self, db: C, steps: Option) -> Result<(), DbErr> + async fn down<'c, C>(&self, db: C, steps: Option) -> Result where C: IntoSchemaManagerConnection<'c>, { - M::down(db, steps).await + let db = db.into_database_executor(); + let manager = SchemaManager::new(db); + let rolled_back = exec_down(self, &manager, steps).await?; + Ok(RolledBackData { rolled_back }) } } -async fn exec_fresh(migrator: &M, manager: &SchemaManager<'_>) -> Result<(), DbErr> +async fn exec_fresh(migrator: &M, manager: &SchemaManager<'_>) -> Result where M: MigratorTraitSelf, { let db = manager.get_connection(); - migrator.install(db).await?; - drop_everything(db).await?; - - exec_up(migrator, manager, None).await + let applied = exec_up(migrator, manager, None).await?; + Ok(AppliedData { applied }) } async fn exec_up( migrator: &M, manager: &SchemaManager<'_>, steps: Option, -) -> Result<(), DbErr> +) -> Result, DbErr> where M: MigratorTraitSelf, { let db = manager.get_connection(); - migrator.install(db).await?; - exec_up_with( manager, steps, @@ -305,14 +352,12 @@ async fn exec_down( migrator: &M, manager: &SchemaManager<'_>, steps: Option, -) -> Result<(), DbErr> +) -> Result, DbErr> where M: MigratorTraitSelf, { let db = manager.get_connection(); - migrator.install(db).await?; - exec_down_with( manager, steps, diff --git a/sea-orm-migration/src/response.rs b/sea-orm-migration/src/response.rs new file mode 100644 index 0000000000..7112ae752e --- /dev/null +++ b/sea-orm-migration/src/response.rs @@ -0,0 +1,171 @@ +//! JSON response types for the sea-orm-migration machine API. +//! +//! Every command writes exactly one JSON object to stdout. The envelope is +//! [`ApiResponse`] which carries [`ApiMeta`] for version/sync tracking plus a +//! command-specific `data` payload. On error, `data` is `null` and `error` +//! contains the message. + +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Envelope +// --------------------------------------------------------------------------- + +/// Emitted for every command. Serialized as a single JSON line to stdout. +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiResponse { + pub ok: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + pub meta: ApiMeta, + pub data: Option, +} + +impl ApiResponse { + pub fn ok(meta: ApiMeta, data: T) -> Self { + Self { + ok: true, + error: None, + meta, + data: Some(data), + } + } + + pub fn err(meta: ApiMeta, error: impl Into) -> ApiResponse { + ApiResponse { + ok: false, + error: Some(error.into()), + meta, + data: None, + } + } +} + +// --------------------------------------------------------------------------- +// Meta +// --------------------------------------------------------------------------- + +/// Versioning and sync-tracking fields present in every response. +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiMeta { + /// Semver of the sea-orm-migration crate that produced this response. + pub version: String, + + /// For migration-first commands: FNV64 hex digest of the sorted list of + /// migration names registered in the binary. Changes whenever migrations + /// are added or removed. + #[serde(skip_serializing_if = "Option::is_none")] + pub migrations_hash: Option, + + /// For entity-first commands: FNV64 hex digest of the SQL statements + /// produced by the entity set's schema builder (backend-independent + /// representation). Changes whenever entity definitions change. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema_hash: Option, +} + +/// Compute a stable hex string from an iterator of string slices. +/// Uses FNV-1a 64-bit — no extra dependency, deterministic, fast. +pub fn fnv64_hex<'a>(items: impl Iterator) -> String { + const OFFSET: u64 = 0xcbf29ce484222325; + const PRIME: u64 = 0x00000100000001b3; + let mut hash = OFFSET; + for item in items { + for byte in item.bytes() { + hash ^= byte as u64; + hash = hash.wrapping_mul(PRIME); + } + // Separator so ["ab","c"] != ["a","bc"] + hash ^= 0xff; + hash = hash.wrapping_mul(PRIME); + } + format!("{hash:016x}") +} + +// --------------------------------------------------------------------------- +// Entity-first data types +// --------------------------------------------------------------------------- + +/// Output of `diff` — discovered schema changes, never writes anything. +#[derive(Debug, Serialize, Deserialize)] +pub struct DiffData { + /// Human-readable summaries of each SQL statement that would be generated. + pub changes: Vec, + /// Raw SQL statements that would be applied. + pub statements: Vec, + /// Always-on warnings requiring manual attention. + pub warnings: Vec, + /// Heuristic suggestions (renames etc.). + pub suggestions: Vec, + /// Ambiguous renames that the caller must resolve before calling `generate`. + pub unresolved: Vec, + /// FNV64 hex digest of the discovered SQL — must be passed back to `generate` + /// unchanged so stale calls are rejected. + pub schema_hash: String, +} + +/// Output of `schema` — entity-defined schema as SQL DDL, no DB connection needed. +#[derive(Debug, Serialize, Deserialize)] +pub struct SchemaData { + /// SQL DDL statements for all registered entities (CREATE TABLE, CREATE TYPE, CREATE INDEX). + pub statements: Vec, +} + +/// Output of `generate` — writes a migration file and updates `lib.rs`. +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateData { + pub migration_name: String, + pub filepath: String, + pub changes: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct WarningJson { + pub kind: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SuggestionJson { + pub kind: String, + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UnresolvedRenameJson { + pub table: String, + pub removed: String, + pub candidates: Vec, +} + +// --------------------------------------------------------------------------- +// Migration-first data types +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize, Deserialize)] +pub struct MigrationEntry { + pub name: String, + pub status: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusData { + pub migrations: Vec, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct AppliedData { + pub applied: Vec, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct RolledBackData { + pub rolled_back: Vec, +} + +/// Used for fresh/refresh/reset where we report both applied and rolled-back. +#[derive(Debug, Serialize, Deserialize)] +pub struct LifecycleData { + pub rolled_back: Vec, + pub applied: Vec, +} diff --git a/sea-orm-migration/src/summary.rs b/sea-orm-migration/src/summary.rs new file mode 100644 index 0000000000..e83d2242b5 --- /dev/null +++ b/sea-orm-migration/src/summary.rs @@ -0,0 +1,122 @@ +use sea_orm::Statement; + +/// Parse a list of SQL statements into hman-readable descriptions +pub fn summarize(stmts: &[Statement]) -> Vec { + stmts.iter().map(|s| describe(&s.sql)).collect() +} + +fn describe(sql: &str) -> String { + let upper = sql.to_uppercase(); + let sql = sql.trim(); + + if upper.contains("CREATE TABLE") { + if let Some(name) = extract_after(sql, &upper, "CREATE TABLE", Some("IF NOT EXISTS")) { + return format!("Created table: {name}"); + } + } + + if upper.contains("ALTER TABLE") { + let table = extract_after(sql, &upper, "ALTER TABLE", None); + if upper.contains("ADD COLUMN") { + if let (Some(table), Some(col)) = ( + table.as_ref(), + extract_after(sql, &upper, "ADD COLUMN", None), + ) { + return format!("Added column: {table}.{col}"); + } + } else if upper.contains("RENAME COLUMN") { + if let Some(table) = table { + return format!("Renamed column on: {table}"); + } + } else if upper.contains("DROP COLUMN") { + if let (Some(table), Some(col)) = ( + table.as_ref(), + extract_after(sql, &upper, "DROP COLUMN", None), + ) { + return format!("Dropped column: {table}.{col}"); + } + } else if upper.contains("ADD CONSTRAINT") { + if let Some(table) = table { + return format!("Added foreign key on: {table}"); + } + } else if upper.contains("DROP CONSTRAINT") { + if let Some(table) = table { + return format!("Dropped constraint on: {table}"); + } + } else if upper.contains("DROP FOREIGN KEY") { + if let Some(table) = table { + return format!("Dropped foreign key on: {table}"); + } + } + } + + if upper.contains("ALTER TYPE") && upper.contains("ADD VALUE") { + return "Added enum variant".to_string(); + } + + if upper.contains("DROP TABLE") { + if let Some(name) = extract_after(sql, &upper, "DROP TABLE", Some("IF EXISTS")) { + return format!("Dropped table: {name}"); + } + } + + if upper.contains("CREATE INDEX") || upper.contains("CREATE UNIQUE INDEX") { + if let Some(pos) = upper.find(" ON ") { + let after = sql[pos + " ON ".len()..].trim_start(); + let table = extract_identifier(after); + let kind = if upper.contains("UNIQUE") { + "unique index" + } else { + "index" + }; + return format!("Added {kind} on: {table}"); + } + } + + if upper.contains("CREATE TYPE") { + return "Created enum type".to_string(); + } + + // Fallback: first 80 chars of SQL + if sql.len() > 80 { + format!("SQL: {}...", &sql[..80]) + } else { + format!("SQL: {sql}") + } +} + +fn extract_after(sql: &str, upper: &str, keyword: &str, skip: Option<&str>) -> Option { + let pos = upper.find(keyword)?; + let rest = sql[pos + keyword.len()..].trim_start(); + let rest_upper = &upper[pos + keyword.len()..]; + let rest_upper = rest_upper.trim_start(); + let rest = if let Some(skip) = skip { + if rest_upper.starts_with(skip) { + rest[skip.len()..].trim_start() + } else { + rest + } + } else { + rest + }; + Some(extract_identifier(rest)) +} + +fn extract_identifier(s: &str) -> String { + let s = s.trim(); + if s.starts_with('"') { + // Double-quoted identifier + let end = s[1..].find('"').unwrap_or(s.len() - 1); + s[1..end + 1].to_string() + } else if s.starts_with('`') { + // Backtick-quoted identifier (MySQL) + let end = s[1..].find('`').unwrap_or(s.len() - 1); + s[1..end + 1].to_string() + } else { + // Unquoted: take until whitespace or `(` + s.split(|c: char| c.is_whitespace() || c == '(') + .next() + .unwrap_or(s) + .to_string() + } +} diff --git a/sea-orm-migration/template/migration/README.md b/sea-orm-migration/template/migration/README.md new file mode 100644 index 0000000000..3b438d89e3 --- /dev/null +++ b/sea-orm-migration/template/migration/README.md @@ -0,0 +1,41 @@ +# Running Migrator CLI + +- Generate a new migration file + ```sh + cargo run -- generate MIGRATION_NAME + ``` +- Apply all pending migrations + ```sh + cargo run + ``` + ```sh + cargo run -- up + ``` +- Apply first 10 pending migrations + ```sh + cargo run -- up -n 10 + ``` +- Rollback last applied migrations + ```sh + cargo run -- down + ``` +- Rollback last 10 applied migrations + ```sh + cargo run -- down -n 10 + ``` +- Drop all tables from the database, then reapply all migrations + ```sh + cargo run -- fresh + ``` +- Rollback all applied migrations, then reapply all migrations + ```sh + cargo run -- refresh + ``` +- Rollback all applied migrations + ```sh + cargo run -- reset + ``` +- Check the status of all migrations + ```sh + cargo run -- status + ``` diff --git a/sea-orm-migration/template/migration/_Cargo.toml b/sea-orm-migration/template/migration/_Cargo.toml new file mode 100644 index 0000000000..57797a39dc --- /dev/null +++ b/sea-orm-migration/template/migration/_Cargo.toml @@ -0,0 +1,23 @@ +[package] +edition = "2024" +name = "migration" +publish = false +rust-version = "1.85.0" +version = "0.1.0" + +[lib] +name = "migration" +path = "src/lib.rs" + +[dependencies] +tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] } + +[dependencies.sea-orm-migration] +features = [ + # Enable at least one `ASYNC_RUNTIME` and `DATABASE_DRIVER` feature if you want to run migration via CLI. + # View the list of supported features at https://www.sea-ql.org/SeaORM/docs/install-and-config/database-and-async-runtime. + # e.g. + # "runtime-tokio-rustls", # `ASYNC_RUNTIME` feature + # "sqlx-postgres", # `DATABASE_DRIVER` feature +] +version = "" diff --git a/sea-orm-migration/template/migration/_gitignore b/sea-orm-migration/template/migration/_gitignore new file mode 100644 index 0000000000..c41cc9e35e --- /dev/null +++ b/sea-orm-migration/template/migration/_gitignore @@ -0,0 +1 @@ +/target \ No newline at end of file diff --git a/sea-orm-migration/template/migration/src/lib.rs b/sea-orm-migration/template/migration/src/lib.rs new file mode 100644 index 0000000000..2c605afb94 --- /dev/null +++ b/sea-orm-migration/template/migration/src/lib.rs @@ -0,0 +1,12 @@ +pub use sea_orm_migration::prelude::*; + +mod m20220101_000001_create_table; + +pub struct Migrator; + +#[async_trait::async_trait] +impl MigratorTrait for Migrator { + fn migrations() -> Vec> { + vec![Box::new(m20220101_000001_create_table::Migration)] + } +} diff --git a/sea-orm-migration/template/migration/src/m20220101_000001_create_table.rs b/sea-orm-migration/template/migration/src/m20220101_000001_create_table.rs new file mode 100644 index 0000000000..280e98489b --- /dev/null +++ b/sea-orm-migration/template/migration/src/m20220101_000001_create_table.rs @@ -0,0 +1,33 @@ +use sea_orm_migration::{prelude::*, schema::*}; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + // Replace the sample below with your own migration scripts + todo!(); + + manager + .create_table( + Table::create() + .table("post") + .if_not_exists() + .col(pk_auto("id")) + .col(string("title")) + .col(string("text")) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + // Replace the sample below with your own migration scripts + todo!(); + + manager + .drop_table(Table::drop().table("post").to_owned()) + .await + } +} diff --git a/sea-orm-migration/template/migration/src/main.rs b/sea-orm-migration/template/migration/src/main.rs new file mode 100644 index 0000000000..f054deaf88 --- /dev/null +++ b/sea-orm-migration/template/migration/src/main.rs @@ -0,0 +1,6 @@ +use sea_orm_migration::prelude::*; + +#[tokio::main] +async fn main() { + cli::run_cli(migration::Migrator).await; +} diff --git a/sea-orm-migration/tests/common/entity_common/mod.rs b/sea-orm-migration/tests/common/entity_common/mod.rs new file mode 100644 index 0000000000..b494f531ff --- /dev/null +++ b/sea-orm-migration/tests/common/entity_common/mod.rs @@ -0,0 +1,233 @@ +// --------------------------------------------------------------------------- +// Entity definitions — realistic schema with relations, unique constraints, +// foreign keys, and versioned variants for testing discover() scenarios. +// --------------------------------------------------------------------------- + +/// `cake` — has a unique name, owns many `fruit`s. +pub mod cake { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(unique)] + pub name: String, + #[sea_orm(has_many)] + pub fruits: HasMany, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `fruit` — belongs to a `cake` via foreign key. +pub mod fruit { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "fruit")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub cake_id: i32, + #[sea_orm(belongs_to, from = "cake_id", to = "id")] + pub cake: HasOne, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `cake_v1` — initial version of cake without a unique name (for diff tests). +pub mod cake_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `cake_v2` — adds a `description` column and a `category` column. +pub mod cake_v2 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + #[sea_orm(column_type = "Text")] + pub description: String, + pub category: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `fruit_v1` — initial fruit without a `weight` column. +pub mod fruit_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "fruit")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub cake_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `fruit_v2` — adds a `weight_grams` column (integer) and a `unique` constraint on `name`. +pub mod fruit_v2 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "fruit")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(unique)] + pub name: String, + pub cake_id: i32, + pub weight_grams: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `cake_renamed` — same schema as `cake_v1` but `name` is renamed to `title` (same String type). +/// Used for testing rename detection heuristic. +pub mod cake_renamed { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub title: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `cake_type_change` — same schema as `cake_v1` but `name` removed and `count` (i32) added. +/// The type differs, so rename detection should NOT trigger. +pub mod cake_type_change { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub count: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// `cake_ambiguous` — same schema as `cake_v1` but `name` removed, `title` and `label` +/// both added (same String type). This creates an ambiguous rename scenario. +pub mod cake_ambiguous { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "cake")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub title: String, + pub label: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// EntitySet implementations +// --------------------------------------------------------------------------- + +use sea_orm_migration::{EntitySet, SchemaBuilder}; + +/// Full schema: cake + fruit with FK relation. +pub struct FullSchema; +impl EntitySet for FullSchema { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder + .register(cake::Entity) + .register(fruit::Entity) + } +} + +/// Only cake v1 (no unique, no extra columns). +pub struct CakeV1Only; +impl EntitySet for CakeV1Only { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder.register(cake_v1::Entity) + } +} + +/// Cake v2 + fruit v1 (cake gains columns, fruit gains nothing yet). +pub struct CakeV2FruitV1; +impl EntitySet for CakeV2FruitV1 { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder + .register(cake_v2::Entity) + .register(fruit_v1::Entity) + } +} + +/// Cake v1 + fruit v2 (fruit gains a column and a unique index). +pub struct CakeV1FruitV2; +impl EntitySet for CakeV1FruitV2 { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder + .register(cake_v1::Entity) + .register(fruit_v2::Entity) + } +} + +/// Cake with `name` renamed to `title` (same String type) — for rename detection tests. +pub struct CakeRenamedOnly; +impl EntitySet for CakeRenamedOnly { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder.register(cake_renamed::Entity) + } +} + +/// Cake with `name` removed and `count` (i32) added — type mismatch, no rename. +pub struct CakeTypeChangeOnly; +impl EntitySet for CakeTypeChangeOnly { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder.register(cake_type_change::Entity) + } +} + +/// Cake with `name` removed and both `title` + `label` added — ambiguous rename. +pub struct CakeAmbiguousOnly; +impl EntitySet for CakeAmbiguousOnly { + fn register(self, builder: SchemaBuilder) -> SchemaBuilder { + builder.register(cake_ambiguous::Entity) + } +} diff --git a/sea-orm-migration/tests/common/entity_migration/m20250101_000001_create_cake_table.rs b/sea-orm-migration/tests/common/entity_migration/m20250101_000001_create_cake_table.rs new file mode 100644 index 0000000000..80e7ac4443 --- /dev/null +++ b/sea-orm-migration/tests/common/entity_migration/m20250101_000001_create_cake_table.rs @@ -0,0 +1,26 @@ +use sea_orm_migration::{prelude::*, schema::*}; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Alias::new("cake")) + .if_not_exists() + .col(pk_auto(Alias::new("id"))) + .col(string_uniq(Alias::new("name"))) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(Alias::new("cake")).to_owned()) + .await + } +} diff --git a/sea-orm-migration/tests/common/entity_migration/m20250101_000002_create_fruit_table.rs b/sea-orm-migration/tests/common/entity_migration/m20250101_000002_create_fruit_table.rs new file mode 100644 index 0000000000..8c0453c10a --- /dev/null +++ b/sea-orm-migration/tests/common/entity_migration/m20250101_000002_create_fruit_table.rs @@ -0,0 +1,44 @@ +use sea_orm_migration::{prelude::*, schema::*}; +use sea_orm_migration::sea_orm::DbBackend; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Alias::new("fruit")) + .if_not_exists() + .col(pk_auto(Alias::new("id"))) + .col(string(Alias::new("name"))) + .col(integer(Alias::new("cake_id"))) + .foreign_key( + ForeignKey::create() + .name("fk-fruit-cake_id") + .from(Alias::new("fruit"), Alias::new("cake_id")) + .to(Alias::new("cake"), Alias::new("id")), + ) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + if manager.get_database_backend() != DbBackend::Sqlite { + manager + .drop_foreign_key( + ForeignKey::drop() + .table(Alias::new("fruit")) + .name("fk-fruit-cake_id") + .to_owned(), + ) + .await?; + } + manager + .drop_table(Table::drop().table(Alias::new("fruit")).to_owned()) + .await + } +} diff --git a/sea-orm-migration/tests/common/entity_migration/mod.rs b/sea-orm-migration/tests/common/entity_migration/mod.rs new file mode 100644 index 0000000000..4ec839df98 --- /dev/null +++ b/sea-orm-migration/tests/common/entity_migration/mod.rs @@ -0,0 +1,2 @@ +pub mod m20250101_000001_create_cake_table; +pub mod m20250101_000002_create_fruit_table; diff --git a/sea-orm-migration/tests/common/entity_migrator/default.rs b/sea-orm-migration/tests/common/entity_migrator/default.rs new file mode 100644 index 0000000000..092f31338f --- /dev/null +++ b/sea-orm-migration/tests/common/entity_migrator/default.rs @@ -0,0 +1,15 @@ +use sea_orm_migration::prelude::*; + +use crate::common::entity_migration::*; + +pub struct Migrator; + +#[async_trait::async_trait] +impl MigratorTrait for Migrator { + fn migrations() -> Vec> { + vec![ + Box::new(m20250101_000001_create_cake_table::Migration), + Box::new(m20250101_000002_create_fruit_table::Migration), + ] + } +} diff --git a/sea-orm-migration/tests/common/entity_migrator/mod.rs b/sea-orm-migration/tests/common/entity_migrator/mod.rs new file mode 100644 index 0000000000..1be8d340b8 --- /dev/null +++ b/sea-orm-migration/tests/common/entity_migrator/mod.rs @@ -0,0 +1 @@ +pub mod default; diff --git a/sea-orm-migration/tests/common/mod.rs b/sea-orm-migration/tests/common/mod.rs index f4ab07bae8..0cd2aaf593 100644 --- a/sea-orm-migration/tests/common/mod.rs +++ b/sea-orm-migration/tests/common/mod.rs @@ -1,2 +1,9 @@ pub mod migration; pub mod migrator; + +#[cfg(feature = "entity-first")] +pub mod entity_common; +#[cfg(feature = "entity-first")] +pub mod entity_migration; +#[cfg(feature = "entity-first")] +pub mod entity_migrator; diff --git a/sea-orm-migration/tests/entity_first.rs b/sea-orm-migration/tests/entity_first.rs new file mode 100644 index 0000000000..df534a3567 --- /dev/null +++ b/sea-orm-migration/tests/entity_first.rs @@ -0,0 +1,865 @@ +mod common; + +use std::fs; +use tempfile::TempDir; + +use common::entity_migrator::default::Migrator; +use common::entity_common::{ + CakeAmbiguousOnly, CakeRenamedOnly, CakeTypeChangeOnly, CakeV1FruitV2, CakeV1Only, + CakeV2FruitV1, FullSchema, +}; +use sea_orm::{Database, DbErr, Schema}; +use sea_orm_migration::{EntitySet, MigratorTraitSelf, SchemaManager, prelude::*}; + +async fn connect() -> Result { + Database::connect("sqlite::memory:").await +} + +/// Temporary directory with a skeleton migration `lib.rs` ready for writing into. +fn temp_migration_dir() -> TempDir { + let dir = tempfile::tempdir().expect("tempdir"); + let src = dir.path().join("src"); + fs::create_dir(&src).unwrap(); + fs::write( + src.join("lib.rs"), + r#"pub use sea_orm_migration::prelude::*; + +pub struct Migrator; + +#[async_trait::async_trait] +impl MigratorTrait for Migrator { + fn migrations() -> Vec> { + vec![] + } +} +"#, + ) + .unwrap(); + dir +} + +// --------------------------------------------------------------------------- +// summary tests — pure unit tests, no DB required +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod summary_tests { + use sea_orm::{DbBackend, Statement}; + use sea_orm_migration::summary::summarize; + + fn stmt(sql: &str) -> Statement { + Statement::from_string(DbBackend::Sqlite, sql.to_owned()) + } + + #[test] + fn test_create_table() { + assert_eq!( + summarize(&[stmt(r#"CREATE TABLE "cake" ( "id" integer NOT NULL )"#)]), + vec!["Created table: cake"] + ); + } + + #[test] + fn test_create_table_if_not_exists() { + assert_eq!( + summarize(&[stmt( + r#"CREATE TABLE IF NOT EXISTS "fruit" ( "id" integer NOT NULL )"# + )]), + vec!["Created table: fruit"] + ); + } + + #[test] + fn test_add_column() { + assert_eq!( + summarize(&[stmt(r#"ALTER TABLE "cake" ADD COLUMN "description" text"#)]), + vec!["Added column: cake.description"] + ); + } + + #[test] + fn test_drop_column() { + assert_eq!( + summarize(&[stmt(r#"ALTER TABLE "fruit" DROP COLUMN "weight_grams""#)]), + vec!["Dropped column: fruit.weight_grams"] + ); + } + + #[test] + fn test_drop_table() { + assert_eq!( + summarize(&[stmt(r#"DROP TABLE IF EXISTS "cake""#)]), + vec!["Dropped table: cake"] + ); + } + + #[test] + fn test_create_unique_index() { + assert_eq!( + summarize(&[stmt( + r#"CREATE UNIQUE INDEX "idx_cake_name" ON "cake" ("name")"# + )]), + vec!["Added unique index on: cake"] + ); + } + + #[test] + fn test_add_foreign_key() { + assert_eq!( + summarize(&[stmt( + r#"ALTER TABLE "fruit" ADD CONSTRAINT "fk_cake_id" FOREIGN KEY ("cake_id") REFERENCES "cake" ("id")"# + )]), + vec!["Added foreign key on: fruit"] + ); + } + + #[test] + fn test_multiple_stmts_ordering() { + let stmts = vec![ + stmt(r#"CREATE TABLE "cake" ( "id" integer NOT NULL )"#), + stmt(r#"CREATE TABLE "fruit" ( "id" integer NOT NULL )"#), + stmt(r#"ALTER TABLE "fruit" ADD COLUMN "weight_grams" integer"#), + ]; + assert_eq!( + summarize(&stmts), + vec![ + "Created table: cake", + "Created table: fruit", + "Added column: fruit.weight_grams", + ] + ); + } +} + +// --------------------------------------------------------------------------- +// codegen tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod codegen_tests { + use sea_orm::{DbBackend, Statement}; + use sea_orm_migration::codegen::{MigrationMetadata, render_migration_file}; + + fn cake_create_stmt() -> Statement { + Statement::from_string( + DbBackend::Sqlite, + r#"CREATE TABLE "cake" ( "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar NOT NULL )"#.to_owned(), + ) + } + + fn fruit_create_stmt() -> Statement { + Statement::from_string( + DbBackend::Sqlite, + r#"CREATE TABLE "fruit" ( "id" integer NOT NULL, "name" varchar NOT NULL, "cake_id" integer NOT NULL )"#.to_owned(), + ) + } + + fn meta<'a>(changes: &'a [String]) -> MigrationMetadata<'a> { + MigrationMetadata { + version: "0.1.0", + generated_at: "2026-01-01 00:00:00 UTC", + backend: "SQLite", + changes, + } + } + + #[test] + fn test_renders_header_metadata() { + let changes = vec![ + "Created table: cake".to_string(), + "Created table: fruit".to_string(), + ]; + let out = + render_migration_file(&[cake_create_stmt(), fruit_create_stmt()], &meta(&changes)); + assert!(out.contains("// Generated by sea-orm-entity v0.1.0")); + assert!(out.contains("// Generated at: 2026-01-01 00:00:00 UTC")); + assert!(out.contains("// Backend: SQLite")); + assert!(out.contains("// - Created table: cake")); + assert!(out.contains("// - Created table: fruit")); + } + + #[test] + fn test_renders_boilerplate() { + let changes = vec![]; + let out = render_migration_file(&[cake_create_stmt()], &meta(&changes)); + assert!(out.contains("#[derive(DeriveMigrationName)]")); + assert!(out.contains("pub struct Migration;")); + assert!(out.contains("impl MigrationTrait for Migration")); + assert!(out.contains("async fn up(&self, manager: &SchemaManager)")); + assert!(out.contains("async fn down(&self, _manager: &SchemaManager)")); + assert!(out.contains("// TODO: implement down migration")); + } + + #[test] + fn test_renders_all_stmts_as_execute_unprepared() { + let stmts = vec![cake_create_stmt(), fruit_create_stmt()]; + let changes = vec![]; + let out = render_migration_file(&stmts, &meta(&changes)); + assert!(out.contains(r#"CREATE TABLE "cake""#)); + assert!(out.contains(r#"CREATE TABLE "fruit""#)); + assert_eq!(out.matches("execute_unprepared").count(), 2); + } + + #[test] + fn test_sql_embedded_as_raw_string_literal() { + let out = render_migration_file(&[cake_create_stmt()], &meta(&[])); + assert!(out.contains(r#"r#""#), "should use raw string literal"); + } +} + +// --------------------------------------------------------------------------- +// fs tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod fs_tests { + use super::temp_migration_dir; + use sea_orm::{DbBackend, Statement}; + use sea_orm_migration::codegen::MigrationMetadata; + use std::fs; + + fn stmts() -> Vec { + vec![ + Statement::from_string( + DbBackend::Sqlite, + r#"CREATE TABLE "cake" ( "id" integer NOT NULL )"#.to_owned(), + ), + Statement::from_string( + DbBackend::Sqlite, + r#"CREATE TABLE "fruit" ( "id" integer NOT NULL )"#.to_owned(), + ), + ] + } + + fn meta(changes: &[String]) -> MigrationMetadata<'_> { + MigrationMetadata { + version: "0.1.0", + generated_at: "2026-01-01 00:00:00 UTC", + backend: "SQLite", + changes, + } + } + + #[test] + fn test_write_migration_creates_file_and_updates_lib() { + let dir = temp_migration_dir(); + let name = "m20260101_000001_create_schema"; + let changes = vec![ + "Created table: cake".to_string(), + "Created table: fruit".to_string(), + ]; + + sea_orm_migration::fs::write_migration( + dir.path().to_str().unwrap(), + name, + &stmts(), + &meta(&changes), + ) + .expect("write_migration failed"); + + let file = dir.path().join("src").join(format!("{name}.rs")); + assert!(file.exists()); + let content = fs::read_to_string(&file).unwrap(); + assert!(content.contains("DeriveMigrationName")); + assert!(content.contains("Created table: cake")); + assert!(content.contains("Created table: fruit")); + + let lib = fs::read_to_string(dir.path().join("src/lib.rs")).unwrap(); + assert!(lib.contains(&format!("mod {name};"))); + assert!(lib.contains(&format!("Box::new({name}::Migration)"))); + } + + #[test] + fn test_second_migration_appends_to_lib() { + let dir = temp_migration_dir(); + let changes: Vec = vec![]; + + sea_orm_migration::fs::write_migration( + dir.path().to_str().unwrap(), + "m20260101_000001_first", + &stmts(), + &meta(&changes), + ) + .unwrap(); + sea_orm_migration::fs::write_migration( + dir.path().to_str().unwrap(), + "m20260101_000002_second", + &stmts(), + &meta(&changes), + ) + .unwrap(); + + let lib = fs::read_to_string(dir.path().join("src/lib.rs")).unwrap(); + assert!(lib.contains("mod m20260101_000001_first;")); + assert!(lib.contains("mod m20260101_000002_second;")); + assert!(lib.contains("Box::new(m20260101_000001_first::Migration)")); + assert!(lib.contains("Box::new(m20260101_000002_second::Migration)")); + } + + #[test] + fn test_generated_file_does_not_reference_removed_migration() { + let dir = temp_migration_dir(); + let changes: Vec = vec![]; + + sea_orm_migration::fs::write_migration( + dir.path().to_str().unwrap(), + "m20260101_000001_only", + &stmts(), + &meta(&changes), + ) + .unwrap(); + + let lib = fs::read_to_string(dir.path().join("src/lib.rs")).unwrap(); + assert_eq!(lib.matches("mod m20260101_000001_only;").count(), 1); + assert_eq!(lib.matches("m20260101_000001_only::Migration").count(), 1); + } +} + +// --------------------------------------------------------------------------- +// Integration tests — full discover → generate → apply → lifecycle +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_discover_full_schema_on_empty_db() -> Result<(), DbErr> { + let db = connect().await?; + let builder = FullSchema.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + + assert!(!stmts.is_empty()); + + let sql_all: String = stmts + .iter() + .map(|s| s.sql.to_uppercase()) + .collect::>() + .join(" "); + assert!(sql_all.contains("CREATE TABLE"), "should create tables"); + + let table_names: Vec<_> = stmts + .iter() + .filter(|s| s.sql.to_uppercase().contains("CREATE TABLE")) + .collect(); + assert_eq!(table_names.len(), 2, "should create exactly cake + fruit"); + + Ok(()) +} + +#[tokio::test] +async fn test_no_diff_when_schema_matches_entities() -> Result<(), DbErr> { + let db = connect().await?; + Migrator.up(&db, None).await?; + + let builder = FullSchema.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + + assert!( + stmts.is_empty(), + "no changes expected on synced DB, got: {:?}", + stmts.iter().map(|s| &s.sql).collect::>() + ); + + Ok(()) +} + +#[tokio::test] +async fn test_discover_detects_added_columns() -> Result<(), DbErr> { + let db = connect().await?; + + sea_orm::Schema::new(db.get_database_backend()) + .builder() + .register(common::entity_common::cake_v1::Entity) + .sync(&db) + .await?; + + let builder = CakeV2FruitV1.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + + let sql_all: String = stmts + .iter() + .map(|s| s.sql.to_uppercase()) + .collect::>() + .join(" "); + assert!( + sql_all.contains("ADD COLUMN") || sql_all.contains("CREATE TABLE"), + "should detect schema additions, got stmts: {:?}", + stmts.iter().map(|s| &s.sql).collect::>() + ); + + let all_sql: String = stmts + .iter() + .map(|s| s.sql.clone()) + .collect::>() + .join(" "); + assert!( + all_sql.contains("description") || all_sql.contains("fruit"), + "should reference new columns or missing tables" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_discover_detects_added_column_and_unique_index() -> Result<(), DbErr> { + let db = connect().await?; + + sea_orm::Schema::new(db.get_database_backend()) + .builder() + .register(common::entity_common::cake_v1::Entity) + .register(common::entity_common::fruit_v1::Entity) + .sync(&db) + .await?; + + let builder = CakeV1FruitV2.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + + assert!(!stmts.is_empty(), "should detect changes"); + + let sql_all: String = stmts + .iter() + .map(|s| s.sql.clone()) + .collect::>() + .join(" "); + assert!( + sql_all.contains("weight_grams"), + "should ADD COLUMN weight_grams, got: {sql_all}" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_discover_dangerous_drops_orphaned_tables_but_not_tracker() -> Result<(), DbErr> { + let db = connect().await?; + + Migrator.up(&db, None).await?; + + let builder = CakeV1Only.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, true).await?; + let result = sea_orm::interpret_changes(change_set, &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }); + let stmts: Vec<_> = result.statements.iter().map(|(_, s)| s).collect(); + + let protected = Migrator.migration_table_name().to_string(); + let raw_drops: Vec<_> = stmts + .iter() + .filter(|s| s.sql.to_uppercase().contains("DROP TABLE")) + .collect(); + assert!( + raw_drops.iter().any(|s| s.sql.contains("fruit")), + "fruit should appear in DROP TABLE statements" + ); + + let protected_upper = protected.to_uppercase(); + let filtered: Vec<_> = stmts + .iter() + .filter(|s| { + let upper = s.sql.to_uppercase(); + if upper.contains("DROP TABLE") { + !upper.contains(&format!("\"{}\"", protected_upper)) + && !upper.contains(&format!("`{}`", protected_upper)) + && !upper.contains(&format!(" {} ", protected_upper)) + && !upper.ends_with(&format!(" {}", protected_upper)) + } else { + true + } + }) + .collect(); + + assert!( + filtered + .iter() + .any(|s| s.sql.to_uppercase().contains("DROP TABLE") && s.sql.contains("fruit")), + "fruit should still be in filtered DROP statements" + ); + assert!( + !filtered + .iter() + .any(|s| s.sql.to_lowercase().contains(&protected.to_lowercase()) + && s.sql.to_uppercase().contains("DROP TABLE")), + "seaql_migrations must not appear in filtered DROP statements" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_discover_safe_never_drops() -> Result<(), DbErr> { + let db = connect().await?; + Migrator.up(&db, None).await?; + + let builder = CakeV1Only.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + + assert!( + !stmts.iter().any(|s| s.sql.to_uppercase().contains("DROP")), + "safe discover must not emit any DROP statements" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_full_migration_lifecycle() -> Result<(), DbErr> { + let db = connect().await?; + let manager = SchemaManager::new(&db); + + let pending = Migrator.get_pending_migrations(&db).await?; + assert_eq!(pending.len(), 2); + assert_eq!(pending[0].name(), "m20250101_000001_create_cake_table"); + assert_eq!(pending[1].name(), "m20250101_000002_create_fruit_table"); + + Migrator.up(&db, Some(1)).await?; + assert!(manager.has_table("cake").await?); + assert!(!manager.has_table("fruit").await?); + assert!(manager.has_column("cake", "id").await?); + assert!(manager.has_column("cake", "name").await?); + + let pending = Migrator.get_pending_migrations(&db).await?; + assert_eq!(pending.len(), 1); + + Migrator.up(&db, None).await?; + assert!(manager.has_table("fruit").await?); + assert!(manager.has_column("fruit", "cake_id").await?); + + let applied = Migrator.get_applied_migrations(&db).await?; + assert_eq!(applied.len(), 2); + + let builder = FullSchema.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + assert!(stmts.is_empty(), "no changes after full apply, got: {:?}", stmts.iter().map(|s| &s.sql).collect::>()); + + Migrator.down(&db, Some(1)).await?; + assert!(!manager.has_table("fruit").await?); + assert!(manager.has_table("cake").await?); + + Migrator.fresh(&db).await?; + assert!(manager.has_table("cake").await?); + assert!(manager.has_table("fruit").await?); + + Migrator.reset(&db).await?; + assert!(!manager.has_table("cake").await?); + assert!(!manager.has_table("fruit").await?); + + Ok(()) +} + +#[tokio::test] +async fn test_generate_pipeline_for_full_schema() -> Result<(), DbErr> { + let db = connect().await?; + let dir = temp_migration_dir(); + + let builder = FullSchema.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, false).await?; + let stmts = change_set.statements(); + assert!(!stmts.is_empty()); + + let changes = sea_orm_migration::summary::summarize(&stmts); + assert!(changes.iter().any(|c| c.contains("cake"))); + assert!(changes.iter().any(|c| c.contains("fruit"))); + + let meta = sea_orm_migration::codegen::MigrationMetadata { + version: "0.1.0", + generated_at: "2026-01-01 00:00:00 UTC", + backend: "SQLite", + changes: &changes, + }; + let filepath = sea_orm_migration::fs::write_migration( + dir.path().to_str().unwrap(), + "m20260101_000001_create_schema", + &stmts, + &meta, + ) + .unwrap(); + + let content = fs::read_to_string(&filepath).unwrap(); + assert!(content.contains(r#"CREATE TABLE"#)); + assert!(content.contains("cake") || content.contains("fruit")); + + let lib = fs::read_to_string(dir.path().join("src/lib.rs")).unwrap(); + assert!(lib.contains("m20260101_000001_create_schema")); + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Safety & correctness tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_discover_warns_on_possible_column_rename() -> Result<(), DbErr> { + use sea_orm::schema::SuggestionKind; + let db = connect().await?; + + sea_orm::Schema::new(db.get_database_backend()) + .builder() + .register(common::entity_common::cake_v1::Entity) + .sync(&db) + .await?; + + let builder = CakeRenamedOnly.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, true).await?; + let result = sea_orm::interpret_changes(change_set, &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }); + + assert!( + result + .suggestions + .iter() + .any(|s| s.kind == SuggestionKind::PossibleRename), + "should emit PossibleRename suggestion, got suggestions: {:?}", + result.suggestions + ); + + let rename_suggestion = result + .suggestions + .iter() + .find(|s| s.kind == SuggestionKind::PossibleRename) + .unwrap(); + assert!( + rename_suggestion.message.contains("name") + && rename_suggestion.message.contains("title"), + "suggestion should mention old and new column names, got: {}", + rename_suggestion.message + ); + + let sql_all: String = result + .statements + .iter() + .map(|(_, s)| s.sql.to_uppercase()) + .collect::>() + .join(" "); + assert!( + sql_all.contains("RENAME COLUMN"), + "should produce RENAME COLUMN for auto-assumed rename; got: {sql_all}" + ); + + let has_drop_name = sql_all.contains("DROP COLUMN"); + let has_add_title = sql_all.contains("ADD COLUMN"); + assert!( + !has_drop_name && !has_add_title, + "should not produce DROP or ADD when rename is auto-assumed; got: {sql_all}" + ); + + assert!( + result.unresolved.is_empty(), + "single obvious rename should not be ambiguous" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_discover_no_rename_warning_when_types_differ() -> Result<(), DbErr> { + use sea_orm::schema::SuggestionKind; + let db = connect().await?; + + sea_orm::Schema::new(db.get_database_backend()) + .builder() + .register(common::entity_common::cake_v1::Entity) + .sync(&db) + .await?; + + let builder = + CakeTypeChangeOnly.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, true).await?; + let result = sea_orm::interpret_changes(change_set, &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }); + + assert!( + !result + .suggestions + .iter() + .any(|s| s.kind == SuggestionKind::PossibleRename), + "should not emit PossibleRename when types differ, got suggestions: {:?}", + result.suggestions + ); + + let sql_all: String = result + .statements + .iter() + .map(|(_, s)| s.sql.clone()) + .collect::>() + .join(" "); + assert!( + sql_all.contains("count"), + "should ADD COLUMN count, got: {sql_all}" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_ambiguous_rename_in_unresolved() -> Result<(), DbErr> { + let db = connect().await?; + + sea_orm::Schema::new(db.get_database_backend()) + .builder() + .register(common::entity_common::cake_v1::Entity) + .sync(&db) + .await?; + + let builder = CakeAmbiguousOnly.register(Schema::new(db.get_database_backend()).builder()); + let change_set = builder.discover(&db, true).await?; + let result = sea_orm::interpret_changes(change_set, &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }); + + assert!( + !result.unresolved.is_empty(), + "expected unresolved ambiguous renames, got none; warnings: {:?}, statements: {:?}", + result.warnings, + result.statements.iter().map(|(_, s)| &s.sql).collect::>() + ); + + let ambiguous = &result.unresolved[0]; + assert_eq!(ambiguous.removed, "name"); + assert!( + ambiguous.candidates.len() >= 2, + "should have multiple candidates, got: {:?}", + ambiguous.candidates + ); + + let has_rename = result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("RENAME COLUMN")); + assert!( + !has_rename, + "should not generate RENAME COLUMN for ambiguous rename" + ); + + assert!( + !result + .suggestions + .iter() + .any(|s| s.kind == sea_orm::schema::SuggestionKind::PossibleRename), + "ambiguous renames should not produce PossibleRename suggestions" + ); + + Ok(()) +} + +#[cfg(test)] +mod enum_warning_tests { + use sea_orm::schema::extract_enum_type_name; + + #[test] + fn test_extract_enum_type_name_postgres() { + let sql = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad', 'neutral')"#; + assert_eq!(extract_enum_type_name(sql), Some("mood".to_string())); + } + + #[test] + fn test_extract_enum_type_name_no_match() { + let sql = r#"CREATE TABLE "cake" ( "id" integer NOT NULL )"#; + assert_eq!(extract_enum_type_name(sql), None); + } + + #[test] + fn test_enum_same_name_different_variants_detected() { + let existing_sql = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad')"#; + let new_sql = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad', 'neutral')"#; + + let existing_name = extract_enum_type_name(existing_sql); + let new_name = extract_enum_type_name(new_sql); + + assert_eq!(existing_name, new_name); + assert_ne!(existing_sql, new_sql); + } +} + +#[cfg(test)] +mod safety_summary_tests { + use sea_orm::{DbBackend, Statement}; + use sea_orm_migration::summary::summarize; + + fn stmt(sql: &str) -> Statement { + Statement::from_string(DbBackend::Sqlite, sql.to_owned()) + } + + #[test] + fn test_summary_rename_column() { + assert_eq!( + summarize(&[stmt( + r#"ALTER TABLE "cake" RENAME COLUMN "name" TO "title""# + )]), + vec!["Renamed column on: cake"] + ); + } + + #[test] + fn test_summary_alter_type_add_value() { + assert_eq!( + summarize(&[stmt( + r#"ALTER TYPE "mood" ADD VALUE 'neutral'"# + )]), + vec!["Added enum variant"] + ); + } +} + +#[cfg(test)] +mod warning_type_tests { + use sea_orm::schema::{DiscoverSuggestion, DiscoverWarning, SuggestionKind, WarningKind}; + + #[test] + fn test_warning_kinds_are_eq() { + assert_eq!( + WarningKind::CheckConstraintDiff, + WarningKind::CheckConstraintDiff + ); + assert_ne!( + WarningKind::CheckConstraintDiff, + WarningKind::NotNullNoDefault + ); + } + + #[test] + fn test_suggestion_kinds_are_eq() { + assert_eq!(SuggestionKind::PossibleRename, SuggestionKind::PossibleRename); + assert_ne!(SuggestionKind::PossibleRename, SuggestionKind::EnumVariantChange); + assert_ne!( + SuggestionKind::EnumRename, + SuggestionKind::EnumVariantChange + ); + } + + #[test] + fn test_warning_debug_format() { + let w = DiscoverWarning { + kind: WarningKind::CheckConstraintDiff, + message: "CHECK constraint cannot be diffed".to_string(), + related_changes: vec![], + }; + let debug = format!("{w:?}"); + assert!(debug.contains("CheckConstraintDiff")); + } + + #[test] + fn test_suggestion_debug_format() { + let s = DiscoverSuggestion { + kind: SuggestionKind::PossibleRename, + message: "Column 'name' may have been renamed to 'title'".to_string(), + related_changes: vec![], + }; + let debug = format!("{s:?}"); + assert!(debug.contains("PossibleRename")); + assert!(debug.contains("name")); + } +} diff --git a/sea-orm-migration/tests/main.rs b/sea-orm-migration/tests/main.rs index 9d1a0240ab..5c0dee73c2 100644 --- a/sea-orm-migration/tests/main.rs +++ b/sea-orm-migration/tests/main.rs @@ -6,11 +6,6 @@ use sea_orm_migration::{MigratorTraitSelf, migrator::MigrationStatus, prelude::* #[tokio::test] async fn main() -> Result<(), DbErr> { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .with_test_writer() - .init(); - let url = &std::env::var("DATABASE_URL").expect("Environment variable 'DATABASE_URL' not set"); run_migration(url, default::Migrator, "sea_orm_migration", "public").await?; diff --git a/sea-orm-sync/Cargo.toml b/sea-orm-sync/Cargo.toml index fc40e706b2..cfb96ac211 100644 --- a/sea-orm-sync/Cargo.toml +++ b/sea-orm-sync/Cargo.toml @@ -199,3 +199,4 @@ with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-rusqlite?/with-uuid"] [patch.crates-io] # sea-query = { path = "../sea-query" } +sea-query = { git = "https://github.com/SeaQL/sea-query", branch = "master" } diff --git a/sea-orm-sync/tests/common/features/value_type.rs b/sea-orm-sync/tests/common/features/value_type.rs index a963f2c522..a58dbc283b 100644 --- a/sea-orm-sync/tests/common/features/value_type.rs +++ b/sea-orm-sync/tests/common/features/value_type.rs @@ -69,15 +69,9 @@ where } } -// Automatically disable vec impl #[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] pub struct StringVec(pub Vec); -// Explicitly disable vec impl -#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] -#[sea_orm(no_vec_impl)] -pub struct StringVecNoImpl(pub Vec); - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] #[sea_orm(value_type = "String")] pub enum Tag1 { diff --git a/sea-orm-sync/tests/derive_tests.rs b/sea-orm-sync/tests/derive_tests.rs index e33cae65c7..e482280052 100644 --- a/sea-orm-sync/tests/derive_tests.rs +++ b/sea-orm-sync/tests/derive_tests.rs @@ -69,69 +69,3 @@ struct FromQueryResultNested { #[sea_orm(nested)] _test: SimpleTest, } - -#[cfg(feature = "postgres-array")] -mod postgres_array { - use crate::FromQueryResult; - use sea_orm::DeriveValueType; - - #[derive(DeriveValueType)] - pub struct IngredientId(i32); - - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] - #[sea_orm(value_type = "String")] - pub struct NumericLabel { - pub value: i64, - } - - impl std::fmt::Display for NumericLabel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) - } - } - - impl std::str::FromStr for NumericLabel { - type Err = std::num::ParseIntError; - fn from_str(s: &str) -> Result { - Ok(Self { value: s.parse()? }) - } - } - - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] - #[sea_orm(value_type = "String")] - pub enum TextureKind { - Hard, - Soft, - } - - impl std::fmt::Display for TextureKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Hard => "hard", - Self::Soft => "soft", - } - ) - } - } - - impl std::str::FromStr for TextureKind { - type Err = sea_query::ValueTypeErr; - fn from_str(s: &str) -> Result { - Ok(match s { - "hard" => Self::Hard, - "soft" => Self::Soft, - _ => return Err(sea_query::ValueTypeErr), - }) - } - } - - #[derive(FromQueryResult)] - pub struct IngredientPathRow { - pub ingredient_path: Vec, - pub numeric_label_path: Vec, - pub texture_path: Vec, - } -} diff --git a/src/query/helper.rs b/src/query/helper.rs index 7aef179692..c3b0971722 100644 --- a/src/query/helper.rs +++ b/src/query/helper.rs @@ -909,7 +909,7 @@ pub(crate) fn join_tbl_on_condition( foreign_keys: Identity, ) -> Condition { let mut cond = Condition::all(); - for (owner_key, foreign_key) in owner_keys.into_iter().zip(foreign_keys.into_iter()) { + for (owner_key, foreign_key) in owner_keys.into_iter().zip(foreign_keys) { cond = cond .add(Expr::col((from_tbl.clone(), owner_key)).equals((to_tbl.clone(), foreign_key))); } diff --git a/src/schema/builder.rs b/src/schema/builder.rs index baa29c68f3..a47cba7934 100644 --- a/src/schema/builder.rs +++ b/src/schema/builder.rs @@ -1,14 +1,24 @@ use super::{Schema, TopologicalSort}; use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement}; use sea_query::{ - ForeignKeyCreateStatement, Index, IndexCreateStatement, IntoIden, TableAlterStatement, - TableCreateStatement, TableName, TableRef, extension::postgres::TypeCreateStatement, + IndexCreateStatement, TableCreateStatement, TableName, TableRef, + extension::postgres::TypeCreateStatement, +}; + +#[cfg(feature = "schema-sync")] +pub use super::discover::resolver::extract_enum_type_name; +#[cfg(feature = "schema-sync")] +pub use super::discover::{ + DiscoverSuggestion, DiscoverWarning, InterpretConfig, InterpretResult, RenameDecision, + SchemaChangeId, SuggestionKind, WarningKind, interpret::interpret as interpret_changes, }; /// A schema builder that can take a registry of Entities and synchronize it with database. pub struct SchemaBuilder { helper: Schema, entities: Vec, + #[cfg(feature = "schema-sync")] + excluded_tables: Vec, } /// Schema info for Entity. Can be used to re-create schema in database. @@ -48,6 +58,8 @@ impl SchemaBuilder { Self { helper: schema, entities: Default::default(), + #[cfg(feature = "schema-sync")] + excluded_tables: Default::default(), } } @@ -74,6 +86,17 @@ impl SchemaBuilder { self.entities.push(entity); } + /// Exclude tables from schema discovery. + /// + /// Excluded tables are never reported as orphans and are never diffed for column/FK changes. + /// Use this to protect system tables (e.g. the migration tracker) from being dropped. + #[cfg(feature = "schema-sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] + pub fn exclude(mut self, table: impl Into) -> Self { + self.excluded_tables.push(table.into()); + self + } + /// Synchronize the schema with database, will create missing tables, columns, unique keys, and foreign keys. /// This operation is addition only, will not drop any table / columns. #[cfg(feature = "schema-sync")] @@ -82,164 +105,59 @@ impl SchemaBuilder { where C: ConnectionTrait + sea_schema::Connection, { - let _existing = - match db.get_database_backend() { - #[cfg(feature = "sqlx-mysql")] - DbBackend::MySql => { - use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe}; - - let current_schema: String = db - .query_one( - sea_query::SelectStatement::new() - .expr(sea_schema::mysql::MySql::get_current_schema()), - ) - .await? - .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? - .try_get_by_index(0)?; - - // Collect all unique schemas that registered entities belong to - let mut target_schemas = std::collections::BTreeSet::new(); - for entity in &self.entities { - let schema = entity.schema_name.as_deref().unwrap_or(¤t_schema); - target_schemas.insert(schema.to_string()); - } - - let mut tables_by_schema = std::collections::HashMap::new(); - for schema_name in &target_schemas { - let schema_discovery = SchemaDiscovery::new_no_exec(schema_name); - let schema = schema_discovery.discover_with(db).await.map_err(|err| { - DbErr::Query(crate::RuntimeErr::SqlxError(err.into())) - })?; - - tables_by_schema.insert( - schema_name.clone(), - schema.tables.iter().map(|table| table.write()).collect(), - ); - } - - DiscoveredSchema { - current_schema, - tables_by_schema, - enums_by_schema: Default::default(), - } - } - #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres => { - use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe}; - - let current_schema: String = db - .query_one( - sea_query::SelectStatement::new() - .expr(sea_schema::postgres::Postgres::get_current_schema()), - ) - .await? - .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? - .try_get_by_index(0)?; - - // Collect all unique schemas that registered entities belong to - let mut target_schemas = std::collections::BTreeSet::new(); - for entity in &self.entities { - let schema = entity.schema_name.as_deref().unwrap_or(¤t_schema); - target_schemas.insert(schema.to_string()); - } - - let mut tables_by_schema = std::collections::HashMap::new(); - let mut enums_by_schema = std::collections::HashMap::new(); - for schema_name in &target_schemas { - let schema_discovery = SchemaDiscovery::new_no_exec(schema_name); - let schema = schema_discovery.discover_with(db).await.map_err(|err| { - DbErr::Query(crate::RuntimeErr::SqlxError(err.into())) - })?; - - tables_by_schema.insert( - schema_name.clone(), - schema.tables.iter().map(|table| table.write()).collect(), - ); - enums_by_schema.insert( - schema_name.clone(), - schema.enums.iter().map(|def| def.write()).collect(), - ); - } - - DiscoveredSchema { - current_schema, - tables_by_schema, - enums_by_schema, - } - } - #[cfg(feature = "sqlx-sqlite")] - DbBackend::Sqlite => { - use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; - let schema = SchemaDiscovery::discover_with(db) - .await - .map_err(|err| { - DbErr::Query(match err { - SqliteDiscoveryError::SqlxError(err) => { - crate::RuntimeErr::SqlxError(err.into()) - } - _ => crate::RuntimeErr::Internal(format!("{err:?}")), - }) - })? - .merge_indexes_into_table(); - let mut tables_by_schema = std::collections::HashMap::new(); - tables_by_schema.insert( - String::new(), - schema.tables.iter().map(|table| table.write()).collect(), - ); - DiscoveredSchema { - current_schema: String::new(), - tables_by_schema, - enums_by_schema: Default::default(), - } - } - #[cfg(feature = "rusqlite")] - DbBackend::Sqlite => { - use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; - let schema = SchemaDiscovery::discover_with(db) - .map_err(|err| { - DbErr::Query(match err { - SqliteDiscoveryError::RusqliteError(err) => { - crate::RuntimeErr::Rusqlite(err.into()) - } - _ => crate::RuntimeErr::Internal(format!("{err:?}")), - }) - })? - .merge_indexes_into_table(); - let mut tables_by_schema = std::collections::HashMap::new(); - tables_by_schema.insert( - String::new(), - schema.tables.iter().map(|table| table.write()).collect(), - ); - DiscoveredSchema { - current_schema: String::new(), - tables_by_schema, - enums_by_schema: Default::default(), - } - } - #[allow(unreachable_patterns)] - other => { - return Err(DbErr::BackendNotSupported { - db: other.as_str(), - ctx: "SchemaBuilder::sync", - }); - } - }; + let change_set = self.discover(db, true).await?; + for stmt in change_set.statements() { + db.execute_raw(stmt).await?; + } + Ok(()) + } - #[allow(unreachable_code)] - let mut created_enums: Vec = Default::default(); + /// Returns a [`ChangeSet`](super::discover::changes::ChangeSet) grouped by origin. + /// Use [`interpret`](super::discover::interpret) to turn it into SQL statements. + /// + /// * `db` - The database connection to use for fetching existing table schema. + /// * `allow_dangerous` - If `true`, changes will include drops (tables, columns, constraints). + /// + /// Panics if TableCreateStatement any table name is empty, will never happen. + #[cfg(feature = "schema-sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] + pub async fn discover( + &self, + db: &C, + allow_dangerous: bool, + ) -> Result + where + C: ConnectionTrait + sea_schema::Connection, + { + super::discover::discover(&self.entities, db, allow_dangerous, &self.excluded_tables).await + } - #[allow(unreachable_code)] - for table_name in self.sorted_tables() { + /// Returns the SQL DDL statements (CREATE TABLE, CREATE TYPE, CREATE INDEX) for all + /// registered entities, rendered for the builder's backend. + /// + /// Tables are ordered topologically (parents before children). Useful for previewing + /// the schema without connecting to a database. + pub fn schema_statements(&self) -> Vec { + let backend = self.helper.backend; + let mut stmts: Vec = Vec::new(); + let table_refs: Vec<&TableCreateStatement> = + self.entities.iter().map(|e| &e.table).collect(); + for table_name in sorted_tables(&table_refs, TableSortOrder::ParentsFirst) { if let Some(entity) = self .entities .iter() - .find(|entity| table_name == get_table_name(entity.table.get_table_name())) + .find(|e| table_name == get_table_name(e.table.get_table_name())) { - entity.sync(db, &_existing, &mut created_enums).await?; + for stmt in &entity.enums { + stmts.push(backend.build(stmt)); + } + stmts.push(backend.build(&entity.table)); + for stmt in &entity.indexes { + stmts.push(backend.build(stmt)); + } } } - - Ok(()) + stmts } /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys. @@ -247,7 +165,9 @@ impl SchemaBuilder { pub async fn apply(self, db: &C) -> Result<(), DbErr> { let mut created_enums: Vec = Default::default(); - for table_name in self.sorted_tables() { + let table_refs: Vec<&TableCreateStatement> = + self.entities.iter().map(|entity| &entity.table).collect(); + for table_name in sorted_tables(&table_refs, TableSortOrder::ParentsFirst) { if let Some(entity) = self .entities .iter() @@ -259,41 +179,6 @@ impl SchemaBuilder { Ok(()) } - - fn sorted_tables(&self) -> Vec { - let mut sorter = TopologicalSort::::new(); - - for entity in self.entities.iter() { - let table_name = get_table_name(entity.table.get_table_name()); - sorter.insert(table_name); - } - for entity in self.entities.iter() { - let self_table = get_table_name(entity.table.get_table_name()); - for fk in entity.table.get_foreign_key_create_stmts().iter() { - let fk = fk.get_foreign_key(); - let ref_table = get_table_name(fk.get_ref_table()); - if self_table != ref_table { - // self cycle is okay - sorter.add_dependency(ref_table, self_table.clone()); - } - } - } - let mut sorted = Vec::new(); - while let Some(i) = sorter.pop() { - sorted.push(i); - } - if sorted.len() != self.entities.len() { - // push leftover tables - for entity in self.entities.iter() { - let table_name = get_table_name(entity.table.get_table_name()); - if !sorted.contains(&table_name) { - sorted.push(table_name); - } - } - } - - sorted - } } struct DiscoveredSchema { @@ -349,6 +234,24 @@ impl EntitySchemaInfo { } } + /// Returns a reference to the table create statement. + #[cfg(feature = "schema-sync")] + pub(crate) fn table(&self) -> &TableCreateStatement { + &self.table + } + + /// Returns a reference to the enum type create statements. + #[cfg(feature = "schema-sync")] + pub(crate) fn enums(&self) -> &[TypeCreateStatement] { + &self.enums + } + + /// Returns a reference to the index create statements. + #[cfg(feature = "schema-sync")] + pub(crate) fn indexes(&self) -> &[IndexCreateStatement] { + &self.indexes + } + async fn apply( &self, db: &C, @@ -368,215 +271,6 @@ impl EntitySchemaInfo { Ok(()) } - // better to always compile this function - #[allow(dead_code)] - async fn sync( - &self, - db: &C, - existing: &DiscoveredSchema, - created_enums: &mut Vec, - ) -> Result<(), DbErr> { - let db_backend = db.get_database_backend(); - - // create enum before creating table - let existing_enums = existing.find_enums(self.schema_name.as_deref()); - for stmt in self.enums.iter() { - let mut has_enum = false; - let new_stmt = db_backend.build(stmt); - for existing_enum in existing_enums { - if db_backend.build(existing_enum) == new_stmt { - has_enum = true; - // TODO add enum variants - break; - } - } - if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) { - db.execute(stmt).await?; - created_enums.push(new_stmt); - } - } - let table_name = get_table_name(self.table.get_table_name()); - // Use schema-aware lookup: find existing table in the correct schema - let existing_table = existing.find_table(self.schema_name.as_deref(), &table_name); - if let Some(existing_table) = existing_table { - for column_def in self.table.get_columns() { - let mut column_exists = false; - for existing_column in existing_table.get_columns() { - if column_def.get_column_name() == existing_column.get_column_name() { - column_exists = true; - break; - } - } - if !column_exists { - let mut renamed_from = ""; - if let Some(comment) = &column_def.get_column_spec().comment { - if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") { - if let Some((prefix, _)) = suffix.split_once('"') { - renamed_from = prefix; - } - } - } - if renamed_from.is_empty() { - db.execute( - TableAlterStatement::new() - .table(self.table.get_table_name().expect("Checked above").clone()) - .add_column(column_def.to_owned()), - ) - .await?; - } else { - db.execute( - TableAlterStatement::new() - .table(self.table.get_table_name().expect("Checked above").clone()) - .rename_column( - renamed_from.to_owned(), - column_def.get_column_name(), - ), - ) - .await?; - } - } - } - if db.get_database_backend() != DbBackend::Sqlite { - for foreign_key in self.table.get_foreign_key_create_stmts().iter() { - let mut key_exists = false; - for existing_key in existing_table.get_foreign_key_create_stmts().iter() { - if compare_foreign_key(foreign_key, existing_key) { - key_exists = true; - break; - } - } - if !key_exists { - db.execute(foreign_key).await?; - } - } - } - } else { - db.execute(&self.table).await?; - } - for stmt in self.indexes.iter() { - let mut has_index = false; - if let Some(existing_table) = existing_table { - for existing_index in existing_table.get_indexes() { - if existing_index.get_index_spec().get_column_names() - == stmt.get_index_spec().get_column_names() - { - has_index = true; - break; - } - } - } - if !has_index { - // shall we do alter table add constraint for unique index? - let mut stmt = stmt.clone(); - stmt.if_not_exists(); - db.execute(&stmt).await?; - } - } - if let Some(existing_table) = existing_table { - // For columns with a column-level UNIQUE constraint (#[sea_orm(unique)]) that - // already exist in the table but do not yet have a unique index, create one. - for column_def in self.table.get_columns() { - if column_def.get_column_spec().unique { - let col_name = column_def.get_column_name(); - let col_exists = existing_table - .get_columns() - .iter() - .any(|c| c.get_column_name() == col_name); - if !col_exists { - // Column is being added in this sync pass; the ALTER TABLE ADD COLUMN - // will include the UNIQUE inline, so no separate index needed. - continue; - } - let already_unique = existing_table.get_indexes().iter().any(|idx| { - if !idx.is_unique_key() { - return false; - } - let cols = idx.get_index_spec().get_column_names(); - cols.len() == 1 && cols[0] == col_name - }); - if !already_unique { - let table_name = - self.table.get_table_name().expect("table must have a name"); - let tbl_str = table_name.sea_orm_table().to_string(); - let table_ref = table_name.clone(); - db.execute( - Index::create() - .name(format!("idx-{tbl_str}-{col_name}")) - .table(table_ref) - .col(col_name.into_iden()) - .unique() - .if_not_exists(), - ) - .await?; - } - } - } - } - if let Some(existing_table) = existing_table { - // find all unique keys from existing table - // if it no longer exist in new schema, drop it - for existing_index in existing_table.get_indexes() { - if existing_index.is_unique_key() { - let mut has_index = false; - for stmt in self.indexes.iter() { - if existing_index.get_index_spec().get_column_names() - == stmt.get_index_spec().get_column_names() - { - has_index = true; - break; - } - } - // Also check if the unique index corresponds to a column-level UNIQUE - // constraint (from #[sea_orm(unique)]). These are embedded in the CREATE - // TABLE column definition and not tracked in self.indexes, so we must not - // try to drop them during sync. - if !has_index { - let index_cols = existing_index.get_index_spec().get_column_names(); - if index_cols.len() == 1 { - for column_def in self.table.get_columns() { - if column_def.get_column_name() == index_cols[0] - && column_def.get_column_spec().unique - { - has_index = true; - break; - } - } - } - } - if !has_index { - if let Some(drop_existing) = existing_index - .get_index_spec() - .get_name() - .map(|s| s.to_owned()) - { - if db_backend == DbBackend::Postgres { - // On PostgreSQL, unique indexes created via column-level UNIQUE - // (e.g. ADD COLUMN ... UNIQUE) are backed by a named constraint. - // DROP INDEX fails on constraint-owned indexes; use - // ALTER TABLE ... DROP CONSTRAINT instead. - db.execute( - TableAlterStatement::new() - .table( - self.table - .get_table_name() - .expect("Checked above") - .clone(), - ) - .drop_constraint(drop_existing), - ) - .await?; - } else { - db.execute(sea_query::Index::drop().name(drop_existing)) - .await?; - } - } - } - } - } - } - Ok(()) - } - fn debug_print( &self, f: &mut std::fmt::Formatter<'_>, @@ -604,7 +298,9 @@ impl EntitySchemaInfo { } } -fn get_table_name(table_ref: Option<&TableRef>) -> TableName { +/// Panics if the table reference is not a table name +pub(crate) fn get_table_name(table_ref: Option<&TableRef>) -> TableName { + //TODO: either rewrite TableCreateStatement or move to something else that is not a builder with options match table_ref { Some(TableRef::Table(table_name, _)) => table_name.clone(), None => panic!("Expect TableCreateStatement is properly built"), @@ -612,12 +308,62 @@ fn get_table_name(table_ref: Option<&TableRef>) -> TableName { } } -fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool { - let a = a.get_foreign_key(); - let b = b.get_foreign_key(); +/// Controls which tables appear first in [`sorted_tables`] output. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum TableSortOrder { + /// Parent tables (no FK dependents) appear before children + ParentsFirst, + /// Child tables (FK holders) appear before parents + ChildrenFirst, +} + +/// Sort table names topologically by FK dependency +pub(crate) fn sorted_tables( + tables: &[&TableCreateStatement], + order: TableSortOrder, +) -> Vec { + let mut sorter = TopologicalSort::::new(); + + for tbl in tables { + sorter.insert(get_table_name(tbl.get_table_name())); + } + for tbl in tables { + let self_name = get_table_name(tbl.get_table_name()); + for fk in tbl.get_foreign_key_create_stmts() { + let ref_table = get_table_name(fk.get_foreign_key().get_ref_table()); + if self_name != ref_table { + match order { + TableSortOrder::ParentsFirst => { + sorter.add_dependency(ref_table.clone(), self_name.clone()); + } + TableSortOrder::ChildrenFirst => { + sorter.add_dependency(self_name.clone(), ref_table.clone()); + } + } + } + } + } + let mut sorted = Vec::new(); + loop { + // Collect all zero-predecessor nodes, sort by name for determinism, + // then drain them one level at a time. Without this sort, HashMap + // iteration order inside TopologicalSort::peek() is random per process, + // causing different orderings across subprocess invocations (e.g. diff + // vs generate in `entity sync`), which breaks the schema-hash check. + let mut level = sorter.pop_all(); + if level.is_empty() { + break; + } + level.sort_by(|a, b| a.1.to_string().cmp(&b.1.to_string())); + sorted.extend(level); + } - a.get_name() == b.get_name() - || (a.get_ref_table() == b.get_ref_table() - && a.get_columns() == b.get_columns() - && a.get_ref_columns() == b.get_ref_columns()) + // Append any leftovers (circular deps) + for tbl in tables { + let name = get_table_name(tbl.get_table_name()); + if !sorted.contains(&name) { + sorted.push(name); + } + } + sorted } diff --git a/src/schema/discover/changes.rs b/src/schema/discover/changes.rs new file mode 100644 index 0000000000..688c8784df --- /dev/null +++ b/src/schema/discover/changes.rs @@ -0,0 +1,237 @@ +use crate::Statement; +use sea_query::{ColumnType, TableName, TableRef}; + +/// Unique identifier for a recorded schema change. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ChangeId(pub usize); + +// ── Table-level changes ────────────────────────────────────────────────── + +/// A table-level change detected during schema discovery. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct TableChange { + pub id: ChangeId, + pub kind: TableChangeKind, +} + +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub enum TableChangeKind { + /// Table exists in entities but not in the database. + /// Carries the pre-built CREATE TABLE statement. + Create { table: String, stmt: Statement }, + /// Table exists in the database but not in entities. + Drop { table: TableName }, +} + +// ── Column-level changes ───────────────────────────────────────────────── + +/// A column-level change detected during schema discovery. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct ColumnChange { + pub id: ChangeId, + pub table: String, + pub kind: ColumnChangeKind, +} + +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub enum ColumnChangeKind { + /// Column exists in entity but not in the database. + Add { + column: String, + /// Position index in entity's column list. + index: usize, + column_type: Option, + is_not_null: bool, + has_default: bool, + /// Pre-built ALTER TABLE ADD COLUMN statement. + stmt: Statement, + }, + /// Column exists in the database but not in the entity. + Drop { + column: String, + /// Position index in DB table's column list. + index: usize, + column_type: Option, + /// Pre-built ALTER TABLE DROP COLUMN statement. + stmt: Statement, + }, + /// An explicit rename annotation was found on a column. + /// Carries the pre-built ALTER TABLE RENAME COLUMN statement. + ExplicitRename { + from: String, + to: String, + stmt: Statement, + }, + /// A column has a CHECK constraint that cannot be automatically diffed. + CheckConstraintPresent { column: String }, +} + +// ── Constraint-level changes ───────────────────────────────────────────── + +/// A constraint/index-level change detected during schema discovery. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct ConstraintChange { + pub id: ChangeId, + pub table: String, + pub kind: ConstraintChangeKind, +} + +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub enum ConstraintChangeKind { + /// Pre-built ADD FOREIGN KEY statement. + AddForeignKey { stmt: Statement }, + /// Pre-built ALTER TABLE DROP FOREIGN KEY statement. + DropForeignKey { name: String, stmt: Statement }, + /// Pre-built CREATE INDEX statement. + AddIndex { stmt: Statement }, + /// Pre-built CREATE UNIQUE INDEX statement. + AddUniqueConstraint { column: String, stmt: Statement }, + /// Pre-built DROP INDEX / DROP CONSTRAINT statement. + DropUniqueConstraint { name: String, stmt: Statement }, +} + +// ── Enum-level changes ─────────────────────────────────────────────────── + +/// An enum-level change detected during schema discovery. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct EnumChange { + pub id: ChangeId, + pub kind: EnumChangeKind, +} + +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub enum EnumChangeKind { + /// Enum type exists in entities but not in the database. + /// Carries the pre-built CREATE TYPE statement. + Create { stmt: Statement }, + /// Same enum name, different variants. + VariantChange { + name: String, + existing_sql: String, + new_sql: String, + }, + /// Same variants, different name — enum was renamed. + Rename { + existing_name: String, + new_name: String, + }, + /// Enum type exists in the database but not in any registered entity. + /// Carries the pre-built DROP TYPE statement. + Drop { name: String, stmt: Statement }, +} + +// ── Grouped change set ─────────────────────────────────────────────────── + +/// All recorded changes from Phase 1, grouped by origin. +/// +/// Each group contains changes classified by their source: tables, columns, +/// constraints, and enums. Changes that require SQL carry pre-built +/// [`Statement`]s so that Phase 2 interpretation does not need access to +/// the original entity definitions. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone, Default)] +pub struct ChangeSet { + /// Table-level changes: creates and drops of entire tables. + pub tables: Vec, + /// Column-level changes: adds, drops, explicit renames, and CHECK constraint flags. + pub columns: Vec, + /// Constraint/index-level changes: foreign keys, indexes, and unique constraints. + pub constraints: Vec, + /// Enum type changes: creates, variant diffs, and renames. + pub enums: Vec, + /// SQL strings of already-recorded enum CREATE statements, used for deduplication. + created_enum_sqls: Vec, + /// Internal counter for generating unique [`ChangeId`]s. + next_id: usize, +} + +impl ChangeSet { + fn next_id(&mut self) -> ChangeId { + let id = ChangeId(self.next_id); + self.next_id += 1; + id + } + + pub fn record_table(&mut self, kind: TableChangeKind) -> ChangeId { + let id = self.next_id(); + self.tables.push(TableChange { id, kind }); + id + } + + pub fn record_column(&mut self, table: String, kind: ColumnChangeKind) -> ChangeId { + let id = self.next_id(); + self.columns.push(ColumnChange { id, table, kind }); + id + } + + pub fn record_constraint(&mut self, table: String, kind: ConstraintChangeKind) -> ChangeId { + let id = self.next_id(); + self.constraints.push(ConstraintChange { id, table, kind }); + id + } + + /// Record a new enum CREATE, deduplicating by SQL string. + /// Returns `Some(ChangeId)` if recorded, `None` if already seen. + pub fn record_enum_create(&mut self, sql: &str, stmt: Statement) -> Option { + if self.created_enum_sqls.iter().any(|s| s == sql) { + return None; + } + self.created_enum_sqls.push(sql.to_owned()); + Some(self.record_enum(EnumChangeKind::Create { stmt })) + } + + pub fn record_enum(&mut self, kind: EnumChangeKind) -> ChangeId { + let id = self.next_id(); + self.enums.push(EnumChange { id, kind }); + id + } + + /// Collect all pre-built statements from every group, in recording order. + /// Useful for simple apply-all scenarios like `sync()`. + pub fn statements(self) -> Vec { + let mut stmts = Vec::new(); + + for ec in self.enums { + match ec.kind { + EnumChangeKind::Create { stmt } => stmts.push(stmt), + EnumChangeKind::Drop { .. } + | EnumChangeKind::VariantChange { .. } + | EnumChangeKind::Rename { .. } => {} // sync never drops or modifies enums + } + } + for tc in self.tables { + match tc.kind { + TableChangeKind::Create { stmt, .. } => stmts.push(stmt), + TableChangeKind::Drop { .. } => {} // sync never drops + } + } + for cc in self.columns { + match cc.kind { + ColumnChangeKind::Add { stmt, .. } + | ColumnChangeKind::ExplicitRename { stmt, .. } => stmts.push(stmt), + ColumnChangeKind::Drop { .. } | ColumnChangeKind::CheckConstraintPresent { .. } => { + } + } + } + for cc in self.constraints { + match cc.kind { + ConstraintChangeKind::AddForeignKey { stmt } + | ConstraintChangeKind::AddIndex { stmt } + | ConstraintChangeKind::AddUniqueConstraint { stmt, .. } => stmts.push(stmt), + ConstraintChangeKind::DropForeignKey { .. } + | ConstraintChangeKind::DropUniqueConstraint { .. } => {} + } + } + + stmts + } +} diff --git a/src/schema/discover/enum_.rs b/src/schema/discover/enum_.rs new file mode 100644 index 0000000000..39f9b9e486 --- /dev/null +++ b/src/schema/discover/enum_.rs @@ -0,0 +1,92 @@ +use super::changes::{ChangeSet, EnumChangeKind}; +use super::resolver::{self, extract_enum_type_name}; +use super::schema::DiscoveredSchema; +use crate::DbBackend; +use sea_query::extension::postgres::TypeCreateStatement; + +/// Phase 1: Record enum types in the database that have no matching entity (allow_dangerous only). +/// Records a DROP TYPE for each orphan enum. +pub(crate) fn record_orphan_enums( + all_entity_enums: &[&TypeCreateStatement], + db_backend: DbBackend, + existing: &[TypeCreateStatement], + changes: &mut ChangeSet, +) { + for existing_enum in existing { + let existing_stmt = db_backend.build(existing_enum); + let Some(existing_name) = extract_enum_type_name(&existing_stmt.sql) else { + continue; + }; + let in_entities = all_entity_enums.iter().any(|e| { + let s = db_backend.build(*e); + extract_enum_type_name(&s.sql).as_deref() == Some(existing_name.as_str()) + }); + if !in_entities { + let stmt = db_backend.build( + &sea_query::extension::postgres::Type::drop() + .name(sea_query::Alias::new(existing_name.as_str())) + .if_exists() + .to_owned(), + ); + changes.record_enum(EnumChangeKind::Drop { + name: existing_name, + stmt, + }); + } + } +} + +/// Phase 1: Record enum changes for a single entity's enum definitions against the existing schema. +pub(crate) fn record_enum_changes( + entity_enums: &[TypeCreateStatement], + db_backend: DbBackend, + existing: &[TypeCreateStatement], + changes: &mut ChangeSet, +) { + for stmt in entity_enums.iter() { + let new_stmt = db_backend.build(stmt); + let new_sql = &new_stmt.sql; + + let mut exact_match = false; + let mut change_detected = false; + + for existing_enum in existing { + let existing_stmt = db_backend.build(existing_enum); + if existing_stmt == new_stmt { + exact_match = true; + break; + } + if let Some(enum_change) = resolver::detect_enum_change(&existing_stmt.sql, new_sql) { + change_detected = true; + match enum_change { + resolver::EnumChange::VariantChange { + name, + existing_sql, + new_sql, + } => { + changes.record_enum(EnumChangeKind::VariantChange { + name, + existing_sql, + new_sql, + }); + } + resolver::EnumChange::NameChange { + existing_name, + new_name, + } => { + changes.record_enum(EnumChangeKind::Rename { + existing_name, + new_name, + }); + } + } + break; + } + } + + if !exact_match && !change_detected { + let sql = new_sql.clone(); + changes.record_enum_create(&sql, new_stmt); + } + } +} diff --git a/src/schema/discover/interpret.rs b/src/schema/discover/interpret.rs new file mode 100644 index 0000000000..d9806b4a45 --- /dev/null +++ b/src/schema/discover/interpret.rs @@ -0,0 +1,510 @@ +//! Phase 2: Interpret recorded schema changes into SQL statements, warnings, and suggestions. +//! +//! The main entry point is [`interpret`], which takes a [`ChangeSet`] from Phase 1 +//! and produces an [`InterpretResult`] containing SQL statements, warnings, +//! suggestions, and unresolved ambiguous renames. + +use std::collections::{HashMap, HashSet}; + +use sea_query::TableAlterStatement; + +use super::changes::{ + ChangeId, ChangeSet, ColumnChange, ColumnChangeKind, ConstraintChange, ConstraintChangeKind, + EnumChange, EnumChangeKind, TableChange, TableChangeKind, +}; +use super::resolver::{self, AddedColumn, RemovedColumn}; +use super::suggestion::{DiscoverSuggestion, SuggestionKind}; +use super::warning::{DiscoverWarning, WarningKind}; +use crate::{DbBackend, Statement}; + +/// Result of interpreting recorded schema changes (Phase 2). +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Default)] +pub struct InterpretResult { + /// SQL statements needed to bring the database in sync with entity definitions. + /// Each entry is paired with the [`ChangeId`] it was generated from. + pub statements: Vec<(ChangeId, Statement)>, + /// Always-on warnings about changes requiring manual attention (e.g. data migration). + pub warnings: Vec, + /// Heuristic-powered suggested fixes (renames, enum changes). + pub suggestions: Vec, + /// Ambiguous renames that need user input to resolve. + pub unresolved: Vec, +} + +impl InterpretResult { + /// Get just the SQL statements (without change IDs). + pub fn sql_statements(&self) -> Vec<&Statement> { + self.statements.iter().map(|(_, s)| s).collect() + } +} + +/// A decision made about an ambiguous rename. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub enum RenameDecision { + /// The user confirmed this is a rename. + Rename { + /// The old (removed) column name. + from: String, + /// The new (added) column name. + to: String, + }, + /// The user said this is not a rename — DROP + ADD. + DropAndAdd { + /// The removed column name. + removed: String, + /// The added column names that were candidates. + added: Vec, + }, +} + +impl InterpretResult { + /// Apply user decisions for ambiguous renames. + pub fn apply_rename_decisions(&mut self, decisions: &[RenameDecision], db_backend: DbBackend) { + for decision in decisions { + match decision { + RenameDecision::Rename { from, to } => { + if let Some(ambiguous) = self + .unresolved + .iter() + .find(|a| a.removed == *from && a.candidates.iter().any(|c| c.added == *to)) + { + let table_name = &ambiguous.table; + let id = ChangeId(usize::MAX); + self.statements.push(( + id, + db_backend.build( + TableAlterStatement::new() + .table(sea_query::Alias::new(table_name.as_str())) + .rename_column(from.clone(), to.clone()), + ), + )); + } + } + RenameDecision::DropAndAdd { removed, .. } => { + if let Some(ambiguous) = self.unresolved.iter().find(|a| a.removed == *removed) + { + let table_name = &ambiguous.table; + let id = ChangeId(usize::MAX); + self.statements.push(( + id, + db_backend.build( + TableAlterStatement::new() + .table(sea_query::Alias::new(table_name.as_str())) + .drop_column(sea_query::Alias::new(removed.as_str())), + ), + )); + } + } + } + } + self.unresolved.clear(); + } +} + +/// Configures how change interpretation is performed. +#[derive(Debug)] +pub struct InterpretConfig { + /// The database backend to use for building SQL statements (for renames resolved at interpret time). + pub db_backend: DbBackend, + /// Whether to auto-apply heuristic renames as SQL changes. + pub assumptions: bool, + /// Whether dangerous operations (drops) are allowed. + pub allow_dangerous: bool, +} + +/// Phase 2: Interpret recorded changes into SQL statements, warnings, and suggestions. +/// +/// Operates only on the [`ChangeSet`] from Phase 1. Changes carry pre-built +/// [`Statement`]s; interpretation decides which to emit and generates warnings/suggestions. +pub fn interpret(change_set: ChangeSet, config: &InterpretConfig) -> InterpretResult { + let mut statements: Vec<(ChangeId, Statement)> = Vec::new(); + let mut warnings: Vec = Vec::new(); + let mut suggestions: Vec = Vec::new(); + let mut unresolved: Vec = Vec::new(); + + // Ordered to satisfy FK / type constraints: + // 1. CREATE TYPE — enum types must exist before tables that reference them + // 2. CREATE TABLE — parents before children (ChangeSet records in sorted_tables order) + // 3. ADD COLUMN + // 4. ADD FK / ADD INDEX / ADD UNIQUE + // 5. DROP FK / DROP UNIQUE + // 6. DROP COLUMN + // 7. DROP TABLE — children before parents (ChangeSet records via sorted_table_drops) + // 8. DROP TYPE — after tables that referenced the type are gone + interpret_enum_creates(&change_set.enums, &mut statements); + interpret_table_creates(&change_set.tables, &mut statements); + interpret_column_adds( + &change_set.columns, + config, + &mut statements, + &mut warnings, + &mut suggestions, + &mut unresolved, + ); + interpret_constraint_adds(&change_set.constraints, &mut statements); + interpret_constraint_drops(&change_set.constraints, &mut statements); + interpret_column_drops(&change_set.columns, config, &mut statements); + interpret_table_drops(&change_set.tables, config, &mut statements); + interpret_enum_drops(&change_set.enums, config, &mut statements, &mut suggestions); + + InterpretResult { + statements, + warnings, + suggestions, + unresolved, + } +} + +/// Emit CREATE TABLE statements (parents before children via ChangeSet recording order). +fn interpret_table_creates(tables: &[TableChange], statements: &mut Vec<(ChangeId, Statement)>) { + for tc in tables { + if let TableChangeKind::Create { stmt, .. } = &tc.kind { + statements.push((tc.id, stmt.clone())); + } + } +} + +/// Emit DROP TABLE statements (children before parents via ChangeSet recording order). +fn interpret_table_drops( + tables: &[TableChange], + config: &InterpretConfig, + statements: &mut Vec<(ChangeId, Statement)>, +) { + for tc in tables { + if let TableChangeKind::Drop { table } = &tc.kind { + statements.push(( + tc.id, + config.db_backend.build( + sea_query::Table::drop() + .table(sea_query::Alias::new(table.1.to_string())) + .if_exists(), + ), + )); + } + } +} + +/// Emit ADD COLUMN and RENAME COLUMN statements. +/// Also populates warnings, suggestions, and unresolved renames. +/// Drop statements are collected separately by `interpret_column_drops`. +fn interpret_column_adds( + columns: &[ColumnChange], + config: &InterpretConfig, + statements: &mut Vec<(ChangeId, Statement)>, + warnings: &mut Vec, + suggestions: &mut Vec, + unresolved: &mut Vec, +) { + let mut drop_stmts: Vec<(ChangeId, Statement)> = Vec::new(); + interpret_columns_inner( + columns, + config, + statements, + &mut drop_stmts, + warnings, + suggestions, + unresolved, + ); + // drop_stmts are discarded here; they will be emitted by interpret_column_drops +} + +/// Emit DROP COLUMN statements (after FK drops, before table drops). +fn interpret_column_drops( + columns: &[ColumnChange], + config: &InterpretConfig, + statements: &mut Vec<(ChangeId, Statement)>, +) { + let mut add_stmts: Vec<(ChangeId, Statement)> = Vec::new(); + let mut drop_stmts: Vec<(ChangeId, Statement)> = Vec::new(); + let mut warnings = Vec::new(); + let mut suggestions = Vec::new(); + let mut unresolved = Vec::new(); + interpret_columns_inner( + columns, + config, + &mut add_stmts, + &mut drop_stmts, + &mut warnings, + &mut suggestions, + &mut unresolved, + ); + statements.extend(drop_stmts); +} + +/// Core column interpretation: runs rename detection and separates ADD/RENAME from DROP outputs. +fn interpret_columns_inner( + columns: &[ColumnChange], + config: &InterpretConfig, + add_stmts: &mut Vec<(ChangeId, Statement)>, + drop_stmts: &mut Vec<(ChangeId, Statement)>, + warnings: &mut Vec, + suggestions: &mut Vec, + unresolved: &mut Vec, +) { + let mut table_added: HashMap> = + Default::default(); + let mut table_removed: HashMap> = + Default::default(); + + for cc in columns { + match &cc.kind { + ColumnChangeKind::Add { + column, + index, + column_type, + is_not_null, + has_default, + stmt, + } => { + if *is_not_null && !has_default { + warnings.push(DiscoverWarning { + kind: WarningKind::NotNullNoDefault, + message: format!( + "Column '{}.{column}' is NOT NULL with no default value. \ + Existing rows will need data populated before or during this migration.", + cc.table, + ), + related_changes: vec![cc.id], + }); + } + table_added.entry(cc.table.clone()).or_default().push(( + cc.id, + AddedColumn { + index: *index, + name: column.clone(), + column_type: column_type.clone(), + }, + stmt.clone(), + )); + } + ColumnChangeKind::Drop { + column, + index, + column_type, + stmt, + } => { + table_removed.entry(cc.table.clone()).or_default().push(( + cc.id, + RemovedColumn { + index: *index, + name: column.clone(), + column_type: column_type.clone(), + }, + stmt.clone(), + )); + } + ColumnChangeKind::ExplicitRename { from, to, stmt } => { + if config.assumptions { + add_stmts.push((cc.id, stmt.clone())); + } else { + suggestions.push(DiscoverSuggestion { + kind: SuggestionKind::PossibleRename, + message: format!( + "Column '{}.{from}' has a `renamed_from` annotation to '{to}'. \ + Enable assumptions to auto-apply.", + cc.table, + ), + related_changes: vec![cc.id], + }); + } + } + ColumnChangeKind::CheckConstraintPresent { column } => { + warnings.push(DiscoverWarning { + kind: WarningKind::CheckConstraintDiff, + message: format!( + "Column '{}.{column}' has a CHECK constraint in entity definition. \ + CHECK constraints cannot be automatically diffed — verify manually.", + cc.table, + ), + related_changes: vec![cc.id], + }); + } + } + } + + // Rename detection per table + let all_tables: HashSet = table_added + .keys() + .chain(table_removed.keys()) + .cloned() + .collect(); + + for table in &all_tables { + let added = table_added.remove(table.as_str()).unwrap_or_default(); + let removed = table_removed.remove(table.as_str()).unwrap_or_default(); + + if !config.allow_dangerous || (added.is_empty() && removed.is_empty()) { + for (id, _, stmt) in &added { + add_stmts.push((*id, stmt.clone())); + } + continue; + } + + let added_ids: HashMap = added + .iter() + .map(|(id, c, _)| (c.name.clone(), *id)) + .collect(); + let removed_ids: HashMap = removed + .iter() + .map(|(id, c, _)| (c.name.clone(), *id)) + .collect(); + let added_stmts: HashMap = added + .iter() + .map(|(_, c, s)| (c.name.clone(), s.clone())) + .collect(); + let removed_stmts: HashMap = removed + .iter() + .map(|(_, c, s)| (c.name.clone(), s.clone())) + .collect(); + + let resolver_added: Vec = added.into_iter().map(|(_, c, _)| c).collect(); + let resolver_removed: Vec = removed.into_iter().map(|(_, c, _)| c).collect(); + + let resolution = resolver::resolve_renames(table, resolver_added, resolver_removed); + + // Assumed renames + for rename in &resolution.assumed { + let add_id = added_ids[&rename.added]; + let drop_id = removed_ids[&rename.removed]; + + if config.assumptions { + add_stmts.push(( + add_id, + config.db_backend.build( + TableAlterStatement::new() + .table(sea_query::Alias::new(table.as_str())) + .rename_column(rename.removed.clone(), rename.added.clone()), + ), + )); + suggestions.push(DiscoverSuggestion { + kind: SuggestionKind::PossibleRename, + message: format!( + "Column '{table}.{}' was auto-renamed to '{}' \ + (same type, position proximity {}). Use `--rename` to override.", + rename.removed, rename.added, rename.proximity, + ), + related_changes: vec![add_id, drop_id], + }); + } else { + suggestions.push(DiscoverSuggestion { + kind: SuggestionKind::PossibleRename, + message: format!( + "Column '{table}.{}' may have been renamed to '{}' \ + (same type, position proximity {}). Enable assumptions or use `--rename` to apply.", + rename.removed, rename.added, rename.proximity, + ), + related_changes: vec![add_id, drop_id], + }); + add_stmts.push((add_id, added_stmts[&rename.added].clone())); + drop_stmts.push((drop_id, removed_stmts[&rename.removed].clone())); + } + } + + unresolved.extend(resolution.ambiguous); + + for add in &resolution.remaining_added { + let id = added_ids[&add.name]; + add_stmts.push((id, added_stmts[&add.name].clone())); + } + + for rem in &resolution.remaining_removed { + let id = removed_ids[&rem.name]; + drop_stmts.push((id, removed_stmts[&rem.name].clone())); + } + } +} + +/// Emit ADD FOREIGN KEY, ADD INDEX, ADD UNIQUE CONSTRAINT statements. +fn interpret_constraint_adds( + constraints: &[ConstraintChange], + statements: &mut Vec<(ChangeId, Statement)>, +) { + for cc in constraints { + match &cc.kind { + ConstraintChangeKind::AddForeignKey { stmt } + | ConstraintChangeKind::AddIndex { stmt } + | ConstraintChangeKind::AddUniqueConstraint { stmt, .. } => { + statements.push((cc.id, stmt.clone())); + } + ConstraintChangeKind::DropForeignKey { .. } + | ConstraintChangeKind::DropUniqueConstraint { .. } => {} + } + } +} + +/// Emit DROP FOREIGN KEY and DROP UNIQUE CONSTRAINT statements (before column/table drops). +fn interpret_constraint_drops( + constraints: &[ConstraintChange], + statements: &mut Vec<(ChangeId, Statement)>, +) { + for cc in constraints { + match &cc.kind { + ConstraintChangeKind::DropForeignKey { stmt, .. } + | ConstraintChangeKind::DropUniqueConstraint { stmt, .. } => { + statements.push((cc.id, stmt.clone())); + } + ConstraintChangeKind::AddForeignKey { .. } + | ConstraintChangeKind::AddIndex { .. } + | ConstraintChangeKind::AddUniqueConstraint { .. } => {} + } + } +} + +/// Emit CREATE TYPE statements and variant-change/rename suggestions. +fn interpret_enum_creates(enums: &[EnumChange], statements: &mut Vec<(ChangeId, Statement)>) { + for ec in enums { + if let EnumChangeKind::Create { stmt } = &ec.kind { + statements.push((ec.id, stmt.clone())); + } + } +} + +/// Emit variant-change / rename suggestions, and DROP TYPE when allow_dangerous. +/// Must run after table drops so the enum is no longer referenced. +fn interpret_enum_drops( + enums: &[EnumChange], + config: &InterpretConfig, + statements: &mut Vec<(ChangeId, Statement)>, + suggestions: &mut Vec, +) { + for ec in enums { + match &ec.kind { + EnumChangeKind::VariantChange { name, .. } => { + if config.allow_dangerous { + suggestions.push(DiscoverSuggestion { + kind: SuggestionKind::EnumVariantChange, + message: format!( + "Enum type '{name}' has changed variants. Adding variants requires \ + `ALTER TYPE ... ADD VALUE`; removing variants requires type recreation. \ + This migration must be written manually.", + ), + related_changes: vec![ec.id], + }); + } + } + EnumChangeKind::Rename { + existing_name, + new_name, + } => { + if config.allow_dangerous { + suggestions.push(DiscoverSuggestion { + kind: SuggestionKind::EnumRename, + message: format!( + "Enum type '{existing_name}' appears to have been renamed to '{new_name}'. \ + This requires `ALTER TYPE ... RENAME TO`.", + ), + related_changes: vec![ec.id], + }); + } + } + EnumChangeKind::Drop { stmt, .. } => { + if config.allow_dangerous { + statements.push((ec.id, stmt.clone())); + } + } + EnumChangeKind::Create { .. } => {} + } + } +} diff --git a/src/schema/discover/mod.rs b/src/schema/discover/mod.rs new file mode 100644 index 0000000000..077f3818bb --- /dev/null +++ b/src/schema/discover/mod.rs @@ -0,0 +1,80 @@ +pub mod changes; +mod enum_; +pub mod interpret; +pub mod resolver; +pub(crate) mod schema; +pub mod suggestion; +mod table; +pub mod warning; + +use crate::schema::builder::{EntitySchemaInfo, TableSortOrder, get_table_name}; +use crate::{ConnectionTrait, DbErr, sorted_tables}; +use changes::ChangeSet; + +pub use changes::ChangeId as SchemaChangeId; +pub use interpret::{InterpretConfig, InterpretResult, RenameDecision}; +use sea_query::TableCreateStatement; +pub use suggestion::{DiscoverSuggestion, SuggestionKind}; +pub use warning::{DiscoverWarning, WarningKind}; + +//TODO: honestly, I think whole scheam module should be moved to a separate crate + +/// Record all schema changes by comparing entities against the database +pub(crate) async fn discover( + new_entities: &[EntitySchemaInfo], + db: &C, + allow_dangerous: bool, + excluded_tables: &[String], +) -> Result +where + C: ConnectionTrait + sea_schema::Connection, +{ + let existing = schema::discover_existing_schema(db).await?; + let db_backend = db.get_database_backend(); + + let mut change_set = ChangeSet::default(); + + let tabl_ref: Vec<&TableCreateStatement> = new_entities.iter().map(|e| e.table()).collect(); + for table_name in sorted_tables(&tabl_ref, TableSortOrder::ParentsFirst) { + let name_str = table_name.1.to_string(); + if excluded_tables.iter().any(|e| e == &name_str) { + continue; + } + + //PERF: just sort TableCreateStatements, instead of searching + if let Some(entity) = new_entities + .iter() + .find(|entity| table_name == get_table_name(entity.table().get_table_name())) + { + enum_::record_enum_changes( + entity.enums(), + db_backend, + &existing.enums, + &mut change_set, + ); + table::record_table_changes( + entity, + &existing.tables, + &mut change_set, + allow_dangerous, + db_backend, + ); + } else { + unreachable!() + } + } + + if allow_dangerous { + table::record_orphan_tables(new_entities, &existing, &mut change_set, excluded_tables); + let all_entity_enums: Vec<&sea_query::extension::postgres::TypeCreateStatement> = + new_entities.iter().flat_map(|e| e.enums().iter()).collect(); + enum_::record_orphan_enums( + &all_entity_enums, + db_backend, + &existing.enums, + &mut change_set, + ); + } + + Ok(change_set) +} diff --git a/src/schema/discover/resolver.rs b/src/schema/discover/resolver.rs new file mode 100644 index 0000000000..54a7cfa10e --- /dev/null +++ b/src/schema/discover/resolver.rs @@ -0,0 +1,618 @@ +//! Heuristic rename detection and enum change resolution for schema discovery. +//! +//! This module contains pure functions that take raw added/removed column lists +//! and produce structured rename decisions. No I/O or user interaction happens here. + +use sea_query::ColumnType; + +/// A column that exists in the entity but not in the database. +#[derive(Debug, Clone)] +pub struct AddedColumn { + /// Position index in the entity's column list. + pub index: usize, + /// Column name. + pub name: String, + /// Column type (if available from the entity definition). + pub column_type: Option, +} + +/// A column that exists in the database but not in the entity. +#[derive(Debug, Clone)] +pub struct RemovedColumn { + /// Position index in the database table's column list. + pub index: usize, + /// Column name. + pub name: String, + /// Column type (if available from schema discovery). + pub column_type: Option, +} + +/// A single rename candidate pairing a removed column with an added column. +#[derive(Debug, Clone)] +pub struct RenameCandidate { + /// The name of the removed (old) column. + pub removed: String, + /// The name of the added (new) column. + pub added: String, + /// Positional distance between the two columns. + pub proximity: usize, +} + +/// An ambiguous rename where multiple candidates exist for a removed column. +#[derive(Debug, Clone)] +pub struct AmbiguousRename { + /// The table this rename occurs in. + pub table: String, + /// The name of the removed column. + pub removed: String, + /// All possible added columns it could be renamed to. + pub candidates: Vec, +} + +/// The result of rename resolution. +#[derive(Debug, Clone, Default)] +pub struct RenameResolution { + /// Obvious renames (1:1 mapping, same type, close proximity) — auto-decided. + pub assumed: Vec, + /// Ambiguous renames (multiple candidates) — need user input. + pub ambiguous: Vec, + /// Genuinely new columns (no rename match). + pub remaining_added: Vec, + /// Genuinely removed columns (no rename match). + pub remaining_removed: Vec, +} + +/// The kind of enum change detected between existing and new definitions. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EnumChange { + /// Same enum name, different variants. + VariantChange { + /// The enum type name. + name: String, + /// The existing CREATE TYPE SQL. + existing_sql: String, + /// The new CREATE TYPE SQL. + new_sql: String, + }, + /// Different enum name with same variants (enum was renamed). + NameChange { + /// The existing enum type name. + existing_name: String, + /// The new enum type name. + new_name: String, + }, +} + +/// Check if two column types are compatible for rename detection. +/// Treats String variants (String, Text) as equivalent. +pub fn types_compatible(a: Option<&ColumnType>, b: Option<&ColumnType>) -> bool { + match (a, b) { + (Some(a), Some(b)) => { + if a == b { + return true; + } + // Treat all String/Text variants as compatible + matches!( + (a, b), + ( + ColumnType::String(_) | ColumnType::Text, + ColumnType::String(_) | ColumnType::Text, + ) + ) + } + _ => false, + } +} + +/// Resolve renames from lists of added and removed columns. +/// +/// For each removed column, find added columns with compatible types and proximity ≤ 2. +/// - If exactly one match for both sides (1:1, neither claimed elsewhere) → assumed rename. +/// - If multiple candidates → ambiguous rename. +/// - Unmatched columns go to remaining_added / remaining_removed. +pub fn resolve_renames( + table: &str, + added: Vec, + removed: Vec, +) -> RenameResolution { + let mut resolution = RenameResolution::default(); + + // For each removed column, collect all compatible added candidates within proximity + let mut removed_candidates: Vec<(usize, Vec<(usize, RenameCandidate)>)> = Vec::new(); + + for (ri, rem) in removed.iter().enumerate() { + let mut candidates = Vec::new(); + for (ai, add) in added.iter().enumerate() { + let proximity = (rem.index as isize - add.index as isize).unsigned_abs(); + if proximity <= 2 + && types_compatible(rem.column_type.as_ref(), add.column_type.as_ref()) + { + candidates.push(( + ai, + RenameCandidate { + removed: rem.name.clone(), + added: add.name.clone(), + proximity, + }, + )); + } + } + removed_candidates.push((ri, candidates)); + } + + // First pass: identify obvious 1:1 renames + let mut claimed_added: Vec = Vec::new(); + let mut claimed_removed: Vec = Vec::new(); + + // Sort by number of candidates (fewest first) to greedily resolve unambiguous ones + let mut sorted_by_candidates: Vec<_> = removed_candidates.iter().collect(); + sorted_by_candidates.sort_by_key(|(_, cands)| cands.len()); + + for (ri, candidates) in &sorted_by_candidates { + if claimed_removed.contains(ri) { + continue; + } + // Filter out already-claimed added columns + let available: Vec<_> = candidates + .iter() + .filter(|(ai, _)| !claimed_added.contains(ai)) + .collect(); + + if available.len() == 1 { + // Check the reverse: is this added column also only matched by one removed column? + let ai = available[0].0; + let reverse_count = removed_candidates + .iter() + .filter(|(other_ri, other_cands)| { + !claimed_removed.contains(other_ri) + && other_cands.iter().any(|(other_ai, _)| { + *other_ai == ai && !claimed_added.contains(other_ai) + }) + }) + .count(); + + if reverse_count == 1 { + // Unique 1:1 mapping → assumed rename + resolution.assumed.push(available[0].1.clone()); + claimed_added.push(ai); + claimed_removed.push(*ri); + } + } + } + + // Second pass: collect ambiguous renames from unclaimed removed columns with candidates + for (ri, candidates) in &removed_candidates { + if claimed_removed.contains(ri) { + continue; + } + let available: Vec<_> = candidates + .iter() + .filter(|(ai, _)| !claimed_added.contains(ai)) + .map(|(_, c)| c.clone()) + .collect(); + + if available.len() > 1 { + resolution.ambiguous.push(AmbiguousRename { + table: table.to_string(), + removed: removed[*ri].name.clone(), + candidates: available, + }); + claimed_removed.push(*ri); + // Don't claim the added columns — the user will decide + } + } + + // Remaining: unclaimed added and removed columns + for (ai, add) in added.iter().enumerate() { + if !claimed_added.contains(&ai) { + // Check if this added column is referenced in an ambiguous rename + let in_ambiguous = resolution + .ambiguous + .iter() + .any(|a| a.candidates.iter().any(|c| c.added == add.name)); + if !in_ambiguous { + resolution.remaining_added.push(add.clone()); + } + } + } + + for (ri, rem) in removed.iter().enumerate() { + if !claimed_removed.contains(&ri) { + resolution.remaining_removed.push(rem.clone()); + } + } + + resolution +} + +/// Detect enum changes between existing and new SQL definitions. +/// Compares two `CREATE TYPE ... AS ENUM (...)` SQL strings and returns +/// the kind of change detected, if any. +pub fn detect_enum_change(existing_sql: &str, new_sql: &str) -> Option { + let existing_name = extract_enum_type_name(existing_sql)?; + let new_name = extract_enum_type_name(new_sql)?; + + if existing_name == new_name && existing_sql != new_sql { + Some(EnumChange::VariantChange { + name: existing_name, + existing_sql: existing_sql.to_string(), + new_sql: new_sql.to_string(), + }) + } else if existing_name != new_name { + // Extract variants to check if they match + let existing_variants = extract_enum_variants(existing_sql); + let new_variants = extract_enum_variants(new_sql); + if existing_variants == new_variants && !existing_variants.is_empty() { + Some(EnumChange::NameChange { + existing_name, + new_name, + }) + } else { + None + } + } else { + None + } +} + +/// Extract the type name from a `CREATE TYPE "name" AS ENUM (...)` SQL string. +pub fn extract_enum_type_name(sql: &str) -> Option { + let upper = sql.to_uppercase(); + let crt_type_pos = upper.find("CREATE TYPE")?; + let as_enum_pos = upper.find("AS ENUM")?; + if as_enum_pos <= crt_type_pos { + return None; + } + + let between = sql[crt_type_pos + "CREATE TYPE".len()..as_enum_pos].trim(); + // Extract quoted or unquoted identifier + let name = if let Some(stripped) = between.strip_prefix('"') { + let end = stripped.find('"')?; + &stripped[..end] + } else { + between.split_whitespace().next()? + }; + Some(name.to_string()) +} + +/// Extract enum variant strings from a CREATE TYPE ... AS ENUM (...) SQL statement. +fn extract_enum_variants(sql: &str) -> Vec { + let upper = sql.to_uppercase(); + let Some(paren_start) = upper.find("AS ENUM") else { + return Vec::new(); + }; + let rest = &sql[paren_start..]; + let Some(open) = rest.find('(') else { + return Vec::new(); + }; + let Some(close) = rest.find(')') else { + return Vec::new(); + }; + let inner = &rest[open + 1..close]; + inner + .split(',') + .map(|s| s.trim().trim_matches('\'').trim_matches('"').to_string()) + .filter(|s| !s.is_empty()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn added(index: usize, name: &str, col_type: Option) -> AddedColumn { + AddedColumn { + index, + name: name.to_string(), + column_type: col_type, + } + } + + fn removed(index: usize, name: &str, col_type: Option) -> RemovedColumn { + RemovedColumn { + index, + name: name.to_string(), + column_type: col_type, + } + } + + #[test] + fn test_single_obvious_rename() { + let added_cols = vec![added( + 1, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + + assert_eq!(result.assumed.len(), 1); + assert_eq!(result.assumed[0].removed, "name"); + assert_eq!(result.assumed[0].added, "title"); + assert_eq!(result.assumed[0].proximity, 0); + assert!(result.ambiguous.is_empty()); + assert!(result.remaining_added.is_empty()); + assert!(result.remaining_removed.is_empty()); + } + + #[test] + fn test_no_rename_type_mismatch() { + let added_cols = vec![added(1, "count", Some(ColumnType::Integer))]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + + assert!(result.assumed.is_empty()); + assert!(result.ambiguous.is_empty()); + assert_eq!(result.remaining_added.len(), 1); + assert_eq!(result.remaining_removed.len(), 1); + } + + #[test] + fn test_no_rename_too_far() { + let added_cols = vec![added( + 5, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + + assert!(result.assumed.is_empty()); + assert!(result.ambiguous.is_empty()); + assert_eq!(result.remaining_added.len(), 1); + assert_eq!(result.remaining_removed.len(), 1); + } + + #[test] + fn test_ambiguous_multiple_candidates() { + // One removed column, two added columns with same type and close proximity + let added_cols = vec![ + added( + 1, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + ), + added( + 2, + "label", + Some(ColumnType::String(sea_query::StringLen::None)), + ), + ]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + + assert!(result.assumed.is_empty()); + assert_eq!(result.ambiguous.len(), 1); + assert_eq!(result.ambiguous[0].table, "cake"); + assert_eq!(result.ambiguous[0].removed, "name"); + assert_eq!(result.ambiguous[0].candidates.len(), 2); + } + + #[test] + fn test_multiple_independent_renames() { + // Two removed + two added, each pair is uniquely matched + let added_cols = vec![ + added( + 1, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + ), + added(3, "weight_kg", Some(ColumnType::Integer)), + ]; + let removed_cols = vec![ + removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + ), + removed(3, "weight", Some(ColumnType::Integer)), + ]; + + let result = resolve_renames("product", added_cols, removed_cols); + + assert_eq!(result.assumed.len(), 2); + assert!(result.ambiguous.is_empty()); + assert!(result.remaining_added.is_empty()); + assert!(result.remaining_removed.is_empty()); + + let names: Vec<_> = result.assumed.iter().map(|r| r.removed.as_str()).collect(); + assert!(names.contains(&"name")); + assert!(names.contains(&"weight")); + } + + #[test] + fn test_string_text_type_compatibility() { + let added_cols = vec![added(1, "description", Some(ColumnType::Text))]; + let removed_cols = vec![removed( + 1, + "desc", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("item", added_cols, removed_cols); + + assert_eq!(result.assumed.len(), 1); + assert_eq!(result.assumed[0].removed, "desc"); + assert_eq!(result.assumed[0].added, "description"); + } + + #[test] + fn test_proximity_at_boundary() { + // Proximity exactly 2 — should still match + let added_cols = vec![added( + 3, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + assert_eq!(result.assumed.len(), 1); + assert_eq!(result.assumed[0].proximity, 2); + + // Proximity 3 — should NOT match + let added_cols = vec![added( + 4, + "title", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + let removed_cols = vec![removed( + 1, + "name", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("cake", added_cols, removed_cols); + assert!(result.assumed.is_empty()); + } + + #[test] + fn test_no_columns_produces_empty_resolution() { + let result = resolve_renames("empty", vec![], vec![]); + assert!(result.assumed.is_empty()); + assert!(result.ambiguous.is_empty()); + assert!(result.remaining_added.is_empty()); + assert!(result.remaining_removed.is_empty()); + } + + #[test] + fn test_only_added_columns() { + let added_cols = vec![ + added( + 1, + "new_col", + Some(ColumnType::String(sea_query::StringLen::None)), + ), + added(2, "another", Some(ColumnType::Integer)), + ]; + + let result = resolve_renames("t", added_cols, vec![]); + assert!(result.assumed.is_empty()); + assert!(result.ambiguous.is_empty()); + assert_eq!(result.remaining_added.len(), 2); + assert!(result.remaining_removed.is_empty()); + } + + #[test] + fn test_only_removed_columns() { + let removed_cols = vec![removed( + 1, + "old_col", + Some(ColumnType::String(sea_query::StringLen::None)), + )]; + + let result = resolve_renames("t", vec![], removed_cols); + assert!(result.assumed.is_empty()); + assert!(result.ambiguous.is_empty()); + assert!(result.remaining_added.is_empty()); + assert_eq!(result.remaining_removed.len(), 1); + } + + #[test] + fn test_enum_variant_change() { + let existing = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad')"#; + let new = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad', 'neutral')"#; + + let change = detect_enum_change(existing, new); + assert!(change.is_some()); + match change.unwrap() { + EnumChange::VariantChange { name, .. } => { + assert_eq!(name, "mood"); + } + _ => panic!("expected VariantChange"), + } + } + + #[test] + fn test_enum_rename() { + let existing = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad')"#; + let new = r#"CREATE TYPE "feeling" AS ENUM ('happy', 'sad')"#; + + let change = detect_enum_change(existing, new); + assert!(change.is_some()); + match change.unwrap() { + EnumChange::NameChange { + existing_name, + new_name, + } => { + assert_eq!(existing_name, "mood"); + assert_eq!(new_name, "feeling"); + } + _ => panic!("expected NameChange"), + } + } + + #[test] + fn test_enum_no_change() { + let sql = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad')"#; + assert!(detect_enum_change(sql, sql).is_none()); + } + + #[test] + fn test_enum_completely_different() { + let existing = r#"CREATE TYPE "mood" AS ENUM ('happy', 'sad')"#; + let new = r#"CREATE TYPE "color" AS ENUM ('red', 'blue')"#; + + // Different name AND different variants — no match + assert!(detect_enum_change(existing, new).is_none()); + } + + #[test] + fn test_types_compatible_same() { + assert!(types_compatible( + Some(&ColumnType::Integer), + Some(&ColumnType::Integer) + )); + } + + #[test] + fn test_types_compatible_string_text() { + assert!(types_compatible( + Some(&ColumnType::String(sea_query::StringLen::None)), + Some(&ColumnType::Text) + )); + assert!(types_compatible( + Some(&ColumnType::Text), + Some(&ColumnType::String(sea_query::StringLen::N(255))) + )); + } + + #[test] + fn test_types_compatible_none() { + assert!(!types_compatible(None, Some(&ColumnType::Integer))); + assert!(!types_compatible(Some(&ColumnType::Integer), None)); + assert!(!types_compatible(None, None)); + } + + #[test] + fn test_types_incompatible() { + assert!(!types_compatible( + Some(&ColumnType::Integer), + Some(&ColumnType::String(sea_query::StringLen::None)) + )); + } +} diff --git a/src/schema/discover/schema.rs b/src/schema/discover/schema.rs new file mode 100644 index 0000000000..85cf2c699a --- /dev/null +++ b/src/schema/discover/schema.rs @@ -0,0 +1,109 @@ +#[allow(unused_imports)] +use crate::{ConnectionTrait, DbBackend, DbErr}; +use sea_query::{TableCreateStatement, extension::postgres::TypeCreateStatement}; + +/// Stores the discovered schema from the database, including tables and enums +#[derive(Default)] +pub(crate) struct DiscoveredSchema { + pub(crate) tables: Vec, + pub(crate) enums: Vec, +} + +pub(crate) async fn discover_existing_schema(db: &C) -> Result +where + C: ConnectionTrait + sea_schema::Connection, +{ + //TODO: discover ONLY existing schema + match db.get_database_backend() { + #[cfg(feature = "sqlx-mysql")] + DbBackend::MySql => { + use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe}; + + let current_schema: String = db + .query_one( + sea_query::SelectStatement::new() + .expr(sea_schema::mysql::MySql::get_current_schema()), + ) + .await? + .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? + .try_get_by_index(0)?; + let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); + + let schema = schema_discovery + .discover_with(db) + .await + .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; + + Ok(DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + }) + } + #[cfg(feature = "sqlx-postgres")] + DbBackend::Postgres => { + use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe}; + + let current_schema: String = db + .query_one( + sea_query::SelectStatement::new() + .expr(sea_schema::postgres::Postgres::get_current_schema()), + ) + .await? + .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))? + .try_get_by_index(0)?; + let schema_discovery = SchemaDiscovery::new_no_exec(¤t_schema); + + let schema = schema_discovery + .discover_with(db) + .await + .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?; + + Ok(DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: schema.enums.iter().map(|def| def.write()).collect(), + }) + } + #[cfg(feature = "sqlx-sqlite")] + DbBackend::Sqlite => { + use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; + let schema = SchemaDiscovery::discover_with(db) + .await + .map_err(|err| { + DbErr::Query(match err { + SqliteDiscoveryError::SqlxError(err) => { + crate::RuntimeErr::SqlxError(err.into()) + } + _ => crate::RuntimeErr::Internal(format!("{err:?}")), + }) + })? + .merge_indexes_into_table(); + Ok(DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + }) + } + #[cfg(feature = "rusqlite")] + DbBackend::Sqlite => { + use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery}; + let schema = SchemaDiscovery::discover_with(db) + .map_err(|err| { + DbErr::Query(match err { + SqliteDiscoveryError::RusqliteError(err) => { + crate::RuntimeErr::Rusqlite(err.into()) + } + _ => crate::RuntimeErr::Internal(format!("{err:?}")), + }) + })? + .merge_indexes_into_table(); + Ok(DiscoveredSchema { + tables: schema.tables.iter().map(|table| table.write()).collect(), + enums: vec![], + }) + } + #[allow(unreachable_patterns)] + other => Err(DbErr::BackendNotSupported { + db: other.as_str(), + ctx: "discover_existing_schema", + }), + } +} diff --git a/src/schema/discover/suggestion.rs b/src/schema/discover/suggestion.rs new file mode 100644 index 0000000000..73799cc45e --- /dev/null +++ b/src/schema/discover/suggestion.rs @@ -0,0 +1,35 @@ +//! Suggested fixes emitted during schema discovery. +//! +//! Suggestions are heuristic-powered proposals for renames and other changes +//! that the system can reasonably detect. They are only generated when the +//! `assumptions` flag is enabled. Suggestions reference the changes they +//! act upon via [`ChangeId`]. + +use super::changes::ChangeId; + +/// A suggested fix detected by heuristic analysis during schema discovery. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct DiscoverSuggestion { + /// The category of suggestion. + pub kind: SuggestionKind, + /// Human-readable description of the suggested change. + pub message: String, + /// IDs of the changes this suggestion relates to (e.g. the ADD + DROP that form a rename). + pub related_changes: Vec, +} + +/// Categories of schema discovery suggestions. +/// +/// Suggestions are only generated when `assumptions` is enabled. +/// They represent changes that the system can heuristically detect and auto-apply. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SuggestionKind { + /// A column was removed and another with the same type was added — obvious rename (auto-assumed). + PossibleRename, + /// An enum type exists in both DB and entities but with different variants. + EnumVariantChange, + /// An enum type appears to have been renamed (same variants, different name). + EnumRename, +} diff --git a/src/schema/discover/table.rs b/src/schema/discover/table.rs new file mode 100644 index 0000000000..077dd38abf --- /dev/null +++ b/src/schema/discover/table.rs @@ -0,0 +1,378 @@ +use super::changes::{ChangeSet, ColumnChangeKind, ConstraintChangeKind, TableChangeKind}; +use super::schema::DiscoveredSchema; +use crate::schema::builder::{EntitySchemaInfo, get_table_name}; +use crate::{DbBackend, TableSortOrder, sorted_tables}; +use sea_query::{ForeignKeyCreateStatement, Index, TableAlterStatement, TableCreateStatement}; + +/// Phase 1: Record table-level changes for a single entity against the existing schema. +pub(crate) fn record_table_changes( + entity: &EntitySchemaInfo, + existing: &[TableCreateStatement], + changes: &mut ChangeSet, + allow_dangerous: bool, + db_backend: DbBackend, +) { + let table_name = get_table_name(entity.table().get_table_name()); + let table_name_str = table_name.1.to_string(); + let existing_table = existing + .iter() + .find(|tbl| get_table_name(tbl.get_table_name()) == table_name); + + if let Some(existing_table) = existing_table { + record_column_changes( + entity, + existing_table, + &table_name_str, + changes, + allow_dangerous, + db_backend, + ); + record_foreign_key_changes( + entity, + existing_table, + &table_name_str, + changes, + allow_dangerous, + db_backend, + ); + record_index_changes(entity, existing_table, &table_name_str, changes, db_backend); + record_unique_constraint_changes( + entity, + existing_table, + &table_name_str, + changes, + db_backend, + ); + record_unique_constraint_drops( + entity, + existing_table, + &table_name_str, + changes, + db_backend, + ); + } else { + changes.record_table(TableChangeKind::Create { + table: table_name_str, + stmt: db_backend.build(entity.table()), + }); + } +} + +/// Phase 1: Record tables in the database that have no matching entity. +/// Drops are recorded in reverse-dependency order (children first) to avoid FK violations. +pub(crate) fn record_orphan_tables( + entities: &[EntitySchemaInfo], + existing: &DiscoveredSchema, + changes: &mut ChangeSet, + excluded_tables: &[String], +) { + let orphans: Vec<&TableCreateStatement> = existing + .tables + .iter() + .filter(|tbl| { + let name = get_table_name(tbl.get_table_name()); + let name_str = name.1.to_string(); + !excluded_tables.iter().any(|e| e == &name_str) + && !entities + .iter() + .any(|e| get_table_name(e.table().get_table_name()) == name) + }) + .collect(); + + for table_name in sorted_tables(&orphans, TableSortOrder::ChildrenFirst) { + changes.record_table(TableChangeKind::Drop { table: table_name }); + } +} + +fn get_entity_table_name(entity: &EntitySchemaInfo) -> sea_query::TableRef { + entity + .table() + .get_table_name() + .expect("table must have a name") + .clone() +} + +fn record_column_changes( + entity: &EntitySchemaInfo, + existing_table: &sea_query::TableCreateStatement, + table_name_str: &str, + changes: &mut ChangeSet, + allow_dangerous: bool, + db_backend: DbBackend, +) { + let entity_table_name = get_entity_table_name(entity); + + for (idx, column_def) in entity.table().get_columns().iter().enumerate() { + let col_name = column_def.get_column_name(); + let exists_in_db = existing_table + .get_columns() + .iter() + .any(|c| c.get_column_name() == col_name); + + if exists_in_db { + if column_def.get_column_spec().check.is_some() { + changes.record_column( + table_name_str.to_string(), + ColumnChangeKind::CheckConstraintPresent { + column: col_name.to_string(), + }, + ); + } + continue; + } + + // Check for explicit renamed_from annotation + let mut renamed_from = ""; + if let Some(comment) = &column_def.get_column_spec().comment { + if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") { + if let Some((prefix, _)) = suffix.split_once('"') { + renamed_from = prefix; + } + } + } + + if !renamed_from.is_empty() { + changes.record_column( + table_name_str.to_string(), + ColumnChangeKind::ExplicitRename { + from: renamed_from.to_string(), + to: col_name.to_string(), + stmt: db_backend.build( + TableAlterStatement::new() + .table(entity_table_name.clone()) + .rename_column(renamed_from.to_string(), col_name.to_string()), + ), + }, + ); + } else { + let spec = column_def.get_column_spec(); + let is_not_null = matches!(spec.nullable, Some(false)); + changes.record_column( + table_name_str.to_string(), + ColumnChangeKind::Add { + column: col_name.to_string(), + index: idx, + column_type: column_def.get_column_type().cloned(), + is_not_null, + has_default: spec.default.is_some(), + stmt: db_backend.build( + TableAlterStatement::new() + .table(entity_table_name.clone()) + .add_column(column_def.to_owned()), + ), + }, + ); + } + } + + // Removed columns (in DB but not in entity) + if allow_dangerous { + let entity_table_name = get_entity_table_name(entity); + for (idx, col) in existing_table.get_columns().iter().enumerate() { + let col_name = col.get_column_name(); + let in_entity = entity + .table() + .get_columns() + .iter() + .any(|ec| ec.get_column_name() == col_name); + if !in_entity { + changes.record_column( + table_name_str.to_string(), + ColumnChangeKind::Drop { + column: col_name.to_string(), + index: idx, + column_type: col.get_column_type().cloned(), + stmt: db_backend.build( + TableAlterStatement::new() + .table(entity_table_name.clone()) + .drop_column(sea_query::Alias::new(col_name)), + ), + }, + ); + } + } + } +} + +fn record_foreign_key_changes( + entity: &EntitySchemaInfo, + existing_table: &sea_query::TableCreateStatement, + table_name_str: &str, + changes: &mut ChangeSet, + allow_dangerous: bool, + db_backend: DbBackend, +) { + for foreign_key in entity.table().get_foreign_key_create_stmts().iter() { + let key_exists = existing_table + .get_foreign_key_create_stmts() + .iter() + .any(|existing_key| compare_foreign_key(foreign_key, existing_key)); + if !key_exists { + changes.record_constraint( + table_name_str.to_string(), + ConstraintChangeKind::AddForeignKey { + stmt: db_backend.build(foreign_key), + }, + ); + } + } + + if allow_dangerous { + let entity_table_name = get_entity_table_name(entity); + for existing_key in existing_table.get_foreign_key_create_stmts().iter() { + let in_entity = entity + .table() + .get_foreign_key_create_stmts() + .iter() + .any(|fk| compare_foreign_key(fk, existing_key)); + if !in_entity { + let fk = existing_key.get_foreign_key(); + if let Some(name) = fk.get_name() { + changes.record_constraint( + table_name_str.to_string(), + ConstraintChangeKind::DropForeignKey { + name: name.to_owned(), + stmt: db_backend.build( + TableAlterStatement::new() + .table(entity_table_name.clone()) + .drop_foreign_key(name.to_owned()), + ), + }, + ); + } + } + } + } +} + +fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool { + let a = a.get_foreign_key(); + let b = b.get_foreign_key(); + + a.get_name() == b.get_name() + || (a.get_ref_table() == b.get_ref_table() + && a.get_columns() == b.get_columns() + && a.get_ref_columns() == b.get_ref_columns()) +} + +fn record_index_changes( + entity: &EntitySchemaInfo, + existing_table: &sea_query::TableCreateStatement, + table_name_str: &str, + changes: &mut ChangeSet, + db_backend: DbBackend, +) { + for stmt in entity.indexes().iter() { + let has_index = existing_table.get_indexes().iter().any(|existing_index| { + existing_index.get_index_spec().get_column_names() + == stmt.get_index_spec().get_column_names() + }); + if !has_index { + let mut idx_stmt = stmt.clone(); + idx_stmt.if_not_exists(); + changes.record_constraint( + table_name_str.to_string(), + ConstraintChangeKind::AddIndex { + stmt: db_backend.build(&idx_stmt), + }, + ); + } + } +} + +fn record_unique_constraint_changes( + entity: &EntitySchemaInfo, + existing_table: &sea_query::TableCreateStatement, + table_name_str: &str, + changes: &mut ChangeSet, + db_backend: DbBackend, +) { + let entity_table_name = get_entity_table_name(entity); + + for column_def in entity.table().get_columns() { + if column_def.get_column_spec().unique { + let col_name = column_def.get_column_name(); + let col_exists = existing_table + .get_columns() + .iter() + .any(|c| c.get_column_name() == col_name); + if !col_exists { + continue; + } + let already_unique = existing_table.get_indexes().iter().any(|idx| { + if !idx.is_unique_key() { + return false; + } + let cols = idx.get_index_spec().get_column_names(); + cols.len() == 1 && cols[0] == col_name + }); + if !already_unique { + changes.record_constraint( + table_name_str.to_string(), + ConstraintChangeKind::AddUniqueConstraint { + column: col_name.to_string(), + stmt: db_backend.build( + Index::create() + .name(format!("idx-{table_name_str}-{col_name}")) + .table(entity_table_name.clone()) + .col(sea_query::Alias::new(col_name)) + .unique() + .if_not_exists(), + ), + }, + ); + } + } + } +} + +fn record_unique_constraint_drops( + entity: &EntitySchemaInfo, + existing_table: &sea_query::TableCreateStatement, + table_name_str: &str, + changes: &mut ChangeSet, + db_backend: DbBackend, +) { + let entity_table_name = get_entity_table_name(entity); + + for existing_index in existing_table.get_indexes() { + if !existing_index.is_unique_key() { + continue; + } + let mut has_index = entity.indexes().iter().any(|stmt| { + existing_index.get_index_spec().get_column_names() + == stmt.get_index_spec().get_column_names() + }); + if !has_index { + let index_cols = existing_index.get_index_spec().get_column_names(); + if index_cols.len() == 1 { + has_index = entity.table().get_columns().iter().any(|column_def| { + column_def.get_column_name() == index_cols[0] + && column_def.get_column_spec().unique + }); + } + } + if !has_index { + if let Some(name) = existing_index + .get_index_spec() + .get_name() + .map(|s| s.to_owned()) + { + let stmt = if db_backend == DbBackend::Postgres { + db_backend.build( + TableAlterStatement::new() + .table(entity_table_name.clone()) + .drop_constraint(name.clone()), + ) + } else { + db_backend.build(sea_query::Index::drop().name(name.clone())) + }; + + changes.record_constraint( + table_name_str.to_string(), + ConstraintChangeKind::DropUniqueConstraint { name, stmt }, + ); + } + } + } +} diff --git a/src/schema/discover/warning.rs b/src/schema/discover/warning.rs new file mode 100644 index 0000000000..c2ec22ce6e --- /dev/null +++ b/src/schema/discover/warning.rs @@ -0,0 +1,32 @@ +//! Warnings emitted during schema discovery. +//! +//! Warnings are always-on alerts about changes that cannot be handled automatically +//! and require manual intervention — typically data migration concerns. +//! Warnings can reference specific changes by their [`ChangeId`]. + +use super::changes::ChangeId; + +/// A warning emitted during schema discovery about a change requiring manual attention. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone)] +pub struct DiscoverWarning { + /// The category of warning. + pub kind: WarningKind, + /// Human-readable description of the concern. + pub message: String, + /// IDs of the changes this warning relates to, if any. + pub related_changes: Vec, +} + +/// Categories of schema discovery warnings. +/// +/// Warnings are always emitted regardless of the `assumptions` flag. +/// They represent situations that cannot be automatically resolved. +#[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WarningKind { + /// A CHECK constraint exists in entity definition but cannot be automatically diffed. + CheckConstraintDiff, + /// A column is being added with NOT NULL and no default — existing rows need data populated. + NotNullNoDefault, +} diff --git a/src/schema/entity.rs b/src/schema/entity.rs index 1273c8ff57..5b040dfd50 100644 --- a/src/schema/entity.rs +++ b/src/schema/entity.rs @@ -274,6 +274,280 @@ mod tests { use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*}; use pretty_assertions::assert_eq; + /// Postgres native enum (db_type = "Enum") — should produce CREATE TYPE on + /// Postgres, nothing on MySQL/SQLite. + #[test] + fn test_create_enum_native_postgres() { + let schema_pg = Schema::new(DbBackend::Postgres); + let enums = schema_pg.create_enum_from_entity(lunch_set::Entity); + assert_eq!( + enums.len(), + 1, + "Postgres should produce one CREATE TYPE for the Tea enum" + ); + let sql = DbBackend::Postgres.build(&enums[0]).to_string(); + assert!( + sql.contains("CREATE TYPE"), + "should be a CREATE TYPE statement: {sql}" + ); + assert!( + sql.contains("tea"), + "should reference the enum name 'tea': {sql}" + ); + + // MySQL/SQLite: no enum type statements + for backend in [DbBackend::MySql, DbBackend::Sqlite] { + let schema = Schema::new(backend); + let enums = schema.create_enum_from_entity(lunch_set::Entity); + assert!( + enums.is_empty(), + "{backend:?} should not produce enum type statements" + ); + } + } + + /// Postgres native enum column: Postgres references the custom type name, + /// MySQL uses inline ENUM('v1', 'v2'). + #[test] + fn test_native_enum_column_type_per_backend() { + let pg_sql = DbBackend::Postgres + .build(&Schema::new(DbBackend::Postgres).create_table_from_entity(lunch_set::Entity)) + .to_string(); + assert!( + pg_sql.contains("\"tea\""), + "Postgres table should reference custom type 'tea': {pg_sql}" + ); + + let mysql_sql = DbBackend::MySql + .build(&Schema::new(DbBackend::MySql).create_table_from_entity(lunch_set::Entity)) + .to_string(); + assert!( + mysql_sql.contains("ENUM("), + "MySQL table should use inline ENUM(...): {mysql_sql}" + ); + } + + /// String-based enum (db_type = "String(...)") must NOT produce any + /// CREATE TYPE statements — it's just a regular string column. + #[test] + fn test_create_enum_string_based_no_create_type() { + use crate as sea_orm; + use crate::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "String(StringLen::N(1))")] + pub enum Size { + #[sea_orm(string_value = "S")] + Small, + #[sea_orm(string_value = "L")] + Large, + } + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "shirt")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub size: Size, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + impl ActiveModelBehavior for ActiveModel {} + + for backend in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] { + let schema = Schema::new(backend); + let enums = schema.create_enum_from_entity(Entity); + assert!( + enums.is_empty(), + "{backend:?}: String-based enum should not produce CREATE TYPE" + ); + + // Verify the column appears as a string type in the table DDL + let table_sql = backend + .build(&schema.create_table_from_entity(Entity)) + .to_string(); + assert!( + !table_sql.to_uppercase().contains("CREATE TYPE"), + "{backend:?}: table DDL should not contain CREATE TYPE: {table_sql}" + ); + } + } + + /// Integer-based enum (db_type = "Integer") must NOT produce any + /// CREATE TYPE statements — it's just a regular integer column. + #[test] + fn test_create_enum_integer_based_no_create_type() { + use crate as sea_orm; + use crate::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "i32", db_type = "Integer")] + pub enum Priority { + #[sea_orm(num_value = 0)] + Low, + #[sea_orm(num_value = 1)] + High, + } + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "task")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub priority: Priority, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + impl ActiveModelBehavior for ActiveModel {} + + for backend in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] { + let schema = Schema::new(backend); + let enums = schema.create_enum_from_entity(Entity); + assert!( + enums.is_empty(), + "{backend:?}: Integer-based enum should not produce CREATE TYPE" + ); + } + } + + /// Entity with no enum columns at all — should produce nothing. + #[test] + fn test_create_enum_no_enum_columns() { + for backend in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] { + let schema = Schema::new(backend); + let enums = schema.create_enum_from_entity(cake::Entity); + assert!( + enums.is_empty(), + "{backend:?}: entity without enum columns should produce no enum statements" + ); + } + } + + /// Entity with multiple Postgres enum columns produces one CREATE TYPE per enum. + #[test] + fn test_create_enum_multiple_enum_columns() { + use crate as sea_orm; + use crate::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "color")] + pub enum Color { + #[sea_orm(string_value = "red")] + Red, + #[sea_orm(string_value = "blue")] + Blue, + } + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "shape")] + pub enum Shape { + #[sea_orm(string_value = "circle")] + Circle, + #[sea_orm(string_value = "square")] + Square, + } + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "widget")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub color: Color, + pub shape: Shape, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + impl ActiveModelBehavior for ActiveModel {} + + let schema = Schema::new(DbBackend::Postgres); + let enums = schema.create_enum_from_entity(Entity); + assert_eq!( + enums.len(), + 2, + "should produce two CREATE TYPE statements for two enum columns" + ); + + let sqls: Vec = enums + .iter() + .map(|e| DbBackend::Postgres.build(e).to_string()) + .collect(); + assert!( + sqls.iter().any(|s| s.contains("color")), + "should have CREATE TYPE for 'color': {sqls:?}" + ); + assert!( + sqls.iter().any(|s| s.contains("shape")), + "should have CREATE TYPE for 'shape': {sqls:?}" + ); + + // MySQL: no CREATE TYPE + let mysql_enums = Schema::new(DbBackend::MySql).create_enum_from_entity(Entity); + assert!(mysql_enums.is_empty()); + } + + /// Mixed entity: one Postgres native enum, one string enum, one integer enum. + /// Only the native enum should produce CREATE TYPE on Postgres. + #[test] + fn test_create_enum_mixed_column_types() { + use crate as sea_orm; + use crate::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "mood")] + pub enum Mood { + #[sea_orm(string_value = "happy")] + Happy, + #[sea_orm(string_value = "sad")] + Sad, + } + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "String(StringLen::N(10))")] + pub enum Tag { + #[sea_orm(string_value = "work")] + Work, + #[sea_orm(string_value = "play")] + Play, + } + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "i32", db_type = "Integer")] + pub enum Level { + #[sea_orm(num_value = 1)] + One, + #[sea_orm(num_value = 2)] + Two, + } + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "entry")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub mood: Mood, + pub tag: Tag, + pub level: Level, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + impl ActiveModelBehavior for ActiveModel {} + + // Only the native Postgres enum should produce a CREATE TYPE + let schema = Schema::new(DbBackend::Postgres); + let enums = schema.create_enum_from_entity(Entity); + assert_eq!( + enums.len(), + 1, + "only the native Postgres enum (Mood) should produce CREATE TYPE" + ); + let sql = DbBackend::Postgres.build(&enums[0]).to_string(); + assert!(sql.contains("mood"), "should be the 'mood' enum: {sql}"); + } + #[test] fn test_create_table_from_entity_table_ref() { for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] { diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 8be74a0281..f881f0602a 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -1,11 +1,16 @@ use crate::DbBackend; mod builder; +#[cfg(feature = "schema-sync")] +pub(crate) mod discover; mod entity; #[cfg(feature = "serde_json")] mod json; mod topology; +#[cfg(feature = "schema-sync")] +pub use discover::resolver; + pub use builder::*; use topology::*; @@ -13,6 +18,7 @@ use topology::*; /// into different [`sea_query`](crate::sea_query) statements. #[derive(Debug)] pub struct Schema { + //TODO: this struct is a wast backend: DbBackend, } @@ -27,3 +33,39 @@ impl Schema { SchemaBuilder::new(self) } } + +// Sorts tables based on their foreign key dependencies +// pub(crate) fn sorted_tables(entities: &[builder::EntitySchemaInfo]) -> Vec { +// let mut sorter = TopologicalSort::::new(); + +// for entity in entities.iter() { +// let table_name = builder::get_table_name(entity.table().get_table_name()); +// sorter.insert(table_name); +// } +// for entity in entities.iter() { +// let self_table = builder::get_table_name(entity.table().get_table_name()); +// for fk in entity.table().get_foreign_key_create_stmts().iter() { +// let fk = fk.get_foreign_key(); +// let ref_table = builder::get_table_name(fk.get_ref_table()); +// if self_table != ref_table { +// // self cycle is okay +// sorter.add_dependency(ref_table, self_table.clone()); +// } +// } +// } +// let mut sorted = Vec::new(); +// while let Some(i) = sorter.pop() { +// sorted.push(i); +// } +// if sorted.len() != entities.len() { +// // push leftover tables +// for entity in entities.iter() { +// let table_name = builder::get_table_name(entity.table().get_table_name()); +// if !sorted.contains(&table_name) { +// sorted.push(table_name); +// } +// } +// } + +// sorted +// } diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 87c820bc56..3fdd6b2fae 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -635,7 +635,7 @@ pub async fn create_categories_table(db: &DbConn) -> Result { #[cfg(feature = "postgres-vector")] pub async fn create_embedding_table(db: &DbConn) -> Result { - db.execute(sea_orm::Statement::from_string( + db.execute_raw(sea_orm::Statement::from_string( db.get_database_backend(), "CREATE EXTENSION IF NOT EXISTS vector", )) diff --git a/tests/common/features/value_type.rs b/tests/common/features/value_type.rs index a963f2c522..a58dbc283b 100644 --- a/tests/common/features/value_type.rs +++ b/tests/common/features/value_type.rs @@ -69,15 +69,9 @@ where } } -// Automatically disable vec impl #[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] pub struct StringVec(pub Vec); -// Explicitly disable vec impl -#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] -#[sea_orm(no_vec_impl)] -pub struct StringVecNoImpl(pub Vec); - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] #[sea_orm(value_type = "String")] pub enum Tag1 { diff --git a/tests/common/fixtures.rs b/tests/common/fixtures.rs new file mode 100644 index 0000000000..b5189cb23b --- /dev/null +++ b/tests/common/fixtures.rs @@ -0,0 +1,404 @@ +// --------------------------------------------------------------------------- +// Enum fixtures +// --------------------------------------------------------------------------- + +pub mod enum_v1 { + use sea_orm::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "disc_status")] + pub enum Status { + #[sea_orm(string_value = "active")] + Active, + #[sea_orm(string_value = "inactive")] + Inactive, + } + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_enum_table")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub status: Status, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod enum_v2 { + use sea_orm::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "disc_status")] + pub enum Status { + #[sea_orm(string_value = "active")] + Active, + #[sea_orm(string_value = "inactive")] + Inactive, + #[sea_orm(string_value = "pending")] + Pending, + } + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_enum_table")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub status: Status, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod enum_renamed { + use sea_orm::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "disc_state")] + pub enum State { + #[sea_orm(string_value = "active")] + Active, + #[sea_orm(string_value = "inactive")] + Inactive, + } + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_enum_table")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub status: State, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Widget fixtures (column drop / no-drop) +// --------------------------------------------------------------------------- + +pub mod widget_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_widget")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub label: String, + pub weight: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod widget_v2 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_widget")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub label: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Combo fixtures (rename detection) +// --------------------------------------------------------------------------- + +pub mod combo_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_combo")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub old_field: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod combo_v2 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_combo")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub new_field: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Category / Article fixtures (FK drop) +// --------------------------------------------------------------------------- + +pub mod category_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_category")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod article_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_article")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub title: String, + #[sea_orm(belongs_to, from = "category_id", to = "id")] + pub category: HasOne, + pub category_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod article_v2 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_article")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub title: String, + pub category_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Tag fixture (orphan table) +// --------------------------------------------------------------------------- + +pub mod tag_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_tag")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub slug: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Column addition fixtures +// --------------------------------------------------------------------------- + +pub mod coltest_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_coltest")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod coltest_v2_nullable { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_coltest")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub bio: Option, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod coltest_v2_notnull { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_coltest")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub age: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod coltest_v2_notnull_default { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_coltest")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + #[sea_orm(default_value = 0)] + pub score: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod coltest_v2_multi { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "disc_coltest")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + pub bio: Option, + pub age: i32, + #[sea_orm(default_value = 0)] + pub score: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +// --------------------------------------------------------------------------- +// Parent / Child fixtures (FK ordering in drops) +// --------------------------------------------------------------------------- + +/// Simple parent table with no FK dependencies. +pub mod fk_parent_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_fk_parent")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// Grandparent table — the root of a three-level FK chain used in complex drop-sequence tests. +pub mod fk_grandparent_v1 { + use sea_orm::entity::prelude::*; + + #[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum)] + #[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "drop_seq_status")] + pub enum Status { + #[sea_orm(string_value = "active")] + Active, + #[sea_orm(string_value = "inactive")] + Inactive, + } + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "drop_seq_gp")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub status: Status, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// Middle table — has an FK to `drop_seq_gp` and is itself referenced by `drop_seq_child`. +pub mod fk_mid_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "drop_seq_mid")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(belongs_to, from = "gp_id", to = "id")] + pub grandparent: HasOne, + pub gp_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// Child table with an FK pointing to `sync_fk_parent`. +pub mod fk_child_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "sync_fk_child")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(belongs_to, from = "parent_id", to = "id")] + pub parent: HasOne, + pub parent_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} + +/// Leaf table — deepest in the three-level chain, references `drop_seq_mid`. +pub mod fk_leaf_v1 { + use sea_orm::entity::prelude::*; + + #[sea_orm::model] + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "drop_seq_leaf")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(belongs_to, from = "mid_id", to = "id")] + pub mid: HasOne, + pub mid_id: i32, + } + + impl ActiveModelBehavior for ActiveModel {} +} diff --git a/tests/common/helpers.rs b/tests/common/helpers.rs new file mode 100644 index 0000000000..fd6b33db39 --- /dev/null +++ b/tests/common/helpers.rs @@ -0,0 +1,149 @@ +use sea_orm::{DatabaseConnection, DbBackend, DbErr, query::*}; +use sea_orm::sea_query::{Alias, Condition, Expr, Query}; + +#[cfg(feature = "schema-sync")] +use sea_orm::{InterpretConfig, InterpretResult, schema::SchemaBuilder}; + +/// Runs `discover` + `interpret_changes` + executes every emitted statement, then +/// returns the [`InterpretResult`] so callers can make additional assertions on +/// warnings / suggestions. +/// +/// This is the "ground-truth" form of a sync round-trip: unlike tests that only +/// inspect the generated SQL, this helper actually applies the statements to the +/// live database so follow-up `column_exists`/`table_exists` checks are meaningful. +#[cfg(feature = "schema-sync")] +pub async fn discover_interpret_and_apply( + db: &DatabaseConnection, + builder: SchemaBuilder, + config: InterpretConfig, +) -> Result { + let dangerous = config.allow_dangerous; + let change_set = builder.discover(db, dangerous).await?; + let result = sea_orm::interpret_changes(change_set, &config); + for (_, stmt) in &result.statements { + db.execute_raw(stmt.clone()).await?; + } + Ok(result) +} + +pub async fn table_exists(db: &DatabaseConnection, table: &str) -> Result { + match db.get_database_backend() { + #[cfg(feature = "sqlx-postgres")] + DbBackend::Postgres => { + let row = db + .query_one( + Query::select() + .expr(Expr::cust("COUNT(*) > 0")) + .from((Alias::new("information_schema"), Alias::new("tables"))) + .cond_where( + Condition::all() + .add(Expr::cust("table_schema = CURRENT_SCHEMA()")) + .add(Expr::col("table_name").eq(table)), + ), + ) + .await?; + Ok(row + .map(|r| r.try_get_by_index::(0).unwrap_or(false)) + .unwrap_or(false)) + } + #[cfg(feature = "sqlx-mysql")] + DbBackend::MySql => { + let row = db + .query_one( + Query::select() + .expr(Expr::cust("COUNT(*) > 0")) + .from((Alias::new("information_schema"), Alias::new("tables"))) + .cond_where( + Condition::all() + .add(Expr::cust("table_schema = DATABASE()")) + .add(Expr::col("table_name").eq(table)), + ), + ) + .await?; + Ok(row + .map(|r| r.try_get_by_index::(0).unwrap_or(false)) + .unwrap_or(false)) + } + #[cfg(any(feature = "sqlx-sqlite", feature = "rusqlite"))] + DbBackend::Sqlite => { + let row = db + .query_one( + Query::select() + .expr(Expr::cust("COUNT(*) > 0")) + .from(Alias::new("sqlite_master")) + .cond_where( + Condition::all() + .add(Expr::col("type").eq("table")) + .add(Expr::col("name").eq(table)), + ), + ) + .await?; + Ok(row + .map(|r| r.try_get_by_index::(0).unwrap_or(false)) + .unwrap_or(false)) + } + _ => Ok(false), + } +} + +pub async fn column_exists( + db: &DatabaseConnection, + table: &str, + column: &str, +) -> Result { + match db.get_database_backend() { + #[cfg(feature = "sqlx-postgres")] + DbBackend::Postgres => { + let row = db + .query_one( + Query::select() + .expr(Expr::cust("COUNT(*) > 0")) + .from((Alias::new("information_schema"), Alias::new("columns"))) + .cond_where( + Condition::all() + .add(Expr::cust("table_schema = CURRENT_SCHEMA()")) + .add(Expr::col("table_name").eq(table)) + .add(Expr::col("column_name").eq(column)), + ), + ) + .await?; + Ok(row + .map(|r| r.try_get_by_index::(0).unwrap_or(false)) + .unwrap_or(false)) + } + #[cfg(feature = "sqlx-mysql")] + DbBackend::MySql => { + let row = db + .query_one( + Query::select() + .expr(Expr::cust("COUNT(*) > 0")) + .from((Alias::new("information_schema"), Alias::new("columns"))) + .cond_where( + Condition::all() + .add(Expr::cust("table_schema = DATABASE()")) + .add(Expr::col("table_name").eq(table)) + .add(Expr::col("column_name").eq(column)), + ), + ) + .await?; + Ok(row + .map(|r| r.try_get_by_index::(0).unwrap_or(false)) + .unwrap_or(false)) + } + #[cfg(any(feature = "sqlx-sqlite", feature = "rusqlite"))] + DbBackend::Sqlite => { + let rows = db + .query_all_raw(sea_orm::Statement::from_string( + DbBackend::Sqlite, + format!("PRAGMA table_info(\"{table}\")"), + )) + .await?; + Ok(rows.iter().any(|r| { + r.try_get_by_index::(1) + .map(|n| n == column) + .unwrap_or(false) + })) + } + _ => Ok(false), + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 756432ceed..72354e3f4a 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -3,6 +3,8 @@ pub mod bakery_dense; pub mod blogger; pub mod features; pub mod film_store; +pub mod fixtures; +pub mod helpers; #[cfg(not(feature = "sync"))] pub mod runtime; pub mod setup; diff --git a/tests/derive_tests.rs b/tests/derive_tests.rs index e33cae65c7..e482280052 100644 --- a/tests/derive_tests.rs +++ b/tests/derive_tests.rs @@ -69,69 +69,3 @@ struct FromQueryResultNested { #[sea_orm(nested)] _test: SimpleTest, } - -#[cfg(feature = "postgres-array")] -mod postgres_array { - use crate::FromQueryResult; - use sea_orm::DeriveValueType; - - #[derive(DeriveValueType)] - pub struct IngredientId(i32); - - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] - #[sea_orm(value_type = "String")] - pub struct NumericLabel { - pub value: i64, - } - - impl std::fmt::Display for NumericLabel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.value) - } - } - - impl std::str::FromStr for NumericLabel { - type Err = std::num::ParseIntError; - fn from_str(s: &str) -> Result { - Ok(Self { value: s.parse()? }) - } - } - - #[derive(Copy, Clone, Debug, PartialEq, Eq, DeriveValueType)] - #[sea_orm(value_type = "String")] - pub enum TextureKind { - Hard, - Soft, - } - - impl std::fmt::Display for TextureKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Hard => "hard", - Self::Soft => "soft", - } - ) - } - } - - impl std::str::FromStr for TextureKind { - type Err = sea_query::ValueTypeErr; - fn from_str(s: &str) -> Result { - Ok(match s { - "hard" => Self::Hard, - "soft" => Self::Soft, - _ => return Err(sea_query::ValueTypeErr), - }) - } - } - - #[derive(FromQueryResult)] - pub struct IngredientPathRow { - pub ingredient_path: Vec, - pub numeric_label_path: Vec, - pub texture_path: Vec, - } -} diff --git a/tests/schema_discover_tests.rs b/tests/schema_discover_tests.rs new file mode 100644 index 0000000000..7a0adee689 --- /dev/null +++ b/tests/schema_discover_tests.rs @@ -0,0 +1,1201 @@ +//! Tests for the schema discovery module (`src/schema/discover/`). +//! +//! These tests verify that `discover()` produces the correct statements and +//! warnings for enum, table, column, index, and foreign key changes. +#![allow(unused_imports, dead_code)] +pub mod common; + +use crate::common::TestContext; +use crate::common::fixtures::*; +use crate::common::helpers::{column_exists, table_exists}; +#[cfg(feature = "schema-sync")] +use crate::common::helpers::discover_interpret_and_apply; +use sea_orm::{DatabaseConnection, DbErr, entity::*, query::*}; + +// --------------------------------------------------------------------------- +// Enum discovery tests (Postgres-only) +// --------------------------------------------------------------------------- + +/// discover() on a brand-new entity must include a CREATE TYPE statement +/// for the Postgres enum before the CREATE TABLE. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_creates_enum_type() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_creates_enum_type").await; + let db = &ctx.db; + + let builder = db.get_schema_builder().register(enum_v1::Entity); + let change_set = builder.discover(db, false).await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + let has_create_type = result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("CREATE TYPE")); + assert!( + has_create_type, + "discover() must include CREATE TYPE for Postgres enum; got: {:?}", + result.statements + ); + + let has_create_table = result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("CREATE TABLE")); + assert!( + has_create_table, + "discover() must include CREATE TABLE; got: {:?}", + result.statements + ); + + let type_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("CREATE TYPE")) + .unwrap(); + let table_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("CREATE TABLE")) + .unwrap(); + assert!( + type_pos < table_pos, + "CREATE TYPE must precede CREATE TABLE; type at {type_pos}, table at {table_pos}" + ); + + ctx.delete().await; + Ok(()) +} + +/// After an entity is synced, discover() must detect the existing enum and +/// NOT produce a duplicate CREATE TYPE statement. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_skips_existing_enum() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_skips_existing_enum").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(enum_v1::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result + .statements + .iter() + .all(|(_, s)| !s.sql.to_uppercase().contains("CREATE TYPE")), + "discover() must NOT re-create an existing enum type; got: {:?}", + result.statements + ); + assert!( + result.statements.is_empty(), + "discover() should produce no changes when schema matches; got: {:?}", + result.statements + ); + + ctx.delete().await; + Ok(()) +} + +/// When an enum's variants change, dangerous discover must emit an EnumVariantChange warning. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_enum_variant_change_warning() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_enum_variant_change_warning").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(enum_v2::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }, + ); + + assert!( + result + .suggestions + .iter() + .any(|s| s.kind == sea_orm::schema::SuggestionKind::EnumVariantChange), + "expected EnumVariantChange suggestion when enum gains a variant; got: {:?}", + result.suggestions + ); + assert!( + result + .statements + .iter() + .all(|(_, s)| !s.sql.to_uppercase().contains("CREATE TYPE")), + "changed enum should produce a warning, not a CREATE TYPE; got: {:?}", + result.statements + ); + + ctx.delete().await; + Ok(()) +} + +/// When an enum type name changes (same variants), dangerous discover must emit a warning. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_enum_rename_warning() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_enum_rename_warning").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(enum_renamed::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }, + ); + + assert!( + result + .suggestions + .iter() + .any(|s| s.kind == sea_orm::schema::SuggestionKind::EnumRename), + "expected EnumRename suggestion when enum type is renamed; got: {:?}", + result.suggestions + ); + + + ctx.delete().await; + Ok(()) +} + +/// Safe discover must NOT produce enum warnings even when variants changed. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_safe_no_enum_warnings() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_safe_no_enum_warnings").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(enum_v2::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result + .suggestions + .iter() + .all(|s| s.kind != sea_orm::schema::SuggestionKind::EnumVariantChange), + "safe discover should not suggest enum changes; got: {:?}", + result.suggestions + ); + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Column drop / safe tests +// --------------------------------------------------------------------------- + +/// When `allow_dangerous = true`, discover() must include a DROP COLUMN for removed columns. +/// Changes must NOT be applied until the caller executes them. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_drop_column() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_drop_column").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(widget_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(widget_v2::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ); + + assert!( + result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("DROP COLUMN")), + "discover(dangerous=true) must include DROP COLUMN for `weight`; got: {:?}", + result.statements + ); + assert!( + column_exists(db, "sync_widget", "weight").await?, + "discover() must not apply changes; `weight` column should still exist" + ); + + ctx.delete().await; + Ok(()) +} + +/// When `allow_dangerous = false`, discover() must NEVER produce any DROP statements. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +async fn test_discover_safe_no_drops() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_safe_no_drops").await; + let db = &ctx.db; + + #[cfg(feature = "schema-sync")] + { + db.get_schema_builder() + .register(widget_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(widget_v2::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result + .statements + .iter() + .all(|(_, s)| !s.sql.to_uppercase().contains("DROP")), + "discover(dangerous=false) must not include any DROP statements; got: {:?}", + result.statements + ); + } + + ctx.delete().await; + Ok(()) +} + +/// Applying dangerous changes actually drops the column. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_sync_dangerous_drops_column() -> Result<(), DbErr> { + let ctx = TestContext::new("test_sync_dangerous_drops_column").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(widget_v1::Entity) + .sync(db) + .await?; + assert!(column_exists(db, "sync_widget", "weight").await?); + + discover_interpret_and_apply( + db, + db.get_schema_builder().register(widget_v2::Entity), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + assert!(!column_exists(db, "sync_widget", "weight").await?); + assert!(column_exists(db, "sync_widget", "label").await?); + + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Table drop tests +// --------------------------------------------------------------------------- + +/// discover(dangerous=true) must include DROP TABLE for orphaned tables. +#[sea_orm_macros::test] +#[cfg(feature = "schema-sync")] +async fn test_discover_drop_table() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_drop_table").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(tag_v1::Entity) + .register(widget_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(widget_v1::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ); + + assert!( + result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("DROP TABLE")), + "discover(dangerous=true) must include DROP TABLE for `sync_tag`; got: {:?}", + result.statements + ); + assert!( + table_exists(db, "sync_tag").await?, + "discover() must not apply changes; `sync_tag` should still exist" + ); + + + ctx.delete().await; + Ok(()) +} + +/// Applying dangerous changes actually drops the orphaned table. +#[sea_orm_macros::test] +#[cfg(feature = "schema-sync")] +async fn test_sync_dangerous_drops_table() -> Result<(), DbErr> { + let ctx = TestContext::new("test_sync_dangerous_drops_table").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(tag_v1::Entity) + .register(widget_v1::Entity) + .sync(db) + .await?; + assert!(table_exists(db, "sync_tag").await?); + + discover_interpret_and_apply( + db, + db.get_schema_builder().register(widget_v1::Entity), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + assert!(!table_exists(db, "sync_tag").await?); + assert!(table_exists(db, "sync_widget").await?); + + + ctx.delete().await; + Ok(()) +} + +/// When both a parent and child table are orphaned, the child must appear before +/// the parent in the DROP TABLE statements (to avoid FK constraint violations). +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_sync_dangerous_drops_orphan_table_fk_order() -> Result<(), DbErr> { + let ctx = TestContext::new("test_sync_dangerous_drops_orphan_table_fk_order").await; + let db = &ctx.db; + + + // Create both tables; fk_child has an FK to fk_parent. + db.get_schema_builder() + .register(fk_parent_v1::Entity) + .register(fk_child_v1::Entity) + .sync(db) + .await?; + + assert!(table_exists(db, "sync_fk_parent").await?); + assert!(table_exists(db, "sync_fk_child").await?); + + // Discover with no registered entities → both tables are orphans; apply in one shot. + let result = discover_interpret_and_apply( + db, + db.get_schema_builder(), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + let child_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("SYNC_FK_CHILD")) + .expect("DROP TABLE sync_fk_child must be in statements"); + let parent_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("SYNC_FK_PARENT")) + .expect("DROP TABLE sync_fk_parent must be in statements"); + + assert!( + child_pos < parent_pos, + "child table must be dropped before parent to avoid FK violation; \ + child at {child_pos}, parent at {parent_pos}" + ); + + assert!(!table_exists(db, "sync_fk_child").await?); + assert!(!table_exists(db, "sync_fk_parent").await?); + + + ctx.delete().await; + Ok(()) +} + + +/// discover(dangerous=true) must include DROP FOREIGN KEY / CONSTRAINT when a FK is removed. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_drop_foreign_key() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_drop_foreign_key").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(category_v1::Entity) + .register(article_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(category_v1::Entity) + .register(article_v2::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ); + + let has_drop_fk = result.statements.iter().any(|(_, s)| { + let sql = s.sql.to_uppercase(); + sql.contains("DROP FOREIGN KEY") || sql.contains("DROP CONSTRAINT") + }); + assert!( + has_drop_fk, + "discover(dangerous=true) must include DROP FOREIGN KEY / CONSTRAINT; got: {:?}", + result.statements + ); + + + ctx.delete().await; + Ok(()) +} + +/// Applying dangerous changes actually removes the FK. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_sync_dangerous_drops_foreign_key() -> Result<(), DbErr> { + let ctx = TestContext::new("test_sync_dangerous_drops_foreign_key").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(category_v1::Entity) + .register(article_v1::Entity) + .sync(db) + .await?; + + discover_interpret_and_apply( + db, + db.get_schema_builder() + .register(category_v1::Entity) + .register(article_v2::Entity), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + assert!(table_exists(db, "sync_article").await?); + + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// No-change test +// --------------------------------------------------------------------------- + +/// When the schema already matches, discover() must return an empty change set. +#[sea_orm_macros::test] +#[cfg(feature = "schema-sync")] +async fn test_discover_no_changes_when_synced() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_no_changes_when_synced").await; + let db = &ctx.db; + + + db.get_schema_builder() + .register(widget_v1::Entity) + .sync(db) + .await?; + + for dangerous in [false, true] { + let change_set = db + .get_schema_builder() + .register(widget_v1::Entity) + .discover(db, dangerous) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: dangerous, + }, + ); + assert!( + result.statements.is_empty(), + "discover(dangerous={dangerous}) must return no changes when schema is up-to-date; got: {:?}", + result.statements + ); + } + + + ctx.delete().await; + Ok(()) +} + + +/// Dangerous sync with assumptions=true auto-assumes obvious rename and generates RENAME COLUMN. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_sync_dangerous_add_and_drop_column() -> Result<(), DbErr> { + let ctx = TestContext::new("test_sync_dangerous_add_and_drop_column").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(combo_v1::Entity) + .sync(db) + .await?; + + assert!(column_exists(db, "sync_combo", "old_field").await?); + assert!(!column_exists(db, "sync_combo", "new_field").await?); + + let change_set = db + .get_schema_builder() + .register(combo_v2::Entity) + .discover(db, true) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: true, + allow_dangerous: true, + }, + ); + + assert!( + result + .suggestions + .iter() + .any(|s| s.kind == sea_orm::schema::SuggestionKind::PossibleRename), + "expected PossibleRename suggestion; got: {:?}", + result.suggestions + ); + assert!( + result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("RENAME COLUMN")), + "auto-assumed rename should produce RENAME COLUMN; got: {:?}", + result.statements + ); + assert!( + result + .statements + .iter() + .all(|(_, s)| !s.sql.to_uppercase().contains("ADD COLUMN") + && !s.sql.to_uppercase().contains("DROP COLUMN")), + "rename-detected pair should not produce ADD/DROP; got: {:?}", + result.statements + ); + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Column addition tests +// --------------------------------------------------------------------------- + +/// Adding a nullable column should produce an ADD COLUMN and no warnings. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_add_nullable_column() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_add_nullable_column").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(coltest_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(coltest_v2_nullable::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result.statements.iter().any(|(_, s)| { + let sql = s.sql.to_uppercase(); + sql.contains("ADD COLUMN") && sql.contains("BIO") + }), + "discover() must include ADD COLUMN for `bio`; got: {:?}", + result.statements + ); + assert!( + result + .warnings + .iter() + .all(|w| w.kind != sea_orm::schema::WarningKind::NotNullNoDefault), + "nullable column must not produce NotNullNoDefault warning; got: {:?}", + result.warnings + ); + + ctx.delete().await; + Ok(()) +} + +/// Adding a NOT NULL column without a default should produce a NotNullNoDefault warning. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_add_notnull_column_warns() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_add_notnull_column_warns").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(coltest_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(coltest_v2_notnull::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result.statements.iter().any(|(_, s)| { + let sql = s.sql.to_uppercase(); + sql.contains("ADD COLUMN") && sql.contains("AGE") + }), + "discover() must include ADD COLUMN for `age`; got: {:?}", + result.statements + ); + assert!( + result.warnings.iter().any(|w| { + w.kind == sea_orm::schema::WarningKind::NotNullNoDefault + && w.message.contains("age") + }), + "NOT NULL column without default must produce NotNullNoDefault warning; got: {:?}", + result.warnings + ); + + ctx.delete().await; + Ok(()) +} + +/// Adding a NOT NULL column WITH a default should NOT produce a NotNullNoDefault warning. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_add_notnull_with_default_no_warn() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_add_notnull_with_default_no_warn").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(coltest_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(coltest_v2_notnull_default::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result.statements.iter().any(|(_, s)| { + let sql = s.sql.to_uppercase(); + sql.contains("ADD COLUMN") && sql.contains("SCORE") + }), + "discover() must include ADD COLUMN for `score`; got: {:?}", + result.statements + ); + assert!( + result + .warnings + .iter() + .all(|w| w.kind != sea_orm::schema::WarningKind::NotNullNoDefault), + "NOT NULL column with default should not warn; got: {:?}", + result.warnings + ); + + ctx.delete().await; + Ok(()) +} + +/// Adding multiple columns produces ADD COLUMN for each and warns only for NOT NULL without default. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_add_multiple_columns() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_add_multiple_columns").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(coltest_v1::Entity) + .sync(db) + .await?; + + let change_set = db + .get_schema_builder() + .register(coltest_v2_multi::Entity) + .discover(db, false) + .await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + for col in ["BIO", "AGE", "SCORE"] { + assert!( + result.statements.iter().any(|(_, s)| { + let sql = s.sql.to_uppercase(); + sql.contains("ADD COLUMN") && sql.contains(col) + }), + "discover() must include ADD COLUMN for `{col}`; got: {:?}", + result.statements + ); + } + + let not_null_warnings: Vec<_> = result + .warnings + .iter() + .filter(|w| w.kind == sea_orm::schema::WarningKind::NotNullNoDefault) + .collect(); + assert_eq!( + not_null_warnings.len(), + 1, + "expected exactly one NotNullNoDefault warning (for `age`); got: {not_null_warnings:?}" + ); + assert!( + not_null_warnings[0].message.contains("age"), + "the warning should reference `age`; got: {:?}", + not_null_warnings[0].message + ); + + ctx.delete().await; + Ok(()) +} + +/// After discovering an ADD COLUMN, applying the statements creates the column. +#[sea_orm_macros::test] +#[cfg(not(any(feature = "sqlx-sqlite", feature = "rusqlite")))] +#[cfg(feature = "schema-sync")] +async fn test_discover_add_column_applies() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_add_column_applies").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(coltest_v1::Entity) + .sync(db) + .await?; + + assert!(!column_exists(db, "disc_coltest", "bio").await?); + + discover_interpret_and_apply( + db, + db.get_schema_builder().register(coltest_v2_nullable::Entity), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ) + .await?; + + assert!(column_exists(db, "disc_coltest", "bio").await?); + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Enum drop tests (Postgres-only) +// --------------------------------------------------------------------------- + +/// When an enum type exists in the DB but has no matching entity, discover(dangerous=true) +/// must include a DROP TYPE statement. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_drops_orphan_enum_type() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_drops_orphan_enum_type").await; + let db = &ctx.db; + + // Sync the enum entity to create the type in the DB. + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + // Discover with NO entities registered — the enum is now orphaned. + let change_set = db.get_schema_builder().discover(db, true).await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ); + + let has_drop_type = result + .statements + .iter() + .any(|(_, s)| s.sql.to_uppercase().contains("DROP TYPE")); + assert!( + has_drop_type, + "discover(dangerous=true) must include DROP TYPE for orphaned enum; got: {:?}", + result.statements + ); + + ctx.delete().await; + Ok(()) +} + +/// DROP TABLE must appear before DROP TYPE in the statement list (so the table +/// referencing the enum is gone before the type is dropped). +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_drop_table_before_enum_type() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_drop_table_before_enum_type").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let result = discover_interpret_and_apply( + db, + db.get_schema_builder(), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + let table_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("DROP TABLE")) + .expect("DROP TABLE must be present"); + let type_pos = result + .statements + .iter() + .position(|(_, s)| s.sql.to_uppercase().contains("DROP TYPE")) + .expect("DROP TYPE must be present"); + + assert!( + table_pos < type_pos, + "DROP TABLE must precede DROP TYPE; table at {table_pos}, type at {type_pos}" + ); + + assert!(!table_exists(db, "disc_enum_table").await?); + + ctx.delete().await; + Ok(()) +} + +/// Safe discover (allow_dangerous=false) must NOT produce DROP TYPE even for orphaned enums. +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_discover_safe_no_drop_enum_type() -> Result<(), DbErr> { + let ctx = TestContext::new("test_discover_safe_no_drop_enum_type").await; + let db = &ctx.db; + + db.get_schema_builder() + .register(enum_v1::Entity) + .sync(db) + .await?; + + let change_set = db.get_schema_builder().discover(db, false).await?; + let result = sea_orm::interpret_changes( + change_set, + &sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: false, + }, + ); + + assert!( + result + .statements + .iter() + .all(|(_, s)| !s.sql.to_uppercase().contains("DROP TYPE")), + "safe discover must not produce DROP TYPE; got: {:?}", + result.statements + ); + + ctx.delete().await; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Complex drop-sequence test +// --------------------------------------------------------------------------- + +/// Full drop-sequence correctness: three-level FK chain (grandparent → mid → leaf) +/// where the grandparent has an enum column, all orphaned at once. +/// +/// Verifies the complete required ordering in a single `discover + execute` pass: +/// 1. DROP CONSTRAINT (FK drops) before DROP TABLE for the same table +/// 2. Leaf before mid before grandparent (child-first table drops) +/// 3. DROP TYPE after all DROP TABLE statements +/// 4. The statements actually execute without FK/type-dependency errors +#[sea_orm_macros::test] +#[cfg(feature = "sqlx-postgres")] +#[cfg(feature = "schema-sync")] +async fn test_complex_drop_sequence() -> Result<(), DbErr> { + let ctx = TestContext::new("test_complex_drop_sequence").await; + let db = &ctx.db; + + // Build the three-level chain: grandparent (has enum) → mid → leaf. + db.get_schema_builder() + .register(fk_grandparent_v1::Entity) + .register(fk_mid_v1::Entity) + .register(fk_leaf_v1::Entity) + .sync(db) + .await?; + + assert!(table_exists(db, "drop_seq_gp").await?); + assert!(table_exists(db, "drop_seq_mid").await?); + assert!(table_exists(db, "drop_seq_leaf").await?); + + // Orphan all three by registering nothing; apply the statements in one shot. + let result = discover_interpret_and_apply( + db, + db.get_schema_builder(), + sea_orm::InterpretConfig { + db_backend: db.get_database_backend(), + assumptions: false, + allow_dangerous: true, + }, + ) + .await?; + + let stmts: Vec<_> = result + .statements + .iter() + .map(|(_, s)| s.sql.to_uppercase()) + .collect(); + + // ── 1. All expected statement kinds are present ────────────────────── + assert!( + stmts.iter().any(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_LEAF")), + "DROP TABLE drop_seq_leaf must be present; got:\n{stmts:#?}" + ); + assert!( + stmts.iter().any(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_MID")), + "DROP TABLE drop_seq_mid must be present; got:\n{stmts:#?}" + ); + assert!( + stmts.iter().any(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_GP")), + "DROP TABLE drop_seq_gp must be present; got:\n{stmts:#?}" + ); + assert!( + stmts.iter().any(|s| s.contains("DROP TYPE")), + "DROP TYPE must be present for the orphaned enum; got:\n{stmts:#?}" + ); + + // ── 2. Child-first table-drop order across all three levels ────────── + let leaf_pos = stmts + .iter() + .position(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_LEAF")) + .unwrap(); + let mid_pos = stmts + .iter() + .position(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_MID")) + .unwrap(); + let gp_pos = stmts + .iter() + .position(|s| s.contains("DROP TABLE") && s.contains("DROP_SEQ_GP")) + .unwrap(); + assert!( + leaf_pos < mid_pos, + "leaf must be dropped before mid; leaf at {leaf_pos}, mid at {mid_pos}" + ); + assert!( + mid_pos < gp_pos, + "mid must be dropped before grandparent; mid at {mid_pos}, gp at {gp_pos}" + ); + + // ── 3. Any FK DROP CONSTRAINT comes before its own table's DROP TABLE ─ + for (fk_table, fk_table_upper) in [ + ("drop_seq_mid", "DROP_SEQ_MID"), + ("drop_seq_leaf", "DROP_SEQ_LEAF"), + ] { + let drop_table_pos = stmts + .iter() + .position(|s| s.contains("DROP TABLE") && s.contains(fk_table_upper)) + .unwrap(); + if let Some(drop_constraint_pos) = stmts + .iter() + .position(|s| s.contains("DROP CONSTRAINT") && s.contains(fk_table_upper)) + { + assert!( + drop_constraint_pos < drop_table_pos, + "DROP CONSTRAINT on {fk_table} must precede DROP TABLE {fk_table}; \ + constraint at {drop_constraint_pos}, table at {drop_table_pos}" + ); + } + } + + // ── 4. DROP TYPE is after all DROP TABLE statements ─────────────────── + let last_drop_table_pos = stmts + .iter() + .rposition(|s| s.contains("DROP TABLE")) + .unwrap(); + let drop_type_pos = stmts + .iter() + .position(|s| s.contains("DROP TYPE")) + .unwrap(); + assert!( + drop_type_pos > last_drop_table_pos, + "DROP TYPE must come after all DROP TABLE statements; \ + last DROP TABLE at {last_drop_table_pos}, DROP TYPE at {drop_type_pos}" + ); + + // ── 5. Verify tables are gone (statements were applied by discover_interpret_and_apply) ── + assert!(!table_exists(db, "drop_seq_leaf").await?); + assert!(!table_exists(db, "drop_seq_mid").await?); + assert!(!table_exists(db, "drop_seq_gp").await?); + + ctx.delete().await; + Ok(()) +} diff --git a/tests/schema_sync_tests.rs b/tests/schema_sync_tests.rs index 5c0281fd9f..81f8dd7f62 100644 --- a/tests/schema_sync_tests.rs +++ b/tests/schema_sync_tests.rs @@ -1,5 +1,4 @@ #![allow(unused_imports, dead_code)] - pub mod common; use crate::common::TestContext;