Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,10 @@ fn cast_struct_to_struct(
ColumnarValue::from(from_field),
to.data_type(),
cast_options,
)
.unwrap();
cast_result.to_array(array_length).unwrap()
)?;
cast_result.to_array(array_length)
})
.collect();
.collect::<DataFusionResult<Vec<_>>>()?;

Ok(Arc::new(StructArray::new(
to_fields.clone(),
Expand Down Expand Up @@ -961,6 +960,38 @@ mod tests {
}
}

#[test]
fn test_cast_nested_struct_to_struct_ansi_overflow_returns_error() {
let inner_values: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(128), None]));
let from_nested_fields =
Fields::from(vec![Field::new("long_value", DataType::Int64, true)]);
let nested: ArrayRef = Arc::new(StructArray::new(
from_nested_fields.clone(),
vec![inner_values],
None,
));
let from_fields = Fields::from(vec![Field::new(
"nested",
DataType::Struct(from_nested_fields),
true,
)]);
let outer: ArrayRef = Arc::new(StructArray::new(from_fields, vec![nested], None));

let to_nested_fields = Fields::from(vec![Field::new("byte_value", DataType::Int8, true)]);
let to_fields = Fields::from(vec![Field::new(
"renamed_nested",
DataType::Struct(to_nested_fields),
true,
)]);
let result = spark_cast(
ColumnarValue::Array(outer),
&DataType::Struct(to_fields),
&SparkCastOptions::new(EvalMode::Ansi, "UTC", false),
);

assert!(result.is_err());
}

#[test]
fn test_cast_struct_to_struct_drop_column() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- Config: spark.sql.ansi.enabled=true

statement
CREATE TABLE test_cast_complex_ansi(
id int,
struct_value struct<
long_value:bigint,
string_value:string,
nested_value:struct<inner_long:bigint>>,
array_value array<bigint>
) USING parquet

statement
INSERT INTO test_cast_complex_ansi VALUES
(
1,
named_struct(
'long_value', cast(1 as bigint),
'string_value', 'fits',
'nested_value', named_struct('inner_long', cast(10 as bigint))),
array(cast(1 as bigint), cast(127 as bigint), cast(null as bigint))
),
(
2,
named_struct(
'long_value', cast(128 as bigint),
'string_value', 'too-large',
'nested_value', named_struct('inner_long', cast(10 as bigint))),
array(cast(1 as bigint))
),
(
3,
named_struct(
'long_value', cast(2 as bigint),
'string_value', 'nested-too-small',
'nested_value', named_struct('inner_long', cast(-129 as bigint))),
array(cast(2 as bigint))
),
(
4,
named_struct(
'long_value', cast(3 as bigint),
'string_value', 'array-too-large',
'nested_value', named_struct('inner_long', cast(4 as bigint))),
array(cast(128 as bigint))
),
(
5,
cast(null as struct<
long_value:bigint,
string_value:string,
nested_value:struct<inner_long:bigint>>),
cast(null as array<bigint>)
)

-- valid complex casts should run natively under ANSI mode
query
SELECT
cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>),
cast(array_value as array<tinyint>),
id
FROM test_cast_complex_ansi
WHERE id IN (1, 5)
ORDER BY id

-- overflow in a struct field should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>)
FROM test_cast_complex_ansi
WHERE id = 2

-- overflow in a nested struct field should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(struct_value as
struct<byte_value:tinyint,text:string,nested_value:struct<inner_byte:tinyint>>)
FROM test_cast_complex_ansi
WHERE id = 3

-- overflow in an array element should propagate as a cast error
query expect_error(CAST_OVERFLOW)
SELECT cast(array_value as array<tinyint>)
FROM test_cast_complex_ansi
WHERE id = 4
82 changes: 77 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.comet
import java.io.File

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.util.Random

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -1465,6 +1466,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

val nestedType =
StructType(Seq(StructField("long_value", LongType), StructField("bool_value", BooleanType)))
val structType = StructType(
Seq(
StructField("int_value", IntegerType),
StructField("string_value", StringType),
StructField("nested_value", nestedType)))
val schema = StructType(Seq(StructField("a", structType)))
val rows = Seq(
Row(Row(1, "one", Row(10L, true))),
Row(Row(null, "missing-int", Row(-2L, false))),
Row(Row(3, null, null)),
Row(null))

castTest(spark.createDataFrame(rows.asJava, schema), StringType)
}

test("cast StructType to StructType") {
Expand All @@ -1479,6 +1496,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

val fromNestedType = StructType(Seq(StructField("inner_int", IntegerType)))
val fromType = StructType(
Seq(
StructField("long_value", LongType),
StructField("string_value", StringType),
StructField("nested_value", fromNestedType)))
val toNestedType = StructType(Seq(StructField("renamed_inner_long", LongType)))
val toType = StructType(
Seq(
StructField("renamed_byte", ByteType),
StructField("renamed_string", StringType),
StructField("renamed_nested", toNestedType)))
val schema = StructType(Seq(StructField("a", fromType)))
val rows = Seq(
Row(Row(1L, "one", Row(10))),
Row(Row(127L, null, Row(-20))),
Row(Row(null, "missing-long", null)),
Row(null))

castTest(spark.createDataFrame(rows.asJava, schema), toType)

val overflowFromType = StructType(
Seq(StructField("long_value", LongType), StructField("string_value", StringType)))
val overflowToType = StructType(
Seq(StructField("renamed_byte", ByteType), StructField("renamed_string", StringType)))
val overflowSchema = StructType(Seq(StructField("a", overflowFromType)))
val overflowRows = Seq(
Row(Row(1L, "fits")),
Row(Row(128L, "too-large")),
Row(Row(-129L, "too-small")),
Row(Row(null, "missing-long")),
Row(null))

castTest(
spark.createDataFrame(overflowRows.asJava, overflowSchema),
overflowToType,
expectAnsiFailure = true)
}

test("cast StructType to StructType with different names") {
Expand Down Expand Up @@ -1564,8 +1619,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast ArrayType to StringType - float double binary edge cases") {
import scala.jdk.CollectionConverters._

def bytes(values: Int*): Array[Byte] = values.map(_.toByte).toArray

def arrayInput(elementType: DataType, values: Seq[Any]): DataFrame = {
Expand Down Expand Up @@ -1630,6 +1683,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
DataTypes.TimestampNTZType,
BinaryType)
testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _))

val schema = StructType(Seq(StructField("a", ArrayType(LongType))))
val rows = Seq(
Row(Seq[Any](1L, 127L, null)),
Row(Seq[Any](128L)),
Row(Seq[Any](-129L, 0L)),
Row(Seq.empty[Any]),
Row(null))

castTest(
spark.createDataFrame(rows.asJava, schema),
ArrayType(ByteType),
expectAnsiFailure = true)
}

test("cast MapType to MapType") {
Expand All @@ -1639,7 +1705,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// the planner routes Map→Map casts into it. The map column must be read
// natively for the cast to be exercised by Comet, which only happens
// under the V1 Parquet scan, so we pin USE_V1_SOURCE_LIST=parquet.
import scala.collection.JavaConverters._
val schema =
StructType(Seq(StructField("a", MapType(IntegerType, IntegerType), nullable = true)))
val rows = Range(0, 100).map { i =>
Expand Down Expand Up @@ -1837,7 +1902,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
def buildRows(values: Seq[Any]): Seq[Row] = {
Range(0, rowNum).map { i =>
Expand Down Expand Up @@ -1899,7 +1963,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateNestedArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(ArrayType(elementType)), true)))
val innerArrays = generateArrays(rowNum, elementType)
.collect()
Expand Down Expand Up @@ -2214,6 +2277,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
hasIncompatibleType: Boolean = false,
testAnsi: Boolean = true,
testTry: Boolean = true,
expectAnsiFailure: Boolean = false,
useDataFrameDiff: Boolean = false): Unit = {

withTempPath { dir =>
Expand Down Expand Up @@ -2261,11 +2325,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
.select(col("__row_id"), col("a"), col("a").cast(toType).as("converted"))
.orderBy(col("__row_id"))
.drop("__row_id")
if (expectAnsiFailure) {
assert(!hasIncompatibleType, "Expected ANSI failures must use Comet native execution")
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
}
val res = if (useDataFrameDiff) {
assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType)
} else {
checkSparkAnswerMaybeThrows(df)
}
if (expectAnsiFailure) {
assert(res._1.isDefined, "Expected Spark ANSI cast to fail")
assert(res._2.isDefined, "Expected Comet ANSI cast to fail")
}
res match {
case (None, None) =>
// neither system threw an exception
Expand Down
Loading