diff --git a/docs/guide/modeling/model.md b/docs/guide/modeling/model.md index f73f0b52e..668982acc 100644 --- a/docs/guide/modeling/model.md +++ b/docs/guide/modeling/model.md @@ -84,6 +84,7 @@ properties: {} | `columns` | Yes | List of columns to expose (see [Column Fields](#column-fields)) | | `primary_key` | No | Column name that uniquely identifies a row; required for relationships | | `cached` | No | Whether query results for this model should be cached; `false` by default | +| `dialect` | No | SQL dialect of the model's `ref_sql` (e.g. `bigquery`, `postgres`). Overrides the project-level `data_source` for this model. Requires `schema_version: 3`. See [Dialect Override](./wren_project.md#dialect-override). | | `properties` | No | Arbitrary key-value metadata (description, tags, etc.) | ## Data Source: Two Ways to Point at Data diff --git a/docs/guide/modeling/view.md b/docs/guide/modeling/view.md index 9b156d8dc..780eaed18 100644 --- a/docs/guide/modeling/view.md +++ b/docs/guide/modeling/view.md @@ -53,6 +53,7 @@ statement: > |-------|----------|-------------| | `name` | Yes | Unique identifier used in SQL queries | | `statement` | Yes | A complete SQL SELECT statement; may reference other models or views | +| `dialect` | No | SQL dialect of the view's `statement` (e.g. `bigquery`, `postgres`). Currently metadata only — the engine always parses view statements with its generic SQL parser. Requires `schema_version: 3`. See [Dialect Override](./wren_project.md#dialect-override). | | `properties` | No | Arbitrary key-value metadata (use `properties.description` for a human-readable description) | ## Model vs View diff --git a/docs/guide/modeling/wren_project.md b/docs/guide/modeling/wren_project.md index c96fc17c0..7da8aa56d 100644 --- a/docs/guide/modeling/wren_project.md +++ b/docs/guide/modeling/wren_project.md @@ -85,7 +85,7 @@ default_project: ~/projects/sales ### `wren_project.yml` ```yaml -schema_version: 2 +schema_version: 3 name: my_project version: "1.0" catalog: wren @@ -95,7 +95,7 @@ data_source: postgres | Field | Description | |-------|-------------| -| `schema_version` | Directory layout version. `2` = folder-per-entity (current). Owned by the CLI — do not bump manually. | +| `schema_version` | Directory layout version. `2` = folder-per-entity, `3` = adds `dialect` field support (current). Owned by the CLI — do not bump manually. | | `name` | Project name | | `version` | User's own project version (free-form, no effect on parsing) | | `catalog` | **Wren Engine namespace** — NOT your database catalog. Identifies this MDL project within the engine. Default: `wren`. | @@ -146,6 +146,19 @@ cached: false properties: {} ``` +**`dialect`** — optional field declaring which SQL dialect the model's `ref_sql` is written in. When omitted, the project-level `data_source` is used. This lets a single project contain models whose SQL targets different databases: + +```yaml +name: revenue +ref_sql: "SELECT * FROM `project.dataset.table`" +dialect: bigquery +columns: + - name: amount + type: DECIMAL +``` + +Requires `schema_version: 3`. See [Dialect Override](#dialect-override) for details. + **ref_sql** — defines the model via a SQL query. SQL can be inline in `metadata.yml` or in a separate `ref_sql.sql` file (the `.sql` file takes precedence if both exist): ```yaml @@ -186,6 +199,16 @@ properties: description: "Top customers by lifetime value" ``` +Like models, views support an optional **`dialect`** field (requires `schema_version: 3`): + +```yaml +name: monthly_summary +statement: "SELECT date_trunc('month', created_at) FROM orders" +dialect: postgres +``` + +When set, the dialect is stored as metadata for downstream consumers. It does not currently affect how the engine parses the view's statement — view statements are always normalized into a logical plan via DataFusion's generic SQL parser. See [Dialect Override](#dialect-override) for details. + ### `relationships.yml` ```yaml @@ -226,6 +249,7 @@ wren context init → scaffold project in current directory (edit models/, relationships.yml, instructions.md) wren context validate → check YAML structure (no DB needed) wren context build → compile to target/mdl.json +wren context upgrade → upgrade project to latest schema_version wren profile add my-pg ... → save connection to ~/.wren/profiles.yml wren memory index → index schema + instructions into .wren/memory/ wren --sql "SELECT 1" → verify connection @@ -280,6 +304,65 @@ wren context init --from-mdl mdl.json --path my_project --force ``` > **When to use this:** You have an existing `mdl.json` that was authored by hand or generated by an older workflow (e.g. the MCP server's `mdl_save_project` tool), and you want to adopt the YAML project format for version control and CLI-driven workflows. +> +> The import is `layoutVersion`-aware: manifests with `layoutVersion: 2` produce a `schema_version: 3` project with `dialect` fields preserved. Manifests without `layoutVersion` (or `layoutVersion: 1`) produce a `schema_version: 2` project. + +--- + +## Upgrading an Existing Project + +When new features are added to the project format (e.g. the `dialect` field in schema_version 3), use `wren context upgrade` to bring your project up to date: + +```bash +wren context upgrade --path my_project +``` + +This upgrades to the latest `schema_version`. The command handles all intermediate steps automatically — for example, upgrading from v1 to v3 applies v1→v2 (restructure flat files into directories) then v2→v3 (enable dialect support). + +### What each upgrade does + +| Upgrade | File changes | +|---------|-------------| +| v1 → v2 | `models/*.yml` flat files → `models//metadata.yml` directories; `ref_sql` extracted to `ref_sql.sql`; `views.yml` → `views//metadata.yml` directories; old files deleted | +| v2 → v3 | No file layout changes — only bumps `schema_version` in `wren_project.yml` to enable `dialect` field support | + +### Options + +| Flag | Description | +|------|-------------| +| `--to N` | Upgrade to a specific schema_version instead of the latest | +| `--dry-run` | Preview what files would be created, deleted, or modified — without writing anything | + +### Preview before upgrading + +```bash +wren context upgrade --path my_project --dry-run +``` + +```text +Dry run — no files will be changed. + +Would create: + models/orders/metadata.yml + models/orders/ref_sql.sql + views/summary/metadata.yml + +Would delete: + models/orders.yml + views.yml + +Would modify: + wren_project.yml (schema_version 1 -> 3) +``` + +### After upgrading + +```bash +wren context validate --path my_project +wren context build --path my_project +``` + +> **When to use this:** Your project was created with an older CLI version and you want to use new features (like per-model `dialect`). If your project is already at the latest schema_version, the command exits with a "nothing to do" message. --- @@ -297,8 +380,60 @@ The `build` step converts all YAML keys from snake_case to camelCase: | `primary_key` | `primaryKey` | | `join_type` | `joinType` | | `data_source` | `dataSource` | +| `layout_version` | `layoutVersion` | +| `refresh_time` | `refreshTime` | +| `base_object` | `baseObject` | + +Generic rule: split on `_`, capitalize each word after the first, join. All other fields (`name`, `type`, `catalog`, `schema`, `table`, `condition`, `models`, `columns`, `cached`, `dialect`, `properties`) are identical in both formats. + +The `layoutVersion` field is stamped automatically by `wren context build` based on the project's `schema_version`. You do not set it manually in YAML. + +--- + +## Dialect Override + +Models and views support an optional `dialect` field that declares which SQL dialect their embedded SQL is written in. This requires `schema_version: 3`. + +### Semantics + +- **`dialect` omitted (or `null`)** — falls back to the project-level `data_source`. This is the default and matches the behavior of all existing projects. +- **`dialect` set** — the embedded SQL is written in the specified dialect, which may differ from the project's `data_source`. + +### Model dialect + +When a model has `dialect: bigquery` but the project's `data_source` is `postgres`, the engine knows the model's `ref_sql` contains BigQuery-flavored SQL (e.g. backtick-quoted identifiers, BigQuery functions). The engine uses this to select the correct SQL parser for the ref_sql. + +```yaml +# models/revenue/metadata.yml +name: revenue +ref_sql: "SELECT * FROM `my-project.dataset.table`" +dialect: bigquery +columns: + - name: amount + type: DECIMAL +``` + +### View dialect + +For views, the `dialect` field is currently **metadata only**. The engine normalizes view statements into a logical plan using DataFusion's generic SQL parser regardless of the dialect setting. The field is still valuable because: + +- It documents the author's intent (which dialect the SQL was written in). +- Downstream consumers (ibis-server, MCP clients) can use it for dialect-aware processing. +- When dialect-aware view parsing is added in the future, the field will already be in place. + +### Valid dialect values + +`athena`, `bigquery`, `canner`, `clickhouse`, `databricks`, `datafusion`, `doris`, `duckdb`, `gcs_file`, `local_file`, `minio_file`, `mssql`, `mysql`, `oracle`, `postgres`, `redshift`, `s3_file`, `snowflake`, `spark`, `trino` + +### Version requirements + +The `dialect` field requires `schema_version: 3` in `wren_project.yml`. Using `dialect` in a `schema_version: 2` project produces a validation warning. The `schema_version` also controls the `layoutVersion` stamped in the compiled `target/mdl.json`: -Generic rule: split on `_`, capitalize each word after the first, join. All other fields (`name`, `type`, `catalog`, `schema`, `table`, `condition`, `models`, `columns`, `cached`, `properties`) are identical in both formats. +| `schema_version` | `layoutVersion` | Capabilities | +|-------------------|-----------------|--------------| +| 1 | 1 | Legacy flat-file project format | +| 2 | 1 | Folder-per-entity project format | +| 3 | 2 | `dialect` field on models and views | --- diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 5efcecbf1..b9220f2a9 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -38,6 +38,8 @@ pub fn manifest(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStr #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] #[serde(rename_all = "camelCase")] pub struct Manifest { + #[serde(default = "default_layout_version")] + pub layout_version: u32, pub catalog: String, pub schema: String, #[serde(default)] @@ -51,6 +53,10 @@ pub fn manifest(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStr #[serde(default)] pub data_source: Option, } + + fn default_layout_version() -> u32 { + 1 + } }; proc_macro::TokenStream::from(expanded) } @@ -154,6 +160,8 @@ pub fn model(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream pub refresh_time: Option, #[serde(default)] pub row_level_access_controls: Vec>, + #[serde(default)] + pub dialect: Option, } }; proc_macro::TokenStream::from(expanded) @@ -363,9 +371,12 @@ pub fn view(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream let expanded = quote! { #python_binding #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] + #[serde(rename_all = "camelCase")] pub struct View { pub name: String, pub statement: String, + #[serde(default)] + pub dialect: Option, } }; proc_macro::TokenStream::from(expanded) diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index 28b7f54cb..d01e181bb 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -46,6 +46,7 @@ impl ManifestBuilder { pub fn new() -> Self { Self { manifest: Manifest { + layout_version: 1, catalog: "wrenai".to_string(), schema: "public".to_string(), models: vec![], @@ -57,6 +58,11 @@ impl ManifestBuilder { } } + pub fn layout_version(mut self, version: u32) -> Self { + self.manifest.layout_version = version; + self + } + pub fn catalog(mut self, catalog: &str) -> Self { self.manifest.catalog = catalog.to_string(); self @@ -114,6 +120,7 @@ impl ModelBuilder { cached: false, refresh_time: None, row_level_access_controls: vec![], + dialect: None, }, } } @@ -168,6 +175,11 @@ impl ModelBuilder { self } + pub fn dialect(mut self, dialect: DataSource) -> Self { + self.model.dialect = Some(dialect); + self + } + pub fn build(self) -> Arc { Arc::new(self.model) } @@ -406,6 +418,7 @@ impl ViewBuilder { view: View { name: name.to_string(), statement: "".to_string(), + dialect: None, }, } } @@ -415,6 +428,11 @@ impl ViewBuilder { self } + pub fn dialect(mut self, dialect: DataSource) -> Self { + self.view.dialect = Some(dialect); + self + } + pub fn build(self) -> Arc { Arc::new(self.view) } @@ -848,4 +866,129 @@ mod test { assert_eq!(actual.normalized_name(), actual.name.to_lowercase()); assert_eq!(actual, expected) } + + #[test] + fn test_manifest_layout_version_default() { + let json = r#"{"catalog":"wren","schema":"public"}"#; + let manifest: Manifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.layout_version, 1); + } + + #[test] + fn test_manifest_layout_version_explicit() { + let json = r#"{"layoutVersion":2,"catalog":"wren","schema":"public"}"#; + let manifest: Manifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.layout_version, 2); + } + + #[test] + fn test_manifest_layout_version_roundtrip() { + let expected = ManifestBuilder::new().layout_version(2).build(); + let json_str = serde_json::to_string(&expected).unwrap(); + assert!(json_str.contains(r#""layoutVersion":2"#)); + let actual: Manifest = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual.layout_version, 2); + assert_eq!(actual, expected); + } + + #[test] + fn test_manifest_version_validation_ok() { + use crate::mdl::manifest::MAX_SUPPORTED_LAYOUT_VERSION; + let manifest = ManifestBuilder::new() + .layout_version(MAX_SUPPORTED_LAYOUT_VERSION) + .build(); + assert!(manifest.validate_layout_version().is_ok()); + } + + #[test] + fn test_manifest_version_validation_rejected() { + let manifest = ManifestBuilder::new().layout_version(99).build(); + let err = manifest.validate_layout_version().unwrap_err(); + assert!(err.to_string().contains("99")); + assert!(err.to_string().contains("only supports up to")); + } + + #[test] + fn test_model_dialect_none_default() { + let json = r#"{"name":"test","columns":[]}"#; + let model: Arc = serde_json::from_str(json).unwrap(); + assert!(model.dialect.is_none()); + } + + #[test] + fn test_model_dialect_roundtrip() { + let expected = ModelBuilder::new("test") + .table_reference("test") + .column(ColumnBuilder::new("id", "integer").build()) + .dialect(DataSource::BigQuery) + .build(); + + let json_str = serde_json::to_string(&expected).unwrap(); + assert!(json_str.contains(r#""dialect":"BIGQUERY""#)); + let actual: Arc = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual.dialect, Some(DataSource::BigQuery)); + assert_eq!(actual, expected); + } + + #[test] + fn test_model_dialect_case_insensitive() { + let json = r#"{"name":"test","columns":[],"dialect":"bigquery"}"#; + let model: Arc = serde_json::from_str(json).unwrap(); + assert_eq!(model.dialect, Some(DataSource::BigQuery)); + } + + #[test] + fn test_view_dialect_none_default() { + let json = r#"{"name":"test","statement":"SELECT 1"}"#; + let view: Arc = serde_json::from_str(json).unwrap(); + assert!(view.dialect.is_none()); + } + + #[test] + fn test_view_dialect_roundtrip() { + let expected = ViewBuilder::new("test") + .statement("SELECT * FROM test") + .dialect(DataSource::Postgres) + .build(); + + let json_str = serde_json::to_string(&expected).unwrap(); + assert!(json_str.contains(r#""dialect":"POSTGRES""#)); + let actual: Arc = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual.dialect, Some(DataSource::Postgres)); + assert_eq!(actual, expected); + } + + #[test] + fn test_manifest_with_dialect_models_and_views() { + let model = ModelBuilder::new("revenue") + .ref_sql("SELECT * FROM `project.dataset.table`") + .dialect(DataSource::BigQuery) + .column(ColumnBuilder::new("amount", "decimal").build()) + .build(); + + let view = ViewBuilder::new("summary") + .statement("SELECT date_trunc('month', created_at) FROM orders") + .dialect(DataSource::Postgres) + .build(); + + let expected = ManifestBuilder::new() + .layout_version(2) + .model(model) + .view(view) + .data_source(DataSource::Postgres) + .build(); + + let json_str = serde_json::to_string(&expected).unwrap(); + let actual: Manifest = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual, expected); + assert_eq!(actual.layout_version, 2); + assert_eq!(actual.models[0].dialect, Some(DataSource::BigQuery)); + assert_eq!(actual.views[0].dialect, Some(DataSource::Postgres)); + } + + #[test] + fn test_manifest_builder_default_layout_version() { + let manifest = ManifestBuilder::new().build(); + assert_eq!(manifest.layout_version, 1); + } } diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index fed827a7c..ace633cd6 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -101,6 +101,39 @@ mod manifest_impl { pub use crate::mdl::manifest::manifest_impl::*; +pub const MAX_SUPPORTED_LAYOUT_VERSION: u32 = 2; + +impl Manifest { + pub fn validate_layout_version(&self) -> Result<(), LayoutVersionError> { + if self.layout_version > MAX_SUPPORTED_LAYOUT_VERSION { + Err(LayoutVersionError { + manifest_version: self.layout_version, + max_supported: MAX_SUPPORTED_LAYOUT_VERSION, + }) + } else { + Ok(()) + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LayoutVersionError { + pub manifest_version: u32, + pub max_supported: u32, +} + +impl Display for LayoutVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "This manifest requires layout version {}, but this engine only supports up to {}", + self.manifest_version, self.max_supported + ) + } +} + +impl Error for LayoutVersionError {} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParsedDataSourceError { pub message: String, diff --git a/wren-core-base/src/mdl/migration.rs b/wren-core-base/src/mdl/migration.rs new file mode 100644 index 000000000..45a324122 --- /dev/null +++ b/wren-core-base/src/mdl/migration.rs @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::mdl::manifest::MAX_SUPPORTED_LAYOUT_VERSION; +use serde_json::Value; +use std::fmt; + +#[derive(Debug)] +pub enum MigrationError { + Json(serde_json::Error), + UnsupportedTargetVersion { target: u32, max: u32 }, +} + +impl fmt::Display for MigrationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MigrationError::Json(e) => write!(f, "JSON error during migration: {e}"), + MigrationError::UnsupportedTargetVersion { target, max } => write!( + f, + "Cannot migrate to layout version {target}: maximum supported version is {max}" + ), + } + } +} + +impl std::error::Error for MigrationError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + MigrationError::Json(e) => Some(e), + MigrationError::UnsupportedTargetVersion { .. } => None, + } + } +} + +impl From for MigrationError { + fn from(e: serde_json::Error) -> Self { + MigrationError::Json(e) + } +} + +/// Migrate a manifest JSON string to the specified target layout version. +/// +/// Applies migration steps sequentially (1→2, 2→3, ...). +/// Returns the input unchanged if already at or above the target version. +pub fn migrate_manifest( + manifest_json: &str, + target_version: u32, +) -> Result { + if target_version > MAX_SUPPORTED_LAYOUT_VERSION { + return Err(MigrationError::UnsupportedTargetVersion { + target: target_version, + max: MAX_SUPPORTED_LAYOUT_VERSION, + }); + } + + let mut value: Value = serde_json::from_str(manifest_json)?; + let current = value + .get("layoutVersion") + .and_then(|v| v.as_u64()) + .unwrap_or(1) as u32; + + if current >= target_version { + return Ok(manifest_json.to_string()); + } + + for version in current..target_version { + match version { + 1 => migrate_v1_to_v2(&mut value), + _ => { + return Err(MigrationError::UnsupportedTargetVersion { + target: target_version, + max: MAX_SUPPORTED_LAYOUT_VERSION, + }); + } + } + } + + value["layoutVersion"] = serde_json::json!(target_version); + Ok(serde_json::to_string(&value)?) +} + +/// v1→v2: No data transformation needed. +/// The `dialect` field on Model and View is optional and defaults to null. +fn migrate_v1_to_v2(_value: &mut Value) { + // No-op: `dialect` is Option with serde(default), + // so existing manifests deserialize correctly without changes. +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_migrate_v1_to_v2() { + let v1_json = r#"{"catalog":"wren","schema":"public","models":[]}"#; + let result = migrate_manifest(v1_json, 2).unwrap(); + let value: Value = serde_json::from_str(&result).unwrap(); + assert_eq!(value["layoutVersion"], 2); + } + + #[test] + fn test_migrate_already_at_target() { + let v2_json = r#"{"layoutVersion":2,"catalog":"wren","schema":"public","models":[]}"#; + let result = migrate_manifest(v2_json, 2).unwrap(); + assert_eq!(result, v2_json); + } + + #[test] + fn test_migrate_above_target() { + let v2_json = r#"{"layoutVersion":2,"catalog":"wren","schema":"public","models":[]}"#; + let result = migrate_manifest(v2_json, 1).unwrap(); + assert_eq!(result, v2_json); + } + + #[test] + fn test_migrate_unsupported_target() { + let v1_json = r#"{"catalog":"wren","schema":"public","models":[]}"#; + let result = migrate_manifest(v1_json, 99); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("99")); + } + + #[test] + fn test_migrate_idempotent() { + let v1_json = r#"{"catalog":"wren","schema":"public","models":[]}"#; + let first = migrate_manifest(v1_json, 2).unwrap(); + let second = migrate_manifest(&first, 2).unwrap(); + assert_eq!(first, second); + } + + #[test] + fn test_migrate_preserves_existing_fields() { + let v1_json = r#"{"catalog":"test","schema":"myschema","models":[{"name":"m1","columns":[],"tableReference":null}],"dataSource":"BIGQUERY"}"#; + let result = migrate_manifest(v1_json, 2).unwrap(); + let value: Value = serde_json::from_str(&result).unwrap(); + assert_eq!(value["catalog"], "test"); + assert_eq!(value["schema"], "myschema"); + assert_eq!(value["dataSource"], "BIGQUERY"); + assert_eq!(value["models"][0]["name"], "m1"); + assert_eq!(value["layoutVersion"], 2); + } + + #[test] + fn test_migrate_invalid_json() { + let result = migrate_manifest("not json", 2); + assert!(result.is_err()); + } +} diff --git a/wren-core-base/src/mdl/mod.rs b/wren-core-base/src/mdl/mod.rs index e25ccb216..505697871 100644 --- a/wren-core-base/src/mdl/mod.rs +++ b/wren-core-base/src/mdl/mod.rs @@ -20,6 +20,7 @@ pub mod builder; pub mod cls; pub mod manifest; +pub mod migration; mod py_method; mod utils; diff --git a/wren-core-base/src/mdl/py_method.rs b/wren-core-base/src/mdl/py_method.rs index 38d91c367..5d436122d 100644 --- a/wren-core-base/src/mdl/py_method.rs +++ b/wren-core-base/src/mdl/py_method.rs @@ -26,6 +26,11 @@ mod manifest_python_impl { #[pymethods] impl Manifest { + #[getter] + fn layout_version(&self) -> PyResult { + Ok(self.layout_version) + } + #[getter] fn catalog(&self) -> PyResult { Ok(self.catalog.clone()) diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index 1ab4d961d..647149719 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -80,6 +80,7 @@ fn extract_manifest( .collect::>(); let used_relationships = extract_relationships(mdl, &used_models); Ok(Manifest { + layout_version: mdl.manifest.layout_version, catalog: mdl.catalog().to_string(), schema: mdl.schema().to_string(), models: used_models, diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index c9425d0be..0722f2df0 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -24,5 +24,6 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?; m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?; m.add_function(wrap_pyfunction!(manifest::is_backward_compatible, m)?)?; + m.add_function(wrap_pyfunction!(manifest::migrate_manifest_json, m)?)?; Ok(()) } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 286862ddb..aaf57a570 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -2,6 +2,7 @@ use crate::errors::CoreError; use base64::prelude::BASE64_STANDARD; use base64::Engine; use pyo3::pyfunction; +use wren_core_base::mdl::migration; pub use wren_core_base::mdl::*; @@ -22,6 +23,16 @@ pub fn to_manifest(mdl_base64: &str) -> Result { Ok(manifest) } +/// Migrate a manifest JSON string to the specified target layout version. +#[pyfunction] +pub fn migrate_manifest_json( + manifest_json: &str, + target_version: u32, +) -> Result { + migration::migrate_manifest(manifest_json, target_version) + .map_err(|e| CoreError::new(&e.to_string())) +} + /// Check if the MDL can be used by the v2 wren core. If there are any access controls rules, /// the MDL should be used by the v3 wren core only. #[pyfunction] @@ -50,6 +61,7 @@ mod tests { #[test] fn test_manifest_to_json_base64() { let py_manifest = Manifest { + layout_version: 1, catalog: "catalog".to_string(), schema: "schema".to_string(), models: vec![ @@ -63,6 +75,7 @@ mod tests { cached: false, refresh_time: None, row_level_access_controls: vec![], + dialect: None, }), Arc::from(Model { name: "model_2".to_string(), @@ -74,6 +87,7 @@ mod tests { cached: false, refresh_time: None, row_level_access_controls: vec![], + dialect: None, }), ], relationships: vec![], diff --git a/wren/src/wren/context.py b/wren/src/wren/context.py index fb31ec393..d5df83b9f 100644 --- a/wren/src/wren/context.py +++ b/wren/src/wren/context.py @@ -96,6 +96,14 @@ def _convert_keys(obj: Any) -> Any: "primaryKey": "primary_key", "joinType": "join_type", "dataSource": "data_source", + "layoutVersion": "layout_version", + "refreshTime": "refresh_time", + "baseObject": "base_object", + "rowLevelAccessControls": "row_level_access_controls", + "columnLevelAccessControl": "column_level_access_control", + "requiredProperties": "required_properties", + "defaultExpr": "default_expr", + "isHidden": "is_hidden", } @@ -141,7 +149,13 @@ def convert_mdl_to_project(mdl_json: dict) -> list[ProjectFile]: files: list[ProjectFile] = [] # ── wren_project.yml ────────────────────────────────────── - project_config: dict[str, Any] = {"schema_version": 2} + # Map layoutVersion back to schema_version + layout_version = mdl_json.get("layoutVersion", 1) + _LAYOUT_TO_SCHEMA = {1: 2, 2: 3} + schema_version = _LAYOUT_TO_SCHEMA.get( + layout_version, 3 if layout_version >= 2 else 2 + ) + project_config: dict[str, Any] = {"schema_version": schema_version} if "name" in mdl_json: project_config["name"] = mdl_json["name"] elif "projectName" in mdl_json: @@ -367,7 +381,34 @@ def load_project_config(project_path: Path) -> dict: return yaml.safe_load(config_file.read_text()) or {} -_SUPPORTED_SCHEMA_VERSIONS = {1, 2} +_SUPPORTED_SCHEMA_VERSIONS = {1, 2, 3} + +# schema_version → layoutVersion mapping for the engine +_LAYOUT_VERSION_MAP = {1: 1, 2: 1, 3: 2} + +# Valid dialect values (matches Rust DataSource enum) +_VALID_DIALECTS = { + "athena", + "bigquery", + "canner", + "clickhouse", + "databricks", + "datafusion", + "doris", + "duckdb", + "gcs_file", + "local_file", + "minio_file", + "mssql", + "mysql", + "oracle", + "postgres", + "redshift", + "s3_file", + "snowflake", + "spark", + "trino", +} def get_schema_version(project_path: Path) -> int: @@ -547,18 +588,28 @@ def build_manifest(project_path: Path) -> dict: for v in views: v.pop("_source_dir", None) - return { + manifest: dict = { "catalog": project_config.get("catalog", "wren"), "schema": project_config.get("schema", "public"), "models": models, "relationships": relationships, "views": views, } + data_source = project_config.get("data_source") + if data_source: + manifest["data_source"] = data_source + return manifest def build_json(project_path: Path) -> dict: - """Build the final camelCase JSON manifest for the engine.""" - return _convert_keys(build_manifest(project_path)) + """Build the final camelCase JSON manifest for the engine. + + Stamps layoutVersion based on schema_version mapping. + """ + manifest = _convert_keys(build_manifest(project_path)) + sv = get_schema_version(project_path) + manifest["layoutVersion"] = _LAYOUT_VERSION_MAP.get(sv, 1) + return manifest def save_target(manifest_json: dict, project_path: Path) -> Path: @@ -600,6 +651,7 @@ def validate_project(project_path: Path) -> list[ValidationError]: 8. table_reference (if used) has at least a table field """ errors: list[ValidationError] = [] + sv = 1 # default; may be overridden below # Check project config config = load_project_config(project_path) @@ -759,6 +811,26 @@ def validate_project(project_path: Path) -> list[ValidationError]: ) ) + # Validate dialect (if present) + model_dialect = model.get("dialect") + if model_dialect is not None: + if sv < 3: + errors.append( + ValidationError( + "warning", + f"{src_path} > {name}", + f"'dialect' field requires schema_version >= 3 (current: {sv})", + ) + ) + if model_dialect.lower() not in _VALID_DIALECTS: + errors.append( + ValidationError( + "error", + f"{src_path} > {name}", + f"unknown dialect '{model_dialect}'", + ) + ) + # Check views for i, view in enumerate(views): src_dir = view.get("_source_dir", f"views[{i}]") @@ -785,6 +857,26 @@ def validate_project(project_path: Path) -> list[ValidationError]: ) ) + # Validate dialect (if present) + view_dialect = view.get("dialect") + if view_dialect is not None: + if sv < 3: + errors.append( + ValidationError( + "warning", + f"views/{src_dir}", + f"'dialect' field requires schema_version >= 3 (current: {sv})", + ) + ) + if view_dialect.lower() not in _VALID_DIALECTS: + errors.append( + ValidationError( + "error", + f"views/{src_dir}", + f"unknown dialect '{view_dialect}'", + ) + ) + # Check relationships all_entity_names = model_names | view_names for i, rel in enumerate(relationships): @@ -824,6 +916,185 @@ def validate_project(project_path: Path) -> list[ValidationError]: return errors +# ── Upgrade ────────────────────────────────────────────────────────────────── + +_LATEST_SCHEMA_VERSION = max(_SUPPORTED_SCHEMA_VERSIONS) + + +@dataclass +class UpgradeResult: + """Result of a project schema upgrade.""" + + from_version: int + to_version: int + files_created: list[str] + files_deleted: list[str] + files_modified: list[str] + + +class UpgradeError(Exception): + """Raised when a project upgrade cannot proceed.""" + + +def plan_upgrade( + project_path: Path, + target_version: int | None = None, +) -> UpgradeResult: + """Compute what an upgrade would do, without touching disk. + + Raises UpgradeError if the upgrade is invalid (e.g. downgrade, unsupported version). + Returns an UpgradeResult with empty lists if already at target (no-op). + """ + current = get_schema_version(project_path) + target = target_version if target_version is not None else _LATEST_SCHEMA_VERSION + + if target not in _SUPPORTED_SCHEMA_VERSIONS: + raise UpgradeError(f"Unsupported target schema_version {target}") + if target < current: + raise UpgradeError( + f"Cannot downgrade from schema_version {current} to {target}" + ) + if target == current: + return UpgradeResult( + from_version=current, + to_version=target, + files_created=[], + files_deleted=[], + files_modified=[], + ) + + files_created: list[str] = [] + files_deleted: list[str] = [] + + # Apply steps sequentially + for version in range(current, target): + if version == 1: + created, deleted = _plan_v1_to_v2(project_path) + files_created.extend(created) + files_deleted.extend(deleted) + # v2→v3: no file layout changes needed + + return UpgradeResult( + from_version=current, + to_version=target, + files_created=files_created, + files_deleted=files_deleted, + files_modified=[_PROJECT_FILE], + ) + + +def _plan_v1_to_v2(project_path: Path) -> tuple[list[str], list[str]]: + """Plan the v1→v2 file restructuring. Returns (files_created, files_deleted).""" + created: list[str] = [] + deleted: list[str] = [] + + # Models: flat files → directories + models = _load_models_v1(project_path) + for model in models: + source_dir = model.pop("_source_dir", None) + name = model.get("name", source_dir or "unknown") + dir_path = f"models/{name}" + + ref_sql = model.get("ref_sql") + if ref_sql: + created.append(f"{dir_path}/ref_sql.sql") + + created.append(f"{dir_path}/metadata.yml") + + if source_dir: + deleted.append(f"models/{source_dir}.yml") + + # Views: single file → directories + views = _load_views_v1(project_path) + for view in views: + name = view.get("name") + if not name: + continue + dir_path = f"views/{name}" + + statement = view.get("statement") + if statement and "\n" in statement.strip(): + created.append(f"{dir_path}/sql.yml") + + created.append(f"{dir_path}/metadata.yml") + + views_file = project_path / "views.yml" + if views_file.exists(): + deleted.append("views.yml") + + return created, deleted + + +def apply_upgrade(project_path: Path, result: UpgradeResult) -> None: + """Write upgrade changes to disk.""" + if not result.files_created and not result.files_deleted: + # No-op (e.g. v2→v3, only wren_project.yml changes) + pass + else: + _apply_v1_to_v2(project_path) + + # Update wren_project.yml + config = load_project_config(project_path) + config["schema_version"] = result.to_version + config_file = project_path / _PROJECT_FILE + config_file.write_text(yaml.dump(config, default_flow_style=False, sort_keys=False)) + + +def _apply_v1_to_v2(project_path: Path) -> None: + """Execute the v1→v2 restructuring: write new files, delete old ones.""" + # Write new model directories + models = _load_models_v1(project_path) + for model in models: + source_dir = model.pop("_source_dir", None) + name = model.get("name", source_dir or "unknown") + model_dir = project_path / "models" / name + model_dir.mkdir(parents=True, exist_ok=True) + + ref_sql = model.pop("ref_sql", None) + if ref_sql: + (model_dir / "ref_sql.sql").write_text(ref_sql.strip() + "\n") + + (model_dir / "metadata.yml").write_text( + yaml.dump(model, default_flow_style=False, sort_keys=False) + ) + + # Delete old flat file + if source_dir: + old_file = project_path / "models" / f"{source_dir}.yml" + if old_file.exists(): + old_file.unlink() + + # Write new view directories + views = _load_views_v1(project_path) + for view in views: + name = view.get("name") + if not name: + continue + view_dir = project_path / "views" / name + view_dir.mkdir(parents=True, exist_ok=True) + + statement = view.pop("statement", None) + if statement and "\n" in statement.strip(): + (view_dir / "sql.yml").write_text( + yaml.dump( + {"statement": statement}, + default_flow_style=False, + sort_keys=False, + ) + ) + elif statement: + view["statement"] = statement + + (view_dir / "metadata.yml").write_text( + yaml.dump(view, default_flow_style=False, sort_keys=False) + ) + + # Delete old views.yml + views_file = project_path / "views.yml" + if views_file.exists(): + views_file.unlink() + + # ── Semantic validation (view dry-plan + description completeness) ───────── _VALID_LEVELS = frozenset({"error", "warning", "strict"}) diff --git a/wren/src/wren/context_cli.py b/wren/src/wren/context_cli.py index d81703a05..7f93b75b4 100644 --- a/wren/src/wren/context_cli.py +++ b/wren/src/wren/context_cli.py @@ -40,7 +40,7 @@ def init( Without --from-mdl: scaffolds an empty project structure. With --from-mdl: imports an existing MDL JSON and produces a complete - v2 YAML project, ready for `wren context validate/build`. + YAML project, ready for `wren context validate/build`. """ project_path = Path(path).expanduser() if path else Path.cwd() @@ -96,7 +96,7 @@ def init( # wren_project.yml project_yml = ( - "schema_version: 2\n" + "schema_version: 3\n" "name: my_project\n" 'version: "1.0"\n' "\n" @@ -436,3 +436,86 @@ def instructions( content = load_instructions(project_path) if content: typer.echo(content) + + +@context_app.command() +def upgrade( + path: ProjectPathOpt = None, + to: Annotated[ + Optional[int], + typer.Option("--to", help="Target schema_version (default: latest)."), + ] = None, + dry_run: Annotated[ + bool, + typer.Option("--dry-run", help="Preview changes without writing."), + ] = False, +) -> None: + """Upgrade project schema_version to enable new features.""" + from wren.context import ( # noqa: PLC0415 + UpgradeError, + apply_upgrade, + discover_project_path, + get_schema_version, + plan_upgrade, + ) + + try: + project_path = discover_project_path(path) + except SystemExit as e: + typer.echo(str(e), err=True) + raise typer.Exit(1) + + current = get_schema_version(project_path) + + try: + result = plan_upgrade(project_path, target_version=to) + except UpgradeError as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) + + if ( + not result.files_created + and not result.files_deleted + and not result.files_modified + ): + typer.echo(f"Already at schema_version {current}. Nothing to do.") + return + + if result.from_version == result.to_version: + typer.echo(f"Already at schema_version {current}. Nothing to do.") + return + + if dry_run: + typer.echo("Dry run — no files will be changed.\n") + if result.files_created: + typer.echo("Would create:") + for f in result.files_created: + typer.echo(f" {f}") + if result.files_deleted: + typer.echo("Would delete:") + for f in result.files_deleted: + typer.echo(f" {f}") + if result.files_modified: + typer.echo("Would modify:") + for f in result.files_modified: + typer.echo( + f" {f} (schema_version {result.from_version} -> {result.to_version})" + ) + return + + typer.echo( + f"Upgrading project from schema_version {result.from_version} -> {result.to_version}..." + ) + + apply_upgrade(project_path, result) + + for f in result.files_created: + typer.echo(f" + {f}") + for f in result.files_deleted: + typer.echo(f" - {f}") + for f in result.files_modified: + typer.echo( + f" * {f} (schema_version {result.from_version} -> {result.to_version})" + ) + + typer.echo("\nUpgrade complete. Run `wren context validate` to check the result.") diff --git a/wren/tests/unit/test_context.py b/wren/tests/unit/test_context.py index 044bba3ce..126b5c429 100644 --- a/wren/tests/unit/test_context.py +++ b/wren/tests/unit/test_context.py @@ -8,16 +8,20 @@ import pytest from wren.context import ( + UpgradeError, _convert_keys, _snake_to_camel, + apply_upgrade, build_json, build_manifest, + convert_mdl_to_project, discover_project_path, get_schema_version, load_instructions, load_models, load_relationships, load_views, + plan_upgrade, require_schema_version, save_target, validate_project, @@ -311,6 +315,34 @@ def test_build_json_camel_case(tmp_path): assert "_instructions" not in result +def test_build_manifest_includes_data_source(tmp_path): + """build_manifest must include data_source from project config.""" + _minimal_v2_project(tmp_path) + manifest = build_manifest(tmp_path) + assert manifest["data_source"] == "postgres" + + +def test_build_json_includes_data_source(tmp_path): + """build_json must include dataSource (camelCase) from project config.""" + _minimal_v2_project(tmp_path) + result = build_json(tmp_path) + assert result["dataSource"] == "postgres" + + +def test_build_manifest_omits_data_source_when_unset(tmp_path): + """If project config lacks data_source, the field is omitted.""" + (tmp_path / "wren_project.yml").write_text( + "schema_version: 2\nname: test\ncatalog: wren\nschema: public\n" + ) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\ntable_reference:\n table: orders\ncolumns: []\n" + ) + manifest = build_manifest(tmp_path) + assert "data_source" not in manifest + + def test_build_json_round_trip(tmp_path): _minimal_v2_project(tmp_path) result = build_json(tmp_path) @@ -540,6 +572,356 @@ def test_discover_via_config(tmp_path, monkeypatch): assert result == project_dir +# ── Schema version 3 / dialect / layoutVersion ────────────────────────────── + + +def _make_v3_project(tmp_path: Path) -> Path: + """Write a minimal v3 project with dialect support.""" + (tmp_path / "wren_project.yml").write_text( + "schema_version: 3\nname: test\ndata_source: postgres\ncatalog: wren\nschema: public\n" + ) + return tmp_path + + +def test_get_schema_version_v3(tmp_path): + _make_v3_project(tmp_path) + assert get_schema_version(tmp_path) == 3 + + +def test_require_schema_version_v3(tmp_path): + _make_v3_project(tmp_path) + assert require_schema_version(tmp_path) == 3 + + +def test_build_json_layout_version_v2_project(tmp_path): + """schema_version 2 → layoutVersion 1.""" + _minimal_v2_project(tmp_path) + result = build_json(tmp_path) + assert result["layoutVersion"] == 1 + + +def test_build_json_layout_version_v3_project(tmp_path): + """schema_version 3 → layoutVersion 2.""" + _make_v3_project(tmp_path) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: id\n type: INTEGER\n" + ) + result = build_json(tmp_path) + assert result["layoutVersion"] == 2 + + +def test_build_json_model_dialect_preserved(tmp_path): + """Model dialect field flows through to JSON output.""" + _make_v3_project(tmp_path) + d = tmp_path / "models" / "revenue" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: revenue\n" + "table_reference:\n table: revenue\n" + "dialect: bigquery\n" + "columns:\n - name: amount\n type: decimal\n" + ) + result = build_json(tmp_path) + assert result["models"][0]["dialect"] == "bigquery" + + +def test_build_json_view_dialect_preserved(tmp_path): + """View dialect field flows through to JSON output.""" + _make_v3_project(tmp_path) + d = tmp_path / "views" / "summary" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: summary\n" + "statement: SELECT 1\n" + "dialect: postgres\n" + ) + result = build_json(tmp_path) + assert result["views"][0]["dialect"] == "postgres" + + +def test_v3_models_load_same_as_v2(tmp_path): + """schema_version 3 uses the same directory layout as v2.""" + _make_v3_project(tmp_path) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text("name: orders\ntable_reference:\n table: orders\n") + models = load_models(tmp_path) + assert len(models) == 1 + assert models[0]["name"] == "orders" + + +def test_convert_mdl_preserves_dialect(tmp_path): + """convert_mdl_to_project preserves dialect on models and views.""" + mdl = { + "layoutVersion": 2, + "catalog": "wren", + "schema": "public", + "dataSource": "POSTGRES", + "models": [ + { + "name": "revenue", + "tableReference": {"table": "revenue"}, + "dialect": "bigquery", + "columns": [{"name": "amount", "type": "decimal"}], + } + ], + "views": [ + { + "name": "summary", + "statement": "SELECT 1", + "dialect": "postgres", + } + ], + } + files = convert_mdl_to_project(mdl) + file_map = {f.relative_path: f.content for f in files} + + # Check schema_version derived from layoutVersion 2 + import yaml + + project = yaml.safe_load(file_map["wren_project.yml"]) + assert project["schema_version"] == 3 + + # Check model dialect preserved + model_meta = yaml.safe_load(file_map["models/revenue/metadata.yml"]) + assert model_meta["dialect"] == "bigquery" + + # Check view dialect preserved + view_meta = yaml.safe_load(file_map["views/summary/metadata.yml"]) + assert view_meta["dialect"] == "postgres" + + +def test_convert_mdl_v1_layout_version(tmp_path): + """layoutVersion 1 (or missing) → schema_version 2.""" + mdl = { + "catalog": "wren", + "schema": "public", + "models": [], + } + files = convert_mdl_to_project(mdl) + import yaml + + file_map = {f.relative_path: f.content for f in files} + project = yaml.safe_load(file_map["wren_project.yml"]) + assert project["schema_version"] == 2 + + +def test_validate_dialect_unknown_value(tmp_path): + """Unknown dialect value is an error.""" + _make_v3_project(tmp_path) + d = tmp_path / "models" / "bad" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: bad\n" + "table_reference:\n table: bad\n" + "dialect: nosuchdb\n" + "columns:\n - name: id\n type: INTEGER\n" + ) + errors = validate_project(tmp_path) + assert any("unknown dialect" in e.message for e in errors) + + +def test_validate_dialect_valid_value(tmp_path): + """Valid dialect does not produce errors.""" + _make_v3_project(tmp_path) + d = tmp_path / "models" / "ok" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: ok\n" + "table_reference:\n table: ok\n" + "dialect: bigquery\n" + "columns:\n - name: id\n type: INTEGER\n" + "primary_key: id\n" + ) + (tmp_path / "relationships.yml").write_text("relationships: []\n") + errors = validate_project(tmp_path) + assert errors == [] + + +def test_validate_dialect_warning_in_v2(tmp_path): + """dialect on a schema_version 2 project produces a warning.""" + _make_v2_project(tmp_path) + d = tmp_path / "models" / "mixed" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: mixed\n" + "table_reference:\n table: mixed\n" + "dialect: bigquery\n" + "columns:\n - name: id\n type: INTEGER\n" + ) + errors = validate_project(tmp_path) + warnings = [e for e in errors if e.level == "warning"] + assert any("schema_version >= 3" in w.message for w in warnings) + + +def test_validate_view_dialect_unknown(tmp_path): + """Unknown dialect on a view is an error.""" + _make_v3_project(tmp_path) + d = tmp_path / "views" / "badview" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: badview\nstatement: SELECT 1\ndialect: nosuchdb\n" + ) + errors = validate_project(tmp_path) + assert any("unknown dialect" in e.message for e in errors) + + +# ── Upgrade ────────────────────────────────────────────────────────────────── + + +def _make_v1_project(tmp_path: Path) -> Path: + """Create a minimal v1 project with flat model files and views.yml.""" + (tmp_path / "wren_project.yml").write_text( + "schema_version: 1\nname: test\ndata_source: postgres\ncatalog: wren\nschema: public\n" + ) + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "orders.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: id\n type: INTEGER\n" + "primary_key: id\n" + ) + (models_dir / "revenue.yml").write_text( + "name: revenue\n" + "ref_sql: SELECT SUM(amount) FROM orders\n" + "columns:\n - name: total\n type: DECIMAL\n" + ) + (tmp_path / "views.yml").write_text( + "views:\n" + " - name: summary\n" + " statement: SELECT 1\n" + " - name: monthly\n" + ' statement: "SELECT\\n date_trunc(month, d)\\n FROM t"\n' + ) + (tmp_path / "relationships.yml").write_text("relationships: []\n") + (tmp_path / "instructions.md").write_text("## Rule 1\nAlways use UTC.\n") + return tmp_path + + +def test_plan_upgrade_v1_to_v2(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=2) + assert result.from_version == 1 + assert result.to_version == 2 + assert any("models/orders/metadata.yml" in f for f in result.files_created) + assert any("models/revenue/ref_sql.sql" in f for f in result.files_created) + assert any("models/orders.yml" in f for f in result.files_deleted) + assert any("views.yml" in f for f in result.files_deleted) + + +def test_plan_upgrade_v1_to_v3(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + assert result.from_version == 1 + assert result.to_version == 3 + assert len(result.files_created) > 0 + + +def test_plan_upgrade_v2_to_v3(tmp_path): + _make_v2_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + assert result.from_version == 2 + assert result.to_version == 3 + assert result.files_created == [] + assert result.files_deleted == [] + assert _PROJECT_FILE in result.files_modified + + +def test_plan_upgrade_already_at_target(tmp_path): + _make_v3_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + assert result.from_version == 3 + assert result.to_version == 3 + assert result.files_created == [] + assert result.files_deleted == [] + assert result.files_modified == [] + + +def test_plan_upgrade_above_target(tmp_path): + _make_v3_project(tmp_path) + # Use fresh import to avoid stale class reference after importlib.reload in earlier tests + from wren.context import UpgradeError as _UE # noqa: PLC0415 + + with pytest.raises(_UE, match="Cannot downgrade"): + plan_upgrade(tmp_path, target_version=1) + + +def test_plan_upgrade_default_to_latest(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path) + assert result.to_version == 3 + + +def test_apply_upgrade_v1_to_v2(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=2) + apply_upgrade(tmp_path, result) + + # New structure exists + assert (tmp_path / "models" / "orders" / "metadata.yml").exists() + assert (tmp_path / "models" / "revenue" / "metadata.yml").exists() + assert (tmp_path / "models" / "revenue" / "ref_sql.sql").exists() + assert (tmp_path / "views" / "summary" / "metadata.yml").exists() + + # Old files deleted + assert not (tmp_path / "models" / "orders.yml").exists() + assert not (tmp_path / "models" / "revenue.yml").exists() + assert not (tmp_path / "views.yml").exists() + + # schema_version updated + assert get_schema_version(tmp_path) == 2 + + # Content preserved + models = load_models(tmp_path) + assert len(models) == 2 + names = {m["name"] for m in models} + assert names == {"orders", "revenue"} + revenue = next(m for m in models if m["name"] == "revenue") + assert "SELECT SUM(amount)" in revenue["ref_sql"] + + +def test_apply_upgrade_v2_to_v3(tmp_path): + _make_v2_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + apply_upgrade(tmp_path, result) + assert get_schema_version(tmp_path) == 3 + + +def test_apply_upgrade_v1_to_v3(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + apply_upgrade(tmp_path, result) + assert get_schema_version(tmp_path) == 3 + assert (tmp_path / "models" / "orders" / "metadata.yml").exists() + assert not (tmp_path / "models" / "orders.yml").exists() + + +def test_upgrade_preserves_relationships(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + apply_upgrade(tmp_path, result) + assert (tmp_path / "relationships.yml").exists() + rels = load_relationships(tmp_path) + assert rels == [] + + +def test_upgrade_preserves_instructions(tmp_path): + _make_v1_project(tmp_path) + result = plan_upgrade(tmp_path, target_version=3) + apply_upgrade(tmp_path, result) + assert (tmp_path / "instructions.md").exists() + content = load_instructions(tmp_path) + assert "Rule 1" in content + + +_PROJECT_FILE = "wren_project.yml" + + # ── Semantic validation tests (view dry-plan + description checks) ───────── import base64 diff --git a/wren/tests/unit/test_context_cli.py b/wren/tests/unit/test_context_cli.py index f8c77b8ec..2a6dbc36c 100644 --- a/wren/tests/unit/test_context_cli.py +++ b/wren/tests/unit/test_context_cli.py @@ -215,3 +215,69 @@ def test_instructions_discovers_project(tmp_path): result = runner.invoke(app, ["context", "instructions", "--path", str(tmp_path)]) assert result.exit_code == 0 assert "custom rule here" in result.output + + +# ── wren context upgrade ───────────────────────────────────────────────── + + +def _make_v1_project(tmp_path: Path) -> Path: + (tmp_path / "wren_project.yml").write_text( + "schema_version: 1\nname: test\ndata_source: postgres\ncatalog: wren\nschema: public\n" + ) + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "orders.yml").write_text( + "name: orders\ntable_reference:\n table: orders\n" + "columns:\n - name: id\n type: INTEGER\nprimary_key: id\n" + ) + (tmp_path / "relationships.yml").write_text("relationships: []\n") + return tmp_path + + +def test_upgrade_cli_v2_to_v3(tmp_path): + _make_valid_project(tmp_path) + result = runner.invoke(app, ["context", "upgrade", "--path", str(tmp_path)]) + assert result.exit_code == 0, result.output + assert "Upgrade complete" in result.output + import yaml + + config = yaml.safe_load((tmp_path / "wren_project.yml").read_text()) + assert config["schema_version"] == 3 + + +def test_upgrade_cli_dry_run(tmp_path): + _make_v1_project(tmp_path) + result = runner.invoke( + app, ["context", "upgrade", "--path", str(tmp_path), "--dry-run"] + ) + assert result.exit_code == 0, result.output + assert "Dry run" in result.output + assert "Would create" in result.output + # Verify no files were actually changed + assert not (tmp_path / "models" / "orders" / "metadata.yml").exists() + import yaml + + config = yaml.safe_load((tmp_path / "wren_project.yml").read_text()) + assert config["schema_version"] == 1 + + +def test_upgrade_cli_already_current(tmp_path): + (tmp_path / "wren_project.yml").write_text( + "schema_version: 3\nname: test\ndata_source: postgres\n" + ) + result = runner.invoke(app, ["context", "upgrade", "--path", str(tmp_path)]) + assert result.exit_code == 0 + assert "Already at" in result.output + + +def test_upgrade_cli_explicit_to_version(tmp_path): + _make_v1_project(tmp_path) + result = runner.invoke( + app, ["context", "upgrade", "--path", str(tmp_path), "--to", "2"] + ) + assert result.exit_code == 0, result.output + assert "Upgrade complete" in result.output + import yaml + + config = yaml.safe_load((tmp_path / "wren_project.yml").read_text()) + assert config["schema_version"] == 2