Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/iceberg_spark_test_reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}
restore-keys: |
${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-
Expand All @@ -99,7 +98,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}

- name: Upload native library
Expand Down
5 changes: 0 additions & 5 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}
restore-keys: |
${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-
Expand Down Expand Up @@ -225,7 +224,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}

# Run Rust tests (runs in parallel with build-native, uses debug builds)
Expand All @@ -250,8 +248,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
# Note: Java version intentionally excluded - Rust target is JDK-independent
key: ${{ runner.os }}-cargo-debug-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}
restore-keys: |
${{ runner.os }}-cargo-debug-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-
Expand All @@ -266,7 +262,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-debug-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}

linux-test:
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-v2-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}
restore-keys: |
${{ runner.os }}-cargo-ci-v2-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-
Expand Down Expand Up @@ -92,7 +91,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-v2-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}

macos-aarch64-test:
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/spark_sql_test_reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
Comment thread
manuzhang marked this conversation as resolved.
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}
restore-keys: |
${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-
Expand All @@ -100,7 +99,6 @@ jobs:
path: |
~/.cargo/registry
~/.cargo/git
native/target
key: ${{ runner.os }}-cargo-ci-${{ hashFiles('native/**/Cargo.lock', 'native/**/Cargo.toml') }}-${{ hashFiles('native/**/*.rs') }}

- name: Upload native library
Expand Down
39 changes: 35 additions & 4 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,10 @@ fn cast_struct_to_struct(
ColumnarValue::from(from_field),
to.data_type(),
cast_options,
)
.unwrap();
cast_result.to_array(array_length).unwrap()
)?;
cast_result.to_array(array_length)
})
.collect();
.collect::<DataFusionResult<Vec<_>>>()?;

Ok(Arc::new(StructArray::new(
to_fields.clone(),
Expand Down Expand Up @@ -961,6 +960,38 @@ mod tests {
}
}

#[test]
fn test_cast_nested_struct_to_struct_ansi_overflow_returns_error() {
let inner_values: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(128), None]));
let from_nested_fields =
Fields::from(vec![Field::new("long_value", DataType::Int64, true)]);
let nested: ArrayRef = Arc::new(StructArray::new(
from_nested_fields.clone(),
vec![inner_values],
None,
));
let from_fields = Fields::from(vec![Field::new(
"nested",
DataType::Struct(from_nested_fields),
true,
)]);
let outer: ArrayRef = Arc::new(StructArray::new(from_fields, vec![nested], None));

let to_nested_fields = Fields::from(vec![Field::new("byte_value", DataType::Int8, true)]);
let to_fields = Fields::from(vec![Field::new(
"renamed_nested",
DataType::Struct(to_nested_fields),
true,
)]);
let result = spark_cast(
ColumnarValue::Array(outer),
&DataType::Struct(to_fields),
&SparkCastOptions::new(EvalMode::Ansi, "UTC", false),
);

assert!(result.is_err());
}

#[test]
fn test_cast_struct_to_struct_drop_column() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- Config: spark.sql.ansi.enabled=true

statement
CREATE TABLE test_cast_complex_ansi(
id int,
struct_value struct<
long_value:bigint,
string_value:string,
nested_value:struct<inner_long:bigint>>,
array_value array<bigint>
) USING parquet

statement
INSERT INTO test_cast_complex_ansi VALUES
(
1,
named_struct(
'long_value', cast(1 as bigint),
'string_value', 'fits',
'nested_value', named_struct('inner_long', cast(10 as bigint))),
array(cast(1 as bigint), cast(127 as bigint), cast(null as bigint))
),
(
2,
named_struct(
'long_value', cast(128 as bigint),
'string_value', 'too-large',
'nested_value', named_struct('inner_long', cast(10 as bigint))),
array(cast(1 as bigint))
),
(
3,
named_struct(
'long_value', cast(2 as bigint),
'string_value', 'nested-too-small',
'nested_value', named_struct('inner_long', cast(-129 as bigint))),
array(cast(2 as bigint))
),
(
4,
named_struct(
'long_value', cast(3 as bigint),
'string_value', 'array-too-large',
'nested_value', named_struct('inner_long', cast(4 as bigint))),
array(cast(128 as bigint))
),
(
5,
cast(null as struct<
long_value:bigint,
string_value:string,
nested_value:struct<inner_long:bigint>>),
cast(null as array<bigint>)
)

-- valid complex casts should run natively under ANSI mode
query
SELECT
cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>),
cast(array_value as array<tinyint>),
id
FROM test_cast_complex_ansi
WHERE id IN (1, 5)
ORDER BY id

-- overflow in a struct field should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>)
FROM test_cast_complex_ansi
WHERE id = 2

-- overflow in a nested struct field should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>)
FROM test_cast_complex_ansi
WHERE id = 3

-- overflow in an array element should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(array_value as array<tinyint>)
FROM test_cast_complex_ansi
WHERE id = 4
82 changes: 77 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.comet
import java.io.File

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.util.Random

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -1465,6 +1466,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

val nestedType =
StructType(Seq(StructField("long_value", LongType), StructField("bool_value", BooleanType)))
val structType = StructType(
Seq(
StructField("int_value", IntegerType),
StructField("string_value", StringType),
StructField("nested_value", nestedType)))
val schema = StructType(Seq(StructField("a", structType)))
val rows = Seq(
Row(Row(1, "one", Row(10L, true))),
Row(Row(null, "missing-int", Row(-2L, false))),
Row(Row(3, null, null)),
Row(null))

castTest(spark.createDataFrame(rows.asJava, schema), StringType)
}

test("cast StructType to StructType") {
Expand All @@ -1479,6 +1496,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

val fromNestedType = StructType(Seq(StructField("inner_int", IntegerType)))
val fromType = StructType(
Seq(
StructField("long_value", LongType),
StructField("string_value", StringType),
StructField("nested_value", fromNestedType)))
val toNestedType = StructType(Seq(StructField("renamed_inner_long", LongType)))
val toType = StructType(
Seq(
StructField("renamed_byte", ByteType),
StructField("renamed_string", StringType),
StructField("renamed_nested", toNestedType)))
val schema = StructType(Seq(StructField("a", fromType)))
val rows = Seq(
Row(Row(1L, "one", Row(10))),
Row(Row(127L, null, Row(-20))),
Row(Row(null, "missing-long", null)),
Row(null))

castTest(spark.createDataFrame(rows.asJava, schema), toType)

val overflowFromType = StructType(
Seq(StructField("long_value", LongType), StructField("string_value", StringType)))
val overflowToType = StructType(
Seq(StructField("renamed_byte", ByteType), StructField("renamed_string", StringType)))
val overflowSchema = StructType(Seq(StructField("a", overflowFromType)))
val overflowRows = Seq(
Row(Row(1L, "fits")),
Row(Row(128L, "too-large")),
Row(Row(-129L, "too-small")),
Row(Row(null, "missing-long")),
Row(null))

castTest(
spark.createDataFrame(overflowRows.asJava, overflowSchema),
overflowToType,
expectAnsiFailure = true)
}

test("cast StructType to StructType with different names") {
Expand Down Expand Up @@ -1564,8 +1619,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast ArrayType to StringType - float double binary edge cases") {
import scala.jdk.CollectionConverters._

def bytes(values: Int*): Array[Byte] = values.map(_.toByte).toArray

def arrayInput(elementType: DataType, values: Seq[Any]): DataFrame = {
Expand Down Expand Up @@ -1630,6 +1683,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
DataTypes.TimestampNTZType,
BinaryType)
testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _))

val schema = StructType(Seq(StructField("a", ArrayType(LongType))))
val rows = Seq(
Row(Seq[Any](1L, 127L, null)),
Row(Seq[Any](128L)),
Row(Seq[Any](-129L, 0L)),
Row(Seq.empty[Any]),
Row(null))

castTest(
spark.createDataFrame(rows.asJava, schema),
ArrayType(ByteType),
expectAnsiFailure = true)
}

test("cast MapType to MapType") {
Expand All @@ -1639,7 +1705,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// the planner routes Map→Map casts into it. The map column must be read
// natively for the cast to be exercised by Comet, which only happens
// under the V1 Parquet scan, so we pin USE_V1_SOURCE_LIST=parquet.
import scala.collection.JavaConverters._
val schema =
StructType(Seq(StructField("a", MapType(IntegerType, IntegerType), nullable = true)))
val rows = Range(0, 100).map { i =>
Expand Down Expand Up @@ -1837,7 +1902,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
def buildRows(values: Seq[Any]): Seq[Row] = {
Range(0, rowNum).map { i =>
Expand Down Expand Up @@ -1899,7 +1963,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateNestedArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(ArrayType(elementType)), true)))
val innerArrays = generateArrays(rowNum, elementType)
.collect()
Expand Down Expand Up @@ -2214,6 +2277,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
hasIncompatibleType: Boolean = false,
testAnsi: Boolean = true,
testTry: Boolean = true,
expectAnsiFailure: Boolean = false,
useDataFrameDiff: Boolean = false): Unit = {

withTempPath { dir =>
Expand Down Expand Up @@ -2261,11 +2325,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
.select(col("__row_id"), col("a"), col("a").cast(toType).as("converted"))
.orderBy(col("__row_id"))
.drop("__row_id")
if (expectAnsiFailure) {
assert(!hasIncompatibleType, "Expected ANSI failures must use Comet native execution")
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
}
val res = if (useDataFrameDiff) {
assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType)
} else {
checkSparkAnswerMaybeThrows(df)
}
if (expectAnsiFailure) {
assert(res._1.isDefined, "Expected Spark ANSI cast to fail")
assert(res._2.isDefined, "Expected Comet ANSI cast to fail")
}
res match {
case (None, None) =>
// neither system threw an exception
Expand Down
Loading