diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java index 9cfb41f9b..8d840b4cb 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScan.java @@ -366,17 +366,21 @@ private List pruneByZonemapStats(List allSplits) { */ @Override public Partitioning outputPartitioning() { - if (partitionInfo != null) { - // Use partition info fragment count — available before - // planInputPartitions() is called. This allows - // V2ScanPartitioningAndOrdering to see the partitioning - // early enough for SPJ. - int partCount = - numPartitions >= 0 ? numPartitions : partitionInfo.getFragmentPartitionValues().size(); - Expression[] keys = new Expression[] {FieldReference.apply(partitionInfo.getColumnName())}; - return new KeyGroupedPartitioning(keys, partCount); + if (partitionInfo == null + || partitionInfo.isSoftCapped() + || !LanceScanBuilder.readReportingEnabledConf()) { + return new UnknownPartitioning(numPartitions >= 0 ? numPartitions : 0); } - return new UnknownPartitioning(numPartitions >= 0 ? numPartitions : 0); + List colNames = partitionInfo.getColumnNames(); + if (colNames.size() > 1 && !SparkVersionUtil.supportsMultiKeySpj()) { + return new UnknownPartitioning(numPartitions >= 0 ? numPartitions : 0); + } + int partCount = numPartitions >= 0 ? numPartitions : partitionInfo.size(); + Expression[] keys = new Expression[colNames.size()]; + for (int i = 0; i < colNames.size(); i++) { + keys[i] = FieldReference.apply(colNames.get(i)); + } + return new KeyGroupedPartitioning(keys, partCount); } @Override diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java index 1178bfa36..ac43e7d66 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/LanceScanBuilder.java @@ -92,6 +92,11 @@ public class LanceScanBuilder private final java.util.Map tableProperties; + static final String CONF_REPORTING_ENABLED = "spark.lance.partition.reporting.enabled"; + static final String CONF_REPORTING_MAX_PARTITIONS = + "spark.lance.partition.reporting.maxPartitions"; + static final int DEFAULT_REPORTING_MAX_PARTITIONS = 10_000; + public LanceScanBuilder( StructType schema, LanceSparkReadOptions readOptions, @@ -135,100 +140,98 @@ public Scan build() { return localScan; } - // Get statistics from manifest summary before closing dataset - ManifestSummary summary = getOrOpenDataset().getVersion().getManifestSummary(); - - // Collect all columns that need zonemap stats: filter columns + partition column (if declared). - Set columnsToLoad = extractReferencedColumns(pushedFilters); - String partitionColumn = tableProperties.get(LanceConstant.TABLE_OPT_PARTITION_COLUMNS); - if (partitionColumn != null && !partitionColumn.trim().isEmpty()) { - partitionColumn = partitionColumn.trim(); - columnsToLoad.add(partitionColumn); - } else { - partitionColumn = null; - } - - // Load zonemap stats for all requested columns in one pass. - Map> zonemapStats = loadZonemapStats(getOrOpenDataset(), columnsToLoad); + try { + // Get statistics from manifest summary before closing dataset + ManifestSummary summary = getOrOpenDataset().getVersion().getManifestSummary(); + + // Parse and validate partition columns from TBLPROPERTIES. Nested paths and non-whitelisted + // types are rejected here with a WARN; detection falls through to null PartitionInfo. + List partitionColumns = + parsePartitionColumns(tableProperties.get(LanceConstant.TABLE_OPT_PARTITION_COLUMNS)); + + // Collect all columns that need zonemap stats: filter columns + declared partition columns. + Set columnsToLoad = extractReferencedColumns(pushedFilters); + columnsToLoad.addAll(partitionColumns); + + // Load zonemap stats for all requested columns in one pass. + Map> zonemapStats = + loadZonemapStats(getOrOpenDataset(), columnsToLoad); + + // Reject-all policy: if any declared column fails detection (missing stats, non-constant + // values, coverage mismatch), the whole scan falls back to UnknownPartitioning so SPJ + // symmetry with the joined counterpart is preserved. + ZonemapFragmentPruner.PartitionInfo partitionInfo = + detectPartitioning(partitionColumns, zonemapStats); + + // Pre-compute fragment pruning so we can (a) estimate post-pruning statistics for + // JoinSelection (BroadcastHashJoin vs SortMergeJoin) and (b) pass the cached result + // to LanceScan to avoid re-computing during planInputPartitions(). + Set survivingFragmentIds = null; + if (pushedFilters.length > 0 && !zonemapStats.isEmpty()) { + survivingFragmentIds = + ZonemapFragmentPruner.pruneFragments(pushedFilters, zonemapStats).orElse(null); + } - // Detect partition-compatible columns, gated on lance.partition.columns table property. - // Currently a partitioned column is only valid if each fragment contains only a single - // value for that column (i.e., all zonemap zones have min == max with the same value). - ZonemapFragmentPruner.PartitionInfo partitionInfo = null; - if (partitionColumn != null) { - if (!zonemapStats.containsKey(partitionColumn)) { - LOG.warn( - "Partition column '{}' declared in {} has no zonemap index or stats;" - + " partition detection disabled", - partitionColumn, - LanceConstant.TABLE_OPT_PARTITION_COLUMNS); - } else { - Map> partValues = - ZonemapFragmentPruner.computeFragmentPartitionValues(zonemapStats.get(partitionColumn)) - .orElse(null); - if (partValues != null) { - partitionInfo = new ZonemapFragmentPruner.PartitionInfo(partitionColumn, partValues); - LOG.info( - "Detected partition-compatible column '{}' with {} fragments", - partitionColumn, - partValues.size()); + // Filter pushdown may have narrowed the surviving fragment set; restrict PartitionInfo so + // the partition count reported via SPJ matches the post-pushdown size. restrictTo clears + // the softCapped flag (cap is size-dependent) — re-apply if the restricted size still + // exceeds the cap. + if (partitionInfo != null && survivingFragmentIds != null) { + partitionInfo = partitionInfo.restrictTo(survivingFragmentIds); + if (partitionInfo.size() == 0) { + partitionInfo = null; + } else if (partitionInfo.size() > readMaxReportedPartitionsConf()) { + partitionInfo = partitionInfo.withSoftCapped(); } } - } - // Pre-compute fragment pruning so we can (a) estimate post-pruning statistics for - // JoinSelection (BroadcastHashJoin vs SortMergeJoin) and (b) pass the cached result - // to LanceScan to avoid re-computing during planInputPartitions(). - Set survivingFragmentIds = null; - if (pushedFilters.length > 0 && !zonemapStats.isEmpty()) { - survivingFragmentIds = - ZonemapFragmentPruner.pruneFragments(pushedFilters, zonemapStats).orElse(null); - } - - // Scale rows and full size by the zonemap fragment-pruning ratio first, then let - // LanceStatistics.estimateProjected apply the column-width ratio on top - // (when the projected schema is narrower than the full schema). - long projectedRows = summary.getTotalRows(); - long projectedFullSize = summary.getTotalFilesSize(); - if (survivingFragmentIds != null && summary.getTotalFragments() > 0) { - double ratio = (double) survivingFragmentIds.size() / summary.getTotalFragments(); - projectedRows = (long) (projectedRows * ratio); - projectedFullSize = (long) (projectedFullSize * ratio); - } - LanceStatistics statistics = - LanceStatistics.estimateProjected(projectedRows, projectedFullSize, fullSchema, schema); - if (survivingFragmentIds != null) { - LOG.debug( - "Scan statistics after pruning: {} of {} fragments survive," - + " estimatedSize={}, estimatedRows={} (full: size={}, rows={})", - survivingFragmentIds.size(), - summary.getTotalFragments(), - statistics.sizeInBytes(), - statistics.numRows(), - summary.getTotalFilesSize(), - summary.getTotalRows()); - } - - // Close the lazily opened dataset - it's no longer needed after build - closeLazyDataset(); - - Optional whereCondition = FilterPushDown.compileFiltersToSqlWhereClause(pushedFilters); - return new LanceScan( - schema, - readOptions, - whereCondition, - limit, - offset, - topNSortOrders, - pushedAggregation, - pushedFilters, - statistics, - zonemapStats, - survivingFragmentIds, - partitionInfo, - initialStorageOptions, - namespaceImpl, - namespaceProperties); + // Scale rows and full size by the zonemap fragment-pruning ratio first, then let + // LanceStatistics.estimateProjected apply the column-width ratio on top + // (when the projected schema is narrower than the full schema). + long projectedRows = summary.getTotalRows(); + long projectedFullSize = summary.getTotalFilesSize(); + if (survivingFragmentIds != null && summary.getTotalFragments() > 0) { + double ratio = (double) survivingFragmentIds.size() / summary.getTotalFragments(); + projectedRows = (long) (projectedRows * ratio); + projectedFullSize = (long) (projectedFullSize * ratio); + } + LanceStatistics statistics = + LanceStatistics.estimateProjected(projectedRows, projectedFullSize, fullSchema, schema); + if (survivingFragmentIds != null) { + LOG.debug( + "Scan statistics after pruning: {} of {} fragments survive," + + " estimatedSize={}, estimatedRows={} (full: size={}, rows={})", + survivingFragmentIds.size(), + summary.getTotalFragments(), + statistics.sizeInBytes(), + statistics.numRows(), + summary.getTotalFilesSize(), + summary.getTotalRows()); + } + + Optional whereCondition = + FilterPushDown.compileFiltersToSqlWhereClause(pushedFilters); + return new LanceScan( + schema, + readOptions, + whereCondition, + limit, + offset, + topNSortOrders, + pushedAggregation, + pushedFilters, + statistics, + zonemapStats, + survivingFragmentIds, + partitionInfo, + initialStorageOptions, + namespaceImpl, + namespaceProperties); + } finally { + // Always close the lazily opened dataset, including on exception paths, so we don't leak + // the JNI handle when parsing/detection/pruning helpers throw. + closeLazyDataset(); + } } @Override @@ -406,4 +409,215 @@ private static Set extractReferencedColumns(Filter[] filters) { } return columns; } + + /** + * Tokenizes {@code lance.partition.columns} on {@code ,}, trims, drops empties, deduplicates, + * rejects nested paths, and validates each column's Spark type against the whitelist. Returns an + * empty list if the property is absent, empty, or any column fails validation (reject-all). + */ + // Package-private so LanceScanBuilderTest can assert dedup / ordering directly. + List parsePartitionColumns(String raw) { + // Treat null, empty, whitespace-only, and pure-delimiter values (",", ", ,", ...) all as + // "property not set" — these are the no-op cases; returning quietly avoids a spurious WARN. + if (raw == null || raw.replace(",", "").trim().isEmpty()) { + return Collections.emptyList(); + } + List tokens = new ArrayList<>(); + Set seen = new HashSet<>(); + for (String part : raw.split(",")) { + String trimmed = part.trim(); + if (trimmed.isEmpty()) { + continue; + } + if (!seen.add(trimmed)) { + LOG.warn( + "{} contains duplicate column '{}' (dropped)", + LanceConstant.TABLE_OPT_PARTITION_COLUMNS, + trimmed); + continue; + } + if (trimmed.contains(".")) { + LOG.warn("partition column '{}' has nested path; nested paths not supported", trimmed); + return Collections.emptyList(); + } + if (!isSupportedPartitionType(trimmed)) { + return Collections.emptyList(); + } + tokens.add(trimmed); + } + return tokens; + } + + /** + * Looks up the column's type on the full read schema and returns true iff it is in the partition + * whitelist. Uses {@link #fullSchema} rather than {@link #schema} so column pruning does not + * remove partition columns from the lookup; returns false with a WARN if the column is missing or + * has an unsupported type. + */ + private boolean isSupportedPartitionType(String columnName) { + int idx; + try { + idx = fullSchema.fieldIndex(columnName); + } catch (IllegalArgumentException e) { + LOG.warn( + "partition column '{}' is not in the table schema; partition detection disabled", + columnName); + return false; + } + org.apache.spark.sql.types.DataType type = fullSchema.fields()[idx].dataType(); + // Whitelist types that PartitionInfo.toSparkValue can encode into Spark's InternalRow. + // ZoneStats always returns Long for integral Arrow widths (int8/16/32 too) and for + // Date (epoch-days) / Timestamp (epoch-micros); toSparkValue narrows/wraps appropriately. + // Use .equals() rather than == so a DataType materialized from a deserialized schema + // (e.g. JSON/Avro round-trip) still matches the singleton constants. + if (DataTypes.BooleanType.equals(type) + || DataTypes.ByteType.equals(type) + || DataTypes.ShortType.equals(type) + || DataTypes.IntegerType.equals(type) + || DataTypes.LongType.equals(type) + || DataTypes.StringType.equals(type) + || DataTypes.DateType.equals(type) + || DataTypes.TimestampType.equals(type)) { + return true; + } + LOG.warn( + "partition column '{}' has unsupported type {}: whitelist is" + + " Boolean/Byte/Short/Int/Long/String/Date/Timestamp", + columnName, + type.typeName()); + return false; + } + + /** + * Runs per-column zone-constancy detection, verifies that every declared column covers the same + * fragment-id set, and assembles per-fragment partition tuples in declaration order. Returns null + * when any column fails — reject-all, so SPJ symmetry is preserved on the joined counterpart. + */ + // Package-private for unit-test access to the multi-column detection logic. + ZonemapFragmentPruner.PartitionInfo detectPartitioning( + List partitionColumns, Map> zonemapStats) { + if (partitionColumns.isEmpty()) { + return null; + } + Map>> perColumnMaps = new HashMap<>(); + for (String name : partitionColumns) { + if (!zonemapStats.containsKey(name)) { + LOG.warn("partition column '{}' has no zonemap stats; partition detection disabled", name); + return null; + } + Map> values = + ZonemapFragmentPruner.computeFragmentPartitionValues(zonemapStats.get(name)).orElse(null); + if (values == null || values.isEmpty()) { + LOG.warn( + "partition column '{}' has non-constant or null values; partition detection disabled", + name); + return null; + } + perColumnMaps.put(name, values); + } + + // Require every declared partition column to cover the same fragment-id set. A strict-subset + // intersection would leave splits for uncovered fragments with a phantom null-key tuple — + // wrong input to Spark's SPJ. Iterate in declaration order so the mismatched-column WARN is + // deterministic across runs. + Set intersection = null; + for (String name : partitionColumns) { + Set columnFragments = perColumnMaps.get(name).keySet(); + if (intersection == null) { + intersection = new HashSet<>(columnFragments); + } else if (!intersection.equals(columnFragments)) { + LOG.warn( + "partition columns {} have mismatched fragment-id coverage (column '{}' differs);" + + " partition detection disabled", + partitionColumns, + name); + return null; + } + } + if (intersection == null || intersection.isEmpty()) { + LOG.warn( + "partition columns {} have no covered fragments; partition detection disabled", + partitionColumns); + return null; + } + + // Assemble tuples in declaration order, and resolve per-column Spark types from the full + // read schema so PartitionInfo can encode each value into the right InternalRow slot + // (narrowing Long -> byte/short/int for Byte/Short/Int/Date columns, pass-through otherwise). + int width = partitionColumns.size(); + Map[]> tuples = new HashMap<>(); + for (Integer fragId : intersection) { + Comparable[] tuple = new Comparable[width]; + for (int i = 0; i < width; i++) { + tuple[i] = perColumnMaps.get(partitionColumns.get(i)).get(fragId); + } + tuples.put(fragId, tuple); + } + List columnTypes = new java.util.ArrayList<>(width); + for (String name : partitionColumns) { + columnTypes.add(fullSchema.fields()[fullSchema.fieldIndex(name)].dataType()); + } + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo(partitionColumns, columnTypes, tuples); + + // Apply soft cap based on session conf (if available) or the default. When the cap fires, + // the scan will report UnknownPartitioning — log that branch separately so operators don't + // see a success-looking "detected N fragments" INFO immediately after the soft-cap WARN. + int cap = readMaxReportedPartitionsConf(); + if (info.size() > cap) { + LOG.warn( + "partition count {} exceeds {}={}; reporting UnknownPartitioning", + info.size(), + CONF_REPORTING_MAX_PARTITIONS, + cap); + return info.withSoftCapped(); + } + LOG.info( + "lance.partition.detect cols={} columnCount={} fragments={}", + partitionColumns, + partitionColumns.size(), + info.size()); + return info; + } + + private static int readMaxReportedPartitionsConf() { + String val = null; + try { + org.apache.spark.sql.SparkSession session = org.apache.spark.sql.SparkSession.active(); + val = session.conf().get(CONF_REPORTING_MAX_PARTITIONS, null); + } catch (Exception e) { + // No active SparkSession (e.g. unit-test / offline builder usage); log at DEBUG so real + // session-level misconfiguration is diagnosable, and fall through to the default. + LOG.debug( + "Could not read {}: {}; using default", CONF_REPORTING_MAX_PARTITIONS, e.toString()); + return DEFAULT_REPORTING_MAX_PARTITIONS; + } + if (val == null) { + return DEFAULT_REPORTING_MAX_PARTITIONS; + } + try { + return Integer.parseInt(val.trim()); + } catch (NumberFormatException e) { + LOG.warn( + "Could not parse {}='{}' as an integer; using default {}", + CONF_REPORTING_MAX_PARTITIONS, + val, + DEFAULT_REPORTING_MAX_PARTITIONS); + return DEFAULT_REPORTING_MAX_PARTITIONS; + } + } + + static boolean readReportingEnabledConf() { + try { + org.apache.spark.sql.SparkSession session = org.apache.spark.sql.SparkSession.active(); + String val = session.conf().get(CONF_REPORTING_ENABLED, null); + if (val != null) { + return !"false".equalsIgnoreCase(val.trim()); + } + } catch (Exception e) { + // No active SparkSession; log at DEBUG and default to enabled. + LOG.debug("Could not read {}: {}; defaulting to true", CONF_REPORTING_ENABLED, e.toString()); + } + return true; + } } diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/SparkVersionUtil.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/SparkVersionUtil.java new file mode 100644 index 000000000..16cbc4bfe --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/SparkVersionUtil.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +import org.apache.spark.package$; + +/** + * Runtime Spark-version gates for features whose behavior differs across supported Spark versions. + * + *

Kept in a single helper so pre-merge findings on multi-key SPJ (Spark 3.4 vs 3.5+) can be + * adjusted in one place rather than scattered across the codebase. + */ +final class SparkVersionUtil { + + private SparkVersionUtil() {} + + /** + * Whether the running Spark version reliably honors multi-key {@code KeyGroupedPartitioning} + * without silently falling back to a shuffle. Uses an explicit allowlist (Spark 3.5.x, 4.x+) + * rather than a denylist so custom forks and unexpected version strings default to the safe "fall + * back to UnknownPartitioning" behavior. + */ + static boolean supportsMultiKeySpj() { + return supportsMultiKeySpj(package$.MODULE$.SPARK_VERSION()); + } + + /** + * Package-private pure-function overload for unit tests. Inputs the version string directly so + * the parse/allowlist logic can be exercised without a running SparkContext. + */ + static boolean supportsMultiKeySpj(String version) { + if (version == null) { + return false; + } + // Accept "3.5.x" or any 4.x+ build. Reject everything else (3.4.x, 3.3.x, custom strings). + if (version.startsWith("3.5.")) { + return true; + } + int dot = version.indexOf('.'); + if (dot <= 0) { + return false; + } + try { + int major = Integer.parseInt(version.substring(0, dot)); + return major >= 4; + } catch (NumberFormatException e) { + return false; + } + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/ZonemapFragmentPruner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/ZonemapFragmentPruner.java index 06b5a6deb..0ad67b3a2 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/read/ZonemapFragmentPruner.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/read/ZonemapFragmentPruner.java @@ -29,6 +29,8 @@ import org.apache.spark.sql.sources.LessThanOrEqual; import org.apache.spark.sql.sources.Not; import org.apache.spark.sql.sources.Or; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.unsafe.types.UTF8String; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -343,47 +345,174 @@ private enum ComparisonType { } /** - * Result of partition detection: the partition column name and a map from fragment ID to the - * partition value for that fragment. + * Result of partition detection: the ordered list of partition column names, a parallel list of + * Spark {@link DataType}s (used to encode each fragment's tuple into an {@link InternalRow}), and + * a map from fragment ID to the partition tuple (one value per declared column, in declaration + * order). + * + *

Width invariants (enforced by the constructor): {@code columnTypes.size()} equals {@code + * columnNames.size()}; every tuple has length {@code columnNames.size()}; column names are + * distinct and non-empty. */ public static final class PartitionInfo implements Serializable { - private static final long serialVersionUID = 1L; + // Bumped from the single-column shape (1L on upstream main): adds columnTypes and uses + // tuple storage (Map[]>) for multi-column type-aware encoding. + private static final long serialVersionUID = 2L; + + private final List columnNames; + private final List columnTypes; + private final Map[]> fragmentPartitionKeys; + private final boolean softCapped; + + public PartitionInfo( + List columnNames, + List columnTypes, + Map[]> fragmentPartitionKeys) { + this(columnNames, columnTypes, fragmentPartitionKeys, false); + } - private final String columnName; - private final Map> fragmentPartitionValues; + public PartitionInfo( + List columnNames, + List columnTypes, + Map[]> fragmentPartitionKeys, + boolean softCapped) { + if (columnNames == null || columnNames.isEmpty()) { + throw new IllegalArgumentException("columnNames must be non-empty"); + } + if (columnTypes == null || columnTypes.size() != columnNames.size()) { + throw new IllegalArgumentException("columnTypes must have the same size as columnNames"); + } + if (new HashSet<>(columnNames).size() != columnNames.size()) { + throw new IllegalArgumentException("columnNames must be distinct: " + columnNames); + } + int width = columnNames.size(); + Map[]> copy = new HashMap<>(); + for (Map.Entry[]> e : fragmentPartitionKeys.entrySet()) { + Comparable[] tuple = e.getValue(); + if (tuple == null || tuple.length != width) { + throw new IllegalArgumentException( + "tuple for fragment " + e.getKey() + " must have length " + width); + } + copy.put(e.getKey(), tuple.clone()); + } + this.columnNames = Collections.unmodifiableList(new java.util.ArrayList<>(columnNames)); + this.columnTypes = Collections.unmodifiableList(new java.util.ArrayList<>(columnTypes)); + this.fragmentPartitionKeys = Collections.unmodifiableMap(copy); + this.softCapped = softCapped; + } - public PartitionInfo(String columnName, Map> fragmentPartitionValues) { - this.columnName = columnName; - this.fragmentPartitionValues = Collections.unmodifiableMap(fragmentPartitionValues); + /** + * Factory for the single-column case. Wraps each scalar partition value into a length-1 tuple + * and delegates to the list-form constructor. + */ + public static PartitionInfo forSingleColumn( + String columnName, DataType columnType, Map> valueByFragment) { + Map[]> tupleMap = new HashMap<>(); + for (Map.Entry> e : valueByFragment.entrySet()) { + tupleMap.put(e.getKey(), new Comparable[] {e.getValue()}); + } + return new PartitionInfo( + Collections.singletonList(columnName), Collections.singletonList(columnType), tupleMap); } - public String getColumnName() { - return columnName; + public List getColumnNames() { + return columnNames; } - public Map> getFragmentPartitionValues() { - return fragmentPartitionValues; + public List getColumnTypes() { + return columnTypes; } /** - * Returns a partition key {@link InternalRow} for the given fragment ID. The row contains a - * single column with the partition value, converted to a Spark-compatible type. + * Returns the fragment-id → tuple map as an unmodifiable snapshot; each tuple array is + * defensively cloned on every call so mutating the returned arrays cannot corrupt internal + * state. Prefer {@link #partitionKeyForFragment(int)} for hot paths — this getter exists for + * inspection, equality checks, and serialization round-trip tests. + */ + public Map[]> getFragmentPartitionKeys() { + Map[]> snapshot = new HashMap<>(fragmentPartitionKeys.size()); + for (Map.Entry[]> e : fragmentPartitionKeys.entrySet()) { + snapshot.put(e.getKey(), e.getValue().clone()); + } + return Collections.unmodifiableMap(snapshot); + } + + public int size() { + return fragmentPartitionKeys.size(); + } + + public boolean isSoftCapped() { + return softCapped; + } + + /** + * Returns a new PartitionInfo restricted to the given fragment-id set. Preserves column order + * and tuple shape. The {@code softCapped} flag is NOT carried over because the cap decision is + * a function of size; if the restricted size still exceeds the cap, the caller must re-apply it + * via {@link #withSoftCapped()}. Used after filter pushdown narrows the surviving fragment set. + */ + public PartitionInfo restrictTo(Set survivingFragmentIds) { + Map[]> restricted = new HashMap<>(); + for (Map.Entry[]> e : fragmentPartitionKeys.entrySet()) { + if (survivingFragmentIds.contains(e.getKey())) { + restricted.put(e.getKey(), e.getValue()); + } + } + return new PartitionInfo(columnNames, columnTypes, restricted, false); + } + + /** Marks this info as soft-capped, returning a new instance (immutability preserved). */ + public PartitionInfo withSoftCapped() { + return new PartitionInfo(columnNames, columnTypes, fragmentPartitionKeys, true); + } + + /** + * Returns a partition key {@link InternalRow} for the given fragment ID. The row contains one + * or more columns (in declaration order), each converted to a Spark-compatible type. */ public InternalRow partitionKeyForFragment(int fragmentId) { - Comparable value = fragmentPartitionValues.get(fragmentId); - Object sparkValue = toSparkValue(value); - return new GenericInternalRow(new Object[] {sparkValue}); + Comparable[] tuple = fragmentPartitionKeys.get(fragmentId); + int width = columnNames.size(); + Object[] out = new Object[width]; + if (tuple == null) { + return new GenericInternalRow(out); + } + for (int i = 0; i < width; i++) { + out[i] = toSparkValue(tuple[i], columnTypes.get(i)); + } + return new GenericInternalRow(out); } - private static Object toSparkValue(Comparable value) { + /** + * Converts a ZoneStats value to the exact Java class Spark's {@link InternalRow} expects for + * the target slot. ZoneStats returns {@code Long} for every integral Arrow width (int8/16/32 + * included) and for Date (epoch-days) / Timestamp (epoch-micros); those need explicit narrowing + * to match {@code getByte}/{@code getShort}/{@code getInt} accessors. Boolean is already typed; + * Strings are wrapped in {@link UTF8String}. + */ + private static Object toSparkValue(Comparable value, DataType type) { if (value == null) { return null; } - if (value instanceof String) { + if (DataTypes.BooleanType.equals(type)) { + return value; + } + if (DataTypes.ByteType.equals(type)) { + return ((Number) value).byteValue(); + } + if (DataTypes.ShortType.equals(type)) { + return ((Number) value).shortValue(); + } + if (DataTypes.IntegerType.equals(type) || DataTypes.DateType.equals(type)) { + return ((Number) value).intValue(); + } + if (DataTypes.LongType.equals(type) || DataTypes.TimestampType.equals(type)) { + return ((Number) value).longValue(); + } + if (DataTypes.StringType.equals(type)) { return UTF8String.fromString((String) value); } - // Long, Double, Boolean, Integer are already compatible - return value; + throw new IllegalArgumentException("Unsupported partition column type: " + type); } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanBuilderTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanBuilderTest.java index 51addf596..8aa7d8165 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanBuilderTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanBuilderTest.java @@ -13,6 +13,7 @@ */ package org.lance.spark.read; +import org.lance.index.scalar.ZoneStats; import org.lance.spark.LanceSparkReadOptions; import org.lance.spark.TestUtils; @@ -37,7 +38,11 @@ import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Test; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @@ -327,4 +332,149 @@ public NullOrdering nullOrdering() { return nullOrdering; } } + + // --- lance.partition.columns parsing guards --- + + private LanceScanBuilder builderWithPartitionColumns(String value) { + return new LanceScanBuilder( + TEST_SCHEMA, + TestUtils.TestTable1Config.readOptions, + Collections.emptyMap(), + null, + Collections.emptyMap(), + Collections.singletonMap("lance.partition.columns", value)); + } + + @Test + public void testPartitionColumnsUnknownColumnFallsBackCleanly() { + // Unknown column must not throw IllegalArgumentException; the builder logs a WARN and + // falls back to a scan that reports UnknownPartitioning. + Scan scan = builderWithPartitionColumns("nonexistent_column").build(); + assertInstanceOf(LanceScan.class, scan); + LanceScan ls = (LanceScan) scan; + ls.planInputPartitions(); + assertInstanceOf( + org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning.class, + ls.outputPartitioning()); + } + + @Test + public void testPartitionColumnsNestedPathFallsBackCleanly() { + // Nested field paths are not supported; builder rejects the property, scan reports Unknown. + Scan scan = builderWithPartitionColumns("outer.inner").build(); + assertInstanceOf(LanceScan.class, scan); + LanceScan ls = (LanceScan) scan; + ls.planInputPartitions(); + assertInstanceOf( + org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning.class, + ls.outputPartitioning()); + } + + @Test + public void testPartitionColumnsWhitespaceOnlyIsAbsent() { + // Whitespace-only property must be treated exactly like an absent property: no WARN about + // empty tokenization for the common "users didn't set it" path. + Scan scan = builderWithPartitionColumns(" ").build(); + assertInstanceOf(LanceScan.class, scan); + } + + @Test + public void testPartitionColumnsEmptyStringIsAbsent() { + Scan scan = builderWithPartitionColumns("").build(); + assertInstanceOf(LanceScan.class, scan); + } + + @Test + public void testPartitionColumnsDelimitersOnlyIsAbsent() { + // Pure-delimiter input (",", ", , ,") must be treated as absent — no WARN about empty + // tokenization, since the effective user intent is "no partition columns declared". + Scan scan = builderWithPartitionColumns(",").build(); + assertInstanceOf(LanceScan.class, scan); + scan = builderWithPartitionColumns(", , ,").build(); + assertInstanceOf(LanceScan.class, scan); + } + + @Test + public void testPartitionColumnsUnsupportedTypeFallsBackCleanly() { + // A column whose Spark type is outside the whitelist (Float/Double/Decimal/complex) must + // trigger reject-all: the scan still builds, but reports UnknownPartitioning. + StructType schemaWithDouble = + new StructType( + new StructField[] { + DataTypes.createStructField("x", DataTypes.LongType, true), + DataTypes.createStructField("y", DataTypes.LongType, true), + DataTypes.createStructField("b", DataTypes.LongType, true), + DataTypes.createStructField("c", DataTypes.LongType, true), + DataTypes.createStructField("score", DataTypes.DoubleType, true), + }); + LanceScanBuilder builder = + new LanceScanBuilder( + schemaWithDouble, + TestUtils.TestTable1Config.readOptions, + Collections.emptyMap(), + null, + Collections.emptyMap(), + Collections.singletonMap("lance.partition.columns", "score")); + Scan scan = builder.build(); + assertInstanceOf(LanceScan.class, scan); + LanceScan ls = (LanceScan) scan; + ls.planInputPartitions(); + assertInstanceOf( + org.apache.spark.sql.connector.read.partitioning.UnknownPartitioning.class, + ls.outputPartitioning()); + } + + // --- detectPartitioning: identical per-column fragment coverage --- + + @Test + public void testDetectPartitioningRejectsMismatchedCoverage() { + // Column "x" covers fragments {0, 1}; column "y" covers only {0}. Strict-subset coverage + // must reject detection entirely — otherwise fragment 1 would produce a phantom null tuple + // element for column "y" (same class of bug the per-column intersection used to allow). + // Column names chosen from TEST_SCHEMA (x, y, b, c) so fullSchema.fieldIndex resolves. + LanceScanBuilder builder = createBuilder(); + Map> stats = new HashMap<>(); + stats.put( + "x", Arrays.asList(new ZoneStats(0, 0, 10, 1L, 1L, 0), new ZoneStats(1, 0, 10, 2L, 2L, 0))); + stats.put("y", Collections.singletonList(new ZoneStats(0, 0, 10, 100L, 100L, 0))); + + ZonemapFragmentPruner.PartitionInfo info = + builder.detectPartitioning(Arrays.asList("x", "y"), stats); + assertNull(info, "Detection must reject when per-column fragment coverage differs"); + } + + @Test + public void testDetectPartitioningAcceptsIdenticalCoverage() { + LanceScanBuilder builder = createBuilder(); + Map> stats = new HashMap<>(); + stats.put( + "x", Arrays.asList(new ZoneStats(0, 0, 10, 1L, 1L, 0), new ZoneStats(1, 0, 10, 2L, 2L, 0))); + stats.put( + "y", + Arrays.asList( + new ZoneStats(0, 0, 10, 100L, 100L, 0), new ZoneStats(1, 0, 10, 200L, 200L, 0))); + + ZonemapFragmentPruner.PartitionInfo info = + builder.detectPartitioning(Arrays.asList("x", "y"), stats); + assertNotNull(info); + assertEquals(Arrays.asList("x", "y"), info.getColumnNames()); + // Types resolved from TEST_SCHEMA (both x and y are LongType). + assertEquals(Arrays.asList(DataTypes.LongType, DataTypes.LongType), info.getColumnTypes()); + assertEquals(2, info.size()); + // Tuples are assembled in declaration order, fragment by fragment. + assertArrayEquals(new Object[] {1L, 100L}, info.getFragmentPartitionKeys().get(0)); + assertArrayEquals(new Object[] {2L, 200L}, info.getFragmentPartitionKeys().get(1)); + } + + // --- parsePartitionColumns: direct assertions on the token list --- + + @Test + public void testParsePartitionColumnsDedupesTrimsAndPreservesOrder() { + // "y, x , x, b" exercises all three behaviors together: whitespace trimming happens before + // dedup (so " x " collapses with "x"), duplicates after the first are dropped with a WARN, + // and the surviving tokens keep source-string order (not alphabetic). + LanceScanBuilder builder = createBuilder(); + List result = builder.parsePartitionColumns("y, x , x, b"); + assertEquals(Arrays.asList("y", "x", "b"), result); + } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanTest.java index 13fe3115b..18737e6cd 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/LanceScanTest.java @@ -199,7 +199,8 @@ public void testOutputPartitioningWithPartitionInfo() { fragValues.put(0, "east"); fragValues.put(1, "west"); ZonemapFragmentPruner.PartitionInfo partInfo = - new ZonemapFragmentPruner.PartitionInfo("region", fragValues); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", org.apache.spark.sql.types.DataTypes.StringType, fragValues); LanceScan scan = new LanceScan( @@ -243,6 +244,75 @@ public void testOutputPartitioningWithoutPartitionInfoIsUnknown() { assertInstanceOf(UnknownPartitioning.class, partitioning); } + private LanceScan buildScanWithPartitionInfo(ZonemapFragmentPruner.PartitionInfo info) { + return new LanceScan( + TEST_SCHEMA, + TestUtils.TestTable1Config.readOptions, + org.lance.spark.utils.Optional.empty(), + org.lance.spark.utils.Optional.empty(), + org.lance.spark.utils.Optional.empty(), + org.lance.spark.utils.Optional.empty(), + org.lance.spark.utils.Optional.empty(), + new Filter[0], + null, + Collections.emptyMap(), + null, + info, + Collections.emptyMap(), + null, + Collections.emptyMap()); + } + + /** On a Spark version that supports multi-key SPJ (3.5+), N keys are reported in order. */ + @Test + public void testOutputPartitioningMultiColumn() { + java.util.Map[]> tuples = new HashMap<>(); + tuples.put(0, new Comparable[] {"us", 2024L}); + tuples.put(1, new Comparable[] {"eu", 2025L}); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + java.util.Arrays.asList("region", "year"), + java.util.Arrays.asList( + org.apache.spark.sql.types.DataTypes.StringType, + org.apache.spark.sql.types.DataTypes.LongType), + tuples); + + LanceScan scan = buildScanWithPartitionInfo(info); + scan.planInputPartitions(); + Partitioning partitioning = scan.outputPartitioning(); + + // Gated off on 3.4 — skip the KGP assertion when the gate is closed. + if (!SparkVersionUtil.supportsMultiKeySpj()) { + assertInstanceOf(UnknownPartitioning.class, partitioning); + return; + } + + assertInstanceOf(KeyGroupedPartitioning.class, partitioning); + KeyGroupedPartitioning kgp = (KeyGroupedPartitioning) partitioning; + Expression[] keys = kgp.keys(); + assertEquals(2, keys.length); + assertEquals("region", ((FieldReference) keys[0]).fieldNames()[0]); + assertEquals("year", ((FieldReference) keys[1]).fieldNames()[0]); + } + + /** A soft-capped PartitionInfo must cause outputPartitioning to report Unknown. */ + @Test + public void testOutputPartitioningSoftCappedReturnsUnknown() { + java.util.Map[]> tuples = new HashMap<>(); + tuples.put(0, new Comparable[] {"us"}); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + java.util.Collections.singletonList("region"), + java.util.Collections.singletonList( + org.apache.spark.sql.types.DataTypes.StringType), + tuples) + .withSoftCapped(); + + LanceScan scan = buildScanWithPartitionInfo(info); + scan.planInputPartitions(); + assertInstanceOf(UnknownPartitioning.class, scan.outputPartitioning()); + } + // --- equals / hashCode (required for ReusedExchange) --- @Test diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/PartitionInfoTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/PartitionInfoTest.java new file mode 100644 index 000000000..0f45ff42c --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/PartitionInfoTest.java @@ -0,0 +1,321 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.unsafe.types.UTF8String; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for {@link ZonemapFragmentPruner.PartitionInfo} covering the multi-column refactor. + */ +public class PartitionInfoTest { + + private static Map[]> tuples(Object[]... entries) { + Map[]> out = new HashMap<>(); + for (int i = 0; i < entries.length; i++) { + Comparable[] tuple = new Comparable[entries[i].length]; + for (int j = 0; j < entries[i].length; j++) { + tuple[j] = (Comparable) entries[i][j]; + } + out.put(i, tuple); + } + return out; + } + + private static final List STRING_LONG = + Arrays.asList(DataTypes.StringType, DataTypes.LongType); + + private static final List STRING_ONLY = Collections.singletonList(DataTypes.StringType); + + @Test + public void rejectsEmptyColumnNames() { + assertThrows( + IllegalArgumentException.class, + () -> + new ZonemapFragmentPruner.PartitionInfo( + Collections.emptyList(), Collections.emptyList(), new HashMap<>())); + } + + @Test + public void rejectsDuplicateColumnNames() { + assertThrows( + IllegalArgumentException.class, + () -> + new ZonemapFragmentPruner.PartitionInfo( + Arrays.asList("a", "a"), + Arrays.asList(DataTypes.StringType, DataTypes.StringType), + tuples(new Object[] {"x", "y"}))); + } + + @Test + public void rejectsTupleWidthMismatch() { + Map[]> bad = new HashMap<>(); + bad.put(0, new Comparable[] {"x"}); // expects width 2 + assertThrows( + IllegalArgumentException.class, + () -> new ZonemapFragmentPruner.PartitionInfo(Arrays.asList("a", "b"), STRING_LONG, bad)); + } + + @Test + public void rejectsColumnTypesSizeMismatch() { + assertThrows( + IllegalArgumentException.class, + () -> + new ZonemapFragmentPruner.PartitionInfo( + Arrays.asList("a", "b"), + Collections.singletonList(DataTypes.StringType), + tuples(new Object[] {"x", 1L}))); + } + + @Test + public void constructorDefensivelyCopiesTuples() { + Comparable[] tuple = new Comparable[] {"east", 2024L}; + Map[]> input = new HashMap<>(); + input.put(0, tuple); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + Arrays.asList("region", "year"), STRING_LONG, input); + tuple[0] = "west"; // mutate caller's array + assertEquals("east", info.getFragmentPartitionKeys().get(0)[0]); + } + + @Test + public void getFragmentPartitionKeysIsUnmodifiable() { + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + Collections.singletonList("region"), STRING_ONLY, tuples(new Object[] {"east"})); + assertThrows( + UnsupportedOperationException.class, + () -> info.getFragmentPartitionKeys().put(1, new Comparable[] {"west"})); + } + + @Test + public void partitionKeyForFragmentMultiColumn() { + Map[]> map = new HashMap<>(); + map.put(7, new Comparable[] {"us", 2024L}); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo(Arrays.asList("region", "year"), STRING_LONG, map); + + InternalRow row = info.partitionKeyForFragment(7); + assertEquals(2, row.numFields()); + assertEquals(UTF8String.fromString("us"), row.get(0, DataTypes.StringType)); + assertEquals(2024L, row.getLong(1)); + } + + @Test + public void partitionKeyForMissingFragmentReturnsNullRow() { + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + Arrays.asList("a", "b"), STRING_LONG, tuples(new Object[] {"x", 1L})); + InternalRow row = info.partitionKeyForFragment(999); + assertEquals(2, row.numFields()); + assertTrue(row.isNullAt(0)); + assertTrue(row.isNullAt(1)); + } + + @Test + public void forSingleColumnMatchesListForm() { + Map> scalarMap = new HashMap<>(); + scalarMap.put(0, "east"); + scalarMap.put(1, "west"); + ZonemapFragmentPruner.PartitionInfo factory = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", DataTypes.StringType, scalarMap); + + Map[]> listMap = new HashMap<>(); + listMap.put(0, new Comparable[] {"east"}); + listMap.put(1, new Comparable[] {"west"}); + ZonemapFragmentPruner.PartitionInfo direct = + new ZonemapFragmentPruner.PartitionInfo( + Collections.singletonList("region"), STRING_ONLY, listMap); + + assertEquals(direct.getColumnNames(), factory.getColumnNames()); + assertEquals(direct.size(), factory.size()); + // partitionKeyForFragment output must match for every fragment id. + for (int fragId : new int[] {0, 1}) { + InternalRow a = direct.partitionKeyForFragment(fragId); + InternalRow b = factory.partitionKeyForFragment(fragId); + assertEquals(a.numFields(), b.numFields()); + assertEquals(a.get(0, DataTypes.StringType), b.get(0, DataTypes.StringType)); + } + } + + @Test + public void restrictToSubsetsFragments() { + Map[]> m = new HashMap<>(); + m.put(0, new Comparable[] {"us", 2024L}); + m.put(1, new Comparable[] {"us", 2025L}); + m.put(2, new Comparable[] {"eu", 2024L}); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo(Arrays.asList("region", "year"), STRING_LONG, m); + ZonemapFragmentPruner.PartitionInfo narrowed = + info.restrictTo(new HashSet<>(Arrays.asList(0, 2))); + assertNotSame(info, narrowed); + assertEquals(2, narrowed.size()); + assertEquals(3, info.size()); // original unchanged + assertTrue(narrowed.getFragmentPartitionKeys().containsKey(0)); + assertTrue(narrowed.getFragmentPartitionKeys().containsKey(2)); + assertFalse(narrowed.getFragmentPartitionKeys().containsKey(1)); + // Column types survive restriction. + assertEquals(STRING_LONG, narrowed.getColumnTypes()); + } + + @Test + public void withSoftCappedCarriesFlagAndTypes() { + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + Collections.singletonList("a"), STRING_ONLY, tuples(new Object[] {"x"})); + assertFalse(info.isSoftCapped()); + ZonemapFragmentPruner.PartitionInfo capped = info.withSoftCapped(); + assertTrue(capped.isSoftCapped()); + // Original untouched. + assertFalse(info.isSoftCapped()); + // Data and types preserved. + assertEquals(info.getColumnNames(), capped.getColumnNames()); + assertEquals(info.getColumnTypes(), capped.getColumnTypes()); + assertEquals(info.size(), capped.size()); + } + + @Test + public void javaSerializationRoundTrip() throws Exception { + Map[]> m = new HashMap<>(); + m.put(0, new Comparable[] {"us", 2024L}); + m.put(1, new Comparable[] {"eu", 2025L}); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo( + Arrays.asList("region", "year"), STRING_LONG, m, true); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(info); + } + ZonemapFragmentPruner.PartitionInfo restored; + try (ObjectInputStream ois = + new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()))) { + restored = (ZonemapFragmentPruner.PartitionInfo) ois.readObject(); + } + + assertEquals(Arrays.asList("region", "year"), restored.getColumnNames()); + assertEquals(STRING_LONG, restored.getColumnTypes()); + assertEquals(2, restored.size()); + assertTrue(restored.isSoftCapped()); + assertArrayEquals(new Object[] {"us", 2024L}, restored.getFragmentPartitionKeys().get(0)); + } + + @Test + public void columnNamesAreImmutableView() { + List names = new java.util.ArrayList<>(Arrays.asList("a", "b")); + ZonemapFragmentPruner.PartitionInfo info = + new ZonemapFragmentPruner.PartitionInfo(names, STRING_LONG, tuples(new Object[] {"x", 1L})); + names.add("c"); // mutate caller's list after construction + assertEquals(Arrays.asList("a", "b"), info.getColumnNames()); + assertThrows(UnsupportedOperationException.class, () -> info.getColumnNames().add("c")); + } + + // --- Type-aware narrowing (ZoneStats returns Long for every integral Arrow width) --- + + @Test + public void byteColumnNarrowsLongToByte() { + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "b", DataTypes.ByteType, Collections.singletonMap(0, 5L)); + InternalRow row = info.partitionKeyForFragment(0); + assertEquals((byte) 5, row.getByte(0)); + } + + @Test + public void shortColumnNarrowsLongToShort() { + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "s", DataTypes.ShortType, Collections.singletonMap(0, 1234L)); + assertEquals((short) 1234, info.partitionKeyForFragment(0).getShort(0)); + } + + @Test + public void intColumnNarrowsLongToInt() { + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "i", DataTypes.IntegerType, Collections.singletonMap(0, 100_000L)); + assertEquals(100_000, info.partitionKeyForFragment(0).getInt(0)); + } + + @Test + public void dateColumnEncodesAsEpochDaysInt() { + // ZoneStats returns epoch-days as Long (e.g. 19737 == 2024-01-15); Spark's InternalRow + // for DateType holds an int. Narrow without loss of information. + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "d", DataTypes.DateType, Collections.singletonMap(0, 19737L)); + assertEquals(19737, info.partitionKeyForFragment(0).getInt(0)); + } + + @Test + public void timestampColumnEncodesAsEpochMicrosLong() { + // ZoneStats returns epoch-micros as Long; Spark's InternalRow for TimestampType holds long. + long micros = 1_705_276_800_000_000L; + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "t", DataTypes.TimestampType, Collections.singletonMap(0, micros)); + assertEquals(micros, info.partitionKeyForFragment(0).getLong(0)); + } + + @Test + public void booleanColumnPassesThrough() { + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "b", DataTypes.BooleanType, Collections.singletonMap(0, Boolean.TRUE)); + assertTrue(info.partitionKeyForFragment(0).getBoolean(0)); + } + + @Test + public void stringColumnWrapsAsUtf8String() { + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "r", DataTypes.StringType, Collections.singletonMap(0, "east")); + assertEquals( + UTF8String.fromString("east"), + info.partitionKeyForFragment(0).get(0, DataTypes.StringType)); + } + + @Test + public void unsupportedPartitionTypeThrowsAtEncodeTime() { + // If detection is bypassed and a non-whitelisted type reaches toSparkValue, we must fail loud + // rather than hand Spark a slot that silently contains the wrong Java class. + ZonemapFragmentPruner.PartitionInfo info = + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "d", DataTypes.DoubleType, Collections.singletonMap(0, 1.5)); + assertThrows(IllegalArgumentException.class, () -> info.partitionKeyForFragment(0)); + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/SparkVersionUtilTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/SparkVersionUtilTest.java new file mode 100644 index 000000000..b227e2f26 --- /dev/null +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/SparkVersionUtilTest.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.read; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for {@link SparkVersionUtil#supportsMultiKeySpj(String)} — the pure-function overload + * that the allowlist logic delegates to. Two cases: strings the allowlist accepts, and strings it + * rejects (including malformed input). + */ +public class SparkVersionUtilTest { + + @Test + public void acceptsAllowlistedVersions() { + String[] accepted = { + // Spark 3.5.x is the minimum version that reliably honors multi-key KGP. + "3.5.0", + "3.5.1", + "3.5.10", + // 3.5.x with snapshot / rc / vendor suffix: startsWith("3.5.") still matches. + "3.5.0-SNAPSHOT", + "3.5.1-rc1", + "3.5.2-databricks", + // Any 4.x+ build is accepted via the major-version branch. + "4.0.0", + "4.1.0", + "4.0.0-preview", + "5.0.0", + "10.0.0", + }; + for (String v : accepted) { + assertTrue(SparkVersionUtil.supportsMultiKeySpj(v), "expected accepted: " + v); + } + } + + @Test + public void rejectsUnsupportedAndMalformed() { + String[] rejected = { + // Pre-3.5 Spark lines: 3.4.x and earlier don't reliably honor multi-key KGP. + "3.4.0", + "3.4.1", + "3.4.0-preview", + "3.3.0", + "3.0.0", + "2.4.8", + // 3.6+ through 3.x: conservatively rejected until the allowlist is explicitly updated. + "3.6.0", + "3.9.9", + // null / empty / no-dot / leading-dot / non-numeric major: all unparseable → reject. + null, + "", + "3", + "custom", + ".5.0", + "vX.0.0", + "custom-fork.1.0", + }; + for (String v : rejected) { + assertFalse(SparkVersionUtil.supportsMultiKeySpj(v), "expected rejected: " + v); + } + } +} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/ZonemapFragmentPrunerTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/ZonemapFragmentPrunerTest.java index 35678a03d..3772840d5 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/read/ZonemapFragmentPrunerTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/read/ZonemapFragmentPrunerTest.java @@ -563,7 +563,8 @@ public void testPartitionKeyForFragmentString() { values.put(0, "east"); values.put(1, "west"); ZonemapFragmentPruner.PartitionInfo info = - new ZonemapFragmentPruner.PartitionInfo("region", values); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", org.apache.spark.sql.types.DataTypes.StringType, values); InternalRow row0 = info.partitionKeyForFragment(0); assertNotNull(row0); @@ -583,7 +584,8 @@ public void testPartitionKeyForFragmentLong() { values.put(0, 2023L); values.put(1, 2024L); ZonemapFragmentPruner.PartitionInfo info = - new ZonemapFragmentPruner.PartitionInfo("year", values); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "year", org.apache.spark.sql.types.DataTypes.LongType, values); InternalRow row0 = info.partitionKeyForFragment(0); assertEquals(2023L, row0.getLong(0)); @@ -597,7 +599,8 @@ public void testPartitionKeyForMissingFragment() { Map> values = new HashMap<>(); values.put(0, "east"); ZonemapFragmentPruner.PartitionInfo info = - new ZonemapFragmentPruner.PartitionInfo("region", values); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", org.apache.spark.sql.types.DataTypes.StringType, values); InternalRow row = info.partitionKeyForFragment(99); assertNotNull(row); @@ -610,7 +613,8 @@ public void testPartitionInfoIsSerializable() throws Exception { values.put(0, "east"); values.put(1, "west"); ZonemapFragmentPruner.PartitionInfo info = - new ZonemapFragmentPruner.PartitionInfo("region", values); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", org.apache.spark.sql.types.DataTypes.StringType, values); java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); java.io.ObjectOutputStream oos = new java.io.ObjectOutputStream(baos); @@ -622,9 +626,9 @@ public void testPartitionInfoIsSerializable() throws Exception { ZonemapFragmentPruner.PartitionInfo deserialized = (ZonemapFragmentPruner.PartitionInfo) ois.readObject(); - assertEquals("region", deserialized.getColumnName()); - assertEquals("east", deserialized.getFragmentPartitionValues().get(0)); - assertEquals("west", deserialized.getFragmentPartitionValues().get(1)); + assertEquals(java.util.Collections.singletonList("region"), deserialized.getColumnNames()); + assertEquals("east", deserialized.getFragmentPartitionKeys().get(0)[0]); + assertEquals("west", deserialized.getFragmentPartitionKeys().get(1)[0]); } @Test @@ -632,10 +636,11 @@ public void testPartitionInfoImmutableMap() { Map> values = new HashMap<>(); values.put(0, "east"); ZonemapFragmentPruner.PartitionInfo info = - new ZonemapFragmentPruner.PartitionInfo("region", values); + ZonemapFragmentPruner.PartitionInfo.forSingleColumn( + "region", org.apache.spark.sql.types.DataTypes.StringType, values); assertThrows( UnsupportedOperationException.class, - () -> info.getFragmentPartitionValues().put(1, "west")); + () -> info.getFragmentPartitionKeys().put(1, new Comparable[] {"west"})); } }