diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 0b76517e1c2..4203d099d0b 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -5,8 +5,7 @@ //! the corresponding IVF partitions. use std::ops::Range; -use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use arrow::compute::concat_batches; use arrow::datatypes::UInt64Type; @@ -14,8 +13,7 @@ use arrow::{array::AsArray, compute::sort_to_indices}; use arrow_array::{RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use futures::{future::try_join_all, prelude::*}; -use lance_arrow::stream::rechunk_stream_by_size; -use lance_arrow::{RecordBatchExt, SchemaExt}; +use lance_arrow::{RecordBatchExt, SchemaExt, interleave_batches}; use lance_core::{ Error, Result, cache::LanceCache, @@ -341,6 +339,11 @@ pub fn create_ivf_shuffler( const DEFAULT_SHUFFLE_BATCH_BYTES: usize = 128 * 1024 * 1024; +/// Number of rows per output batch when streaming sorted data via interleave. +/// Small enough to keep the output chunk's memory footprint modest relative to +/// the accumulated source data. +const SHUFFLE_WRITE_CHUNK_ROWS: usize = 8 * 1024; + /// Limit of how much transformed data we accumulate before spilling to disk. /// /// A larger value will use more RAM but require less random access during the @@ -407,6 +410,51 @@ impl TwoFileShuffler { } } +/// `(batch_idx, row_idx)` pairs produced by [`sort_to_interleave_indices`], paired with +/// per-partition row counts. +type InterleaveResult = (Vec<(usize, usize)>, Vec); + +/// Sorts rows from multiple batches by partition ID and returns interleave indices. +/// +/// Builds a sort key of `(part_id, batch_idx, row_idx)` for every row across all +/// batches, sorts by `part_id`, then emits `(batch_idx, row_idx)` pairs in that +/// order. This avoids concatenating the full data: only the `UInt32` partition-ID +/// columns are touched here. +/// +/// Also returns per-partition row counts (derived from the same sorted keys at no +/// extra cost). +/// +/// Returns an error if any partition ID is out of range `[0, num_partitions)`. +fn sort_to_interleave_indices( + part_id_columns: &[&UInt32Array], + num_partitions: usize, +) -> Result { + let total_rows: usize = part_id_columns.iter().map(|a| a.len()).sum(); + let mut keys: Vec<(u32, u32, u32)> = Vec::with_capacity(total_rows); + for (batch_idx, col) in part_id_columns.iter().enumerate() { + let batch_idx = batch_idx as u32; + for (row_idx, &part_id) in col.values().iter().enumerate() { + keys.push((part_id, batch_idx, row_idx as u32)); + } + } + keys.sort_unstable_by_key(|k| k.0); + + let mut partition_counts = vec![0u64; num_partitions]; + let mut interleave_indices = Vec::with_capacity(total_rows); + for (part_id, batch_idx, row_idx) in &keys { + let pid = *part_id as usize; + if pid >= num_partitions { + return Err(Error::invalid_input(format!( + "partition ID {} is out of range [0, {})", + pid, num_partitions + ))); + } + partition_counts[pid] += 1; + interleave_indices.push((*batch_idx as usize, *row_idx as usize)); + } + Ok((interleave_indices, partition_counts)) +} + #[async_trait::async_trait] impl Shuffler for TwoFileShuffler { async fn shuffle( @@ -414,8 +462,7 @@ impl Shuffler for TwoFileShuffler { data: Box, ) -> Result> { let num_partitions = self.num_partitions; - let full_schema = Arc::new(data.schema().as_ref().clone()); - // No need to write partition ids since we can infer this + // No need to write partition ids since we can infer this from offsets let schema = data.schema().without_column(PART_ID_COLUMN); let offsets_schema = Arc::new(Schema::new(vec![Field::new( "offset", @@ -424,28 +471,6 @@ impl Shuffler for TwoFileShuffler { )])); let batch_size_bytes = self.batch_size_bytes; - // Extract loss from batch metadata before rechunking (concat_batches drops metadata) - let total_loss = Arc::new(Mutex::new(0.0f64)); - let loss_ref = total_loss.clone(); - let loss_stream = data.map(move |result| { - result.inspect(|batch| { - let loss = batch - .metadata() - .get(LOSS_METADATA_KEY) - .and_then(|s| s.parse::().ok()) - .unwrap_or(0.0); - *loss_ref.lock().unwrap() += loss; - }) - }); - - // Rechunk to target batch size - let rechunked = rechunk_stream_by_size( - loss_stream, - full_schema, - batch_size_bytes, - batch_size_bytes * 2, - ); - // Create data file writer let data_path = self.output_dir.clone().join("shuffle_data.lance"); let spill_path = self.output_dir.clone().join("shuffle_data.spill"); @@ -468,72 +493,63 @@ impl Shuffler for TwoFileShuffler { )? .with_page_metadata_spill(self.object_store.clone(), spill_path); - let num_batches = Arc::new(AtomicU64::new(0)); - let num_batches_ref = num_batches.clone(); + let mut num_batches: u64 = 0; let mut partition_counts: Vec = vec![0; num_partitions]; let mut global_row_count: u64 = 0; let mut rows_processed: u64 = 0; + let mut total_loss = 0.0f64; + let mut accumulated: Vec = Vec::new(); + let mut acc_bytes: usize = 0; - let mut rechunked = std::pin::pin!(rechunked); - while let Some(batch) = rechunked.next().await { - num_batches_ref.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let mut data = std::pin::pin!(data); + while let Some(batch) = data.next().await { let batch = batch?; - let np = num_partitions; - let num_rows = batch.num_rows() as u64; - - // Sort by partition ID and compute offsets on CPU - let (sorted_batch, batch_offsets) = spawn_cpu(move || { - let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive(); - let indices = sort_to_indices(part_ids, None, None)?; - let batch = batch.take(&indices)?; - - let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive(); - let batch = batch.drop_column(PART_ID_COLUMN)?; - - // Count rows per partition by scanning sorted part IDs - let mut partition_counts = vec![0u64; np]; - for i in 0..part_ids.len() { - let pid = part_ids.value(i) as usize; - if pid < np { - partition_counts[pid] += 1; - } else { - log::warn!("Partition ID {} is out of range [0, {})", pid, np); - } - } - - // Build cumulative offsets (end positions) for this batch - let mut batch_offsets = Vec::with_capacity(np); - let mut running = 0u64; - for count in &partition_counts { - running += count; - batch_offsets.push(running); + total_loss += batch + .metadata() + .get(LOSS_METADATA_KEY) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0); + acc_bytes += batch.get_array_memory_size(); + accumulated.push(batch); + + if acc_bytes >= batch_size_bytes { + let (total_rows, counts) = flush_shuffle_batch( + std::mem::take(&mut accumulated), + &mut file_writer, + &mut offsets_writer, + offsets_schema.clone(), + num_partitions, + global_row_count, + ) + .await?; + acc_bytes = 0; + for (p, c) in counts.iter().enumerate() { + partition_counts[p] += c; } + global_row_count += total_rows; + rows_processed += total_rows; + num_batches += 1; + self.progress + .stage_progress("shuffle", rows_processed) + .await?; + } + } - Ok::<(RecordBatch, Vec), Error>((batch, batch_offsets)) - }) + if !accumulated.is_empty() { + let (total_rows, counts) = flush_shuffle_batch( + accumulated, + &mut file_writer, + &mut offsets_writer, + offsets_schema, + num_partitions, + global_row_count, + ) .await?; - - // Write sorted batch to data file - file_writer.write_batch(&sorted_batch).await?; - - // Record offsets adjusted by global row count - let mut adjusted_offsets = Vec::with_capacity(batch_offsets.len()); - let mut last_offset = 0; - for (idx, offset) in batch_offsets.iter().enumerate() { - adjusted_offsets.push(global_row_count + offset); - partition_counts[idx] += offset - last_offset; - last_offset = *offset; + for (p, c) in counts.iter().enumerate() { + partition_counts[p] += c; } - global_row_count += sorted_batch.num_rows() as u64; - - // Write offsets to offsets file - let offsets_batch = RecordBatch::try_new( - offsets_schema.clone(), - vec![Arc::new(UInt64Array::from(adjusted_offsets))], - )?; - offsets_writer.write_batch(&offsets_batch).await?; - - rows_processed += num_rows; + rows_processed += total_rows; + num_batches += 1; self.progress .stage_progress("shuffle", rows_processed) .await?; @@ -543,22 +559,76 @@ impl Shuffler for TwoFileShuffler { file_writer.finish().await?; offsets_writer.finish().await?; - let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed); - - let total_loss_val = *total_loss.lock().unwrap(); - TwoFileShuffleReader::try_new( self.object_store.clone(), self.output_dir.clone(), num_partitions, num_batches, partition_counts, - total_loss_val, + total_loss, ) .await } } +/// Sorts `accumulated` batches by partition ID and writes the result to the data +/// and offsets files. +/// +/// Returns `(total_rows_written, per_partition_row_counts)`. +async fn flush_shuffle_batch( + accumulated: Vec, + file_writer: &mut FileWriter, + offsets_writer: &mut FileWriter, + offsets_schema: Arc, + num_partitions: usize, + global_row_count: u64, +) -> Result<(u64, Vec)> { + let total_rows: u64 = accumulated.iter().map(|b| b.num_rows() as u64).sum(); + + // Clone part-id columns into the CPU task (cheap: Arc ref bump, not data copy). + let part_id_cols: Vec = accumulated + .iter() + .map(|b| { + let col: &UInt32Array = b[PART_ID_COLUMN].as_primitive(); + col.clone() + }) + .collect(); + + let np = num_partitions; + let (interleave_indices, batch_partition_counts) = + spawn_cpu(move || sort_to_interleave_indices(&part_id_cols.iter().collect::>(), np)) + .await?; + + // Drop part-id column from source batches before interleaving. + let source_batches: Vec = accumulated + .into_iter() + .map(|b| b.drop_column(PART_ID_COLUMN).map_err(Error::from)) + .collect::>()?; + + // Stream sorted output to the data file in fixed-size chunks so the peak + // memory for the interleave output stays small relative to the source data. + for chunk in interleave_indices.chunks(SHUFFLE_WRITE_CHUNK_ROWS) { + let out = interleave_batches(&source_batches, chunk)?; + file_writer.write_batch(&out).await?; + } + + // Compute cumulative end-row offsets (adjusted by global position) and write + // one offsets batch for this flush group. + let mut adjusted_offsets = Vec::with_capacity(num_partitions); + let mut running = 0u64; + for count in &batch_partition_counts { + running += count; + adjusted_offsets.push(global_row_count + running); + } + let offsets_batch = RecordBatch::try_new( + offsets_schema, + vec![Arc::new(UInt64Array::from(adjusted_offsets))], + )?; + offsets_writer.write_batch(&offsets_batch).await?; + + Ok((total_rows, batch_partition_counts)) +} + pub struct TwoFileShuffleReader { _scheduler: Arc, file_reader: FileReader, @@ -934,4 +1004,65 @@ mod tests { assert!((reader.total_loss().unwrap() - 6.0).abs() < 1e-10); } + + #[tokio::test] + async fn test_two_file_shuffler_multi_batch_single_flush() { + // All three batches fit within the default batch_size_bytes, so they + // accumulate and are interleaved in a single flush group. This exercises + // the cross-batch interleave path. + let dir = TempStrDir::default(); + let output_dir = Path::from(dir.as_ref()); + let num_partitions = 3; + + let batch1 = make_batch(&[0, 1, 2], &[10, 20, 30], None); + let batch2 = make_batch(&[2, 0, 1], &[40, 50, 60], None); + let batch3 = make_batch(&[1, 2, 0], &[70, 80, 90], None); + + // Large batch_size_bytes so all three batches flush together. + let shuffler = + TwoFileShuffler::new(output_dir, num_partitions).with_batch_size_bytes(1024 * 1024); + let stream = batches_to_stream(vec![batch1, batch2, batch3]); + let reader = shuffler.shuffle(stream).await.unwrap(); + + assert_eq!(reader.partition_size(0).unwrap(), 3); + assert_eq!(reader.partition_size(1).unwrap(), 3); + assert_eq!(reader.partition_size(2).unwrap(), 3); + + let p0 = collect_partition(reader.as_ref(), 0).await.unwrap(); + let vals: &Int32Array = p0.column_by_name("val").unwrap().as_primitive(); + let mut v: Vec = vals.iter().map(|x| x.unwrap()).collect(); + v.sort(); + assert_eq!(v, vec![10, 50, 90]); + + let p1 = collect_partition(reader.as_ref(), 1).await.unwrap(); + let vals: &Int32Array = p1.column_by_name("val").unwrap().as_primitive(); + let mut v: Vec = vals.iter().map(|x| x.unwrap()).collect(); + v.sort(); + assert_eq!(v, vec![20, 60, 70]); + + let p2 = collect_partition(reader.as_ref(), 2).await.unwrap(); + let vals: &Int32Array = p2.column_by_name("val").unwrap().as_primitive(); + let mut v: Vec = vals.iter().map(|x| x.unwrap()).collect(); + v.sort(); + assert_eq!(v, vec![30, 40, 80]); + } + + #[tokio::test] + async fn test_two_file_shuffler_out_of_range_partition_id() { + let dir = TempStrDir::default(); + let output_dir = Path::from(dir.as_ref()); + + // Row with partition ID 5 is out of range for num_partitions=3. + let batch = make_batch(&[0, 5, 1], &[10, 20, 30], None); + + let shuffler = TwoFileShuffler::new(output_dir, 3); + let stream = batches_to_stream(vec![batch]); + let Err(err) = shuffler.shuffle(stream).await else { + panic!("expected an error for out-of-range partition ID"); + }; + assert!( + err.to_string().contains("partition ID 5 is out of range"), + "unexpected error: {err}" + ); + } }