From 5de6ac1fefec4f65f76974ce1734da60fa131f4f Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Tue, 7 Apr 2026 23:22:16 +0800 Subject: [PATCH 1/2] fix: PostgreSQL dialect can not support tinyint type --- datafusion/sql/src/unparser/dialect.rs | 25 +++++++++++++++++++++++ datafusion/sql/src/unparser/expr.rs | 2 +- datafusion/sql/tests/cases/plan_to_sql.rs | 11 ++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 9e0683bdd7b20..ee31190b68b98 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -100,6 +100,12 @@ pub trait Dialect: Send + Sync { ast::DataType::BigInt(None) } + /// The SQL type to use for Arrow Int8 unparsing + /// Most dialects use TinyInt, but PostgreSQL prefers SmallInt + fn int8_cast_dtype(&self) -> ast::DataType { + ast::DataType::TinyInt(None) + } + /// The SQL type to use for Arrow Int32 unparsing /// Most dialects use Integer, but some, like MySQL, require SIGNED fn int32_cast_dtype(&self) -> ast::DataType { @@ -345,6 +351,10 @@ impl Dialect for PostgreSqlDialect { ast::DataType::DoublePrecision } + fn int8_cast_dtype(&self) -> ast::DataType { + ast::DataType::SmallInt(None) + } + fn scalar_function_to_sql_overrides( &self, unparser: &Unparser, @@ -664,6 +674,7 @@ pub struct CustomDialect { large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, character_length_style: CharacterLengthStyle, + int8_cast_dtype: ast::DataType, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -689,6 +700,7 @@ impl Default for CustomDialect { large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, character_length_style: CharacterLengthStyle::CharacterLength, + int8_cast_dtype: ast::DataType::TinyInt(None), int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -748,6 +760,10 @@ impl Dialect for CustomDialect { self.int64_cast_dtype.clone() } + fn int8_cast_dtype(&self) -> ast::DataType { + self.int8_cast_dtype.clone() + } + fn int32_cast_dtype(&self) -> ast::DataType { self.int32_cast_dtype.clone() } @@ -839,6 +855,7 @@ pub struct CustomDialectBuilder { large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, character_length_style: CharacterLengthStyle, + int8_cast_dtype: ast::DataType, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -870,6 +887,7 @@ impl CustomDialectBuilder { large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, character_length_style: CharacterLengthStyle::CharacterLength, + int8_cast_dtype: ast::DataType::TinyInt(None), int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -898,6 +916,7 @@ impl CustomDialectBuilder { large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, character_length_style: self.character_length_style, + int8_cast_dtype: self.int8_cast_dtype, int64_cast_dtype: self.int64_cast_dtype, int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, @@ -952,6 +971,12 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific SQL type for Int8 casting: TinyInt, SmallInt, etc. + pub fn with_int8_cast_dtype(mut self, int8_cast_dtype: ast::DataType) -> Self { + self.int8_cast_dtype = int8_cast_dtype; + self + } + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3601febe744c9..9c6d9d97a5eff 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1734,7 +1734,7 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Boolean => Ok(ast::DataType::Bool), - DataType::Int8 => Ok(ast::DataType::TinyInt(None)), + DataType::Int8 => Ok(self.dialect.int8_cast_dtype()), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), DataType::Int32 => Ok(self.dialect.int32_cast_dtype()), DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 0dad48b168976..130e46161b996 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1937,6 +1937,17 @@ fn test_without_offset() { ) } +#[test] +fn test_cast_to_tinyint() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select cast(3 as tinyint)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserPostgreSqlDialect {}, + expected: @"SELECT CAST(3 AS SMALLINT)", + ); + Ok(()) +} + #[test] fn test_with_offset0() { let statement = generate_round_trip_statement(MySqlDialect {}, "select 1 offset 0"); From 7e979820337a543ee9d40d8c816bdc756d18a0fe Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Wed, 8 Apr 2026 21:46:32 +0800 Subject: [PATCH 2/2] add case --- datafusion/sql/tests/cases/plan_to_sql.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 130e46161b996..d4f7b32a8de09 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1948,6 +1948,17 @@ fn test_cast_to_tinyint() -> Result<(), DataFusionError> { Ok(()) } +#[test] +fn test_cast_to_tinyint_default_dialect() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select cast(3 as tinyint)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @"SELECT CAST(3 AS TINYINT)", + ); + Ok(()) +} + #[test] fn test_with_offset0() { let statement = generate_round_trip_statement(MySqlDialect {}, "select 1 offset 0");