diff --git a/native/spark-expr/src/datetime_funcs/date_diff.rs b/native/spark-expr/src/datetime_funcs/date_diff.rs index ca148c103a..0fddbc9331 100644 --- a/native/spark-expr/src/datetime_funcs/date_diff.rs +++ b/native/spark-expr/src/datetime_funcs/date_diff.rs @@ -100,9 +100,12 @@ impl ScalarUDFImpl for SparkDateDiff { ) })?; - // Date32 stores days since epoch, so difference is just subtraction - let result: Int32Array = - binary(end_date_array, start_date_array, |end, start| end - start)?; + // Date32 stores days since epoch, so difference is just subtraction. Use wrapping_sub + // to match Spark, whose JVM int subtraction wraps on overflow; a plain `i32 -` would + // panic in debug builds on extreme inputs. + let result: Int32Array = binary(end_date_array, start_date_array, |end, start| { + end.wrapping_sub(start) + })?; Ok(ColumnarValue::Array(Arc::new(result))) } @@ -111,3 +114,53 @@ impl ScalarUDFImpl for SparkDateDiff { &self.aliases } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + use datafusion::config::ConfigOptions; + + fn date_diff(end: i32, start: i32) -> i32 { + let udf = SparkDateDiff::new(); + let return_field = Arc::new(Field::new("date_diff", DataType::Int32, true)); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Date32Array::from(vec![Some(end)]))), + ColumnarValue::Array(Arc::new(Date32Array::from(vec![Some(start)]))), + ], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + arg_fields: vec![], + }; + match udf.invoke_with_args(args).unwrap() { + ColumnarValue::Array(array) => array + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + _ => panic!("expected array result"), + } + } + + #[test] + fn test_date_diff_basic() { + // 2020-01-02 (18263) minus 2020-01-01 (18262) = 1 day + assert_eq!(date_diff(18263, 18262), 1); + assert_eq!(date_diff(18262, 18263), -1); + } + + #[test] + fn test_date_diff_wraps_on_overflow() { + // Extreme inputs overflow i32; Spark's JVM int subtraction wraps rather than panicking. + assert_eq!( + date_diff(i32::MAX, i32::MIN), + i32::MAX.wrapping_sub(i32::MIN) + ); + assert_eq!( + date_diff(i32::MIN, i32::MAX), + i32::MIN.wrapping_sub(i32::MAX) + ); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index 2ce75ccc0d..a2dad7fe4b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -292,9 +292,7 @@ object CometSecond extends CometExpressionSerde[Second] with CodegenDispatchFall object CometUnixTimestamp extends CometExpressionSerde[UnixTimestamp] { override def getUnsupportedReasons(): Seq[String] = Seq( - "Only `TimestampType` and `DateType` inputs are supported." + - " `TimestampNTZType` is not supported because Comet incorrectly applies timezone" + - " conversion to TimestampNTZ values.") + "Only `DateType`, `TimestampType`, and `TimestampNTZType` inputs are supported.") private def isSupportedInputType(expr: UnixTimestamp): Boolean = { expr.children.head.dataType match { @@ -695,9 +693,8 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { "Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " + "sessions. Other format strings (including non-literal formats), as well as non-UTC " + "sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " + - "codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " + - "codegen dispatcher is disabled (default) the operator falls back to Spark in those " + - "cases.") + "codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true` (the default). " + + "When the codegen dispatcher is disabled the operator falls back to Spark in those cases.") override def convert( expr: DateFormatClass,