Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 160 additions & 90 deletions rust/lance-index/src/vector/v3/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
//! the corresponding IVF partitions.

use std::ops::Range;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use arrow::compute::concat_batches;
use arrow::datatypes::UInt64Type;
use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use futures::{future::try_join_all, prelude::*};
use lance_arrow::stream::rechunk_stream_by_size;
use lance_arrow::{RecordBatchExt, SchemaExt};
use lance_arrow::{RecordBatchExt, SchemaExt, interleave_batches};
use lance_core::{
Error, Result,
cache::LanceCache,
Expand Down Expand Up @@ -341,6 +339,11 @@ pub fn create_ivf_shuffler(

const DEFAULT_SHUFFLE_BATCH_BYTES: usize = 128 * 1024 * 1024;

/// Number of rows per output batch when streaming sorted data via interleave.
/// Small enough to keep the output chunk's memory footprint modest relative to
/// the accumulated source data.
const SHUFFLE_WRITE_CHUNK_ROWS: usize = 8 * 1024;

/// Limit of how much transformed data we accumulate before spilling to disk.
///
/// A larger value will use more RAM but require less random access during the
Expand Down Expand Up @@ -407,15 +410,55 @@ impl TwoFileShuffler {
}
}

/// Sorts rows from multiple batches by partition ID and returns interleave indices.
///
/// Builds a sort key of `(part_id, batch_idx, row_idx)` for every row across all
/// batches, sorts by `part_id`, then emits `(batch_idx, row_idx)` pairs in that
/// order. This avoids concatenating the full data: only the `UInt32` partition-ID
/// columns are touched here.
///
/// Also returns per-partition row counts (derived from the same sorted keys at no
/// extra cost).
fn sort_to_interleave_indices(
part_id_columns: &[&UInt32Array],
num_partitions: usize,
) -> (Vec<(usize, usize)>, Vec<u64>) {
let total_rows: usize = part_id_columns.iter().map(|a| a.len()).sum();
let mut keys: Vec<(u32, u32, u32)> = Vec::with_capacity(total_rows);
for (batch_idx, col) in part_id_columns.iter().enumerate() {
let batch_idx = batch_idx as u32;
for (row_idx, &part_id) in col.values().iter().enumerate() {
keys.push((part_id, batch_idx, row_idx as u32));
}
}
keys.sort_unstable_by_key(|k| k.0);

let mut partition_counts = vec![0u64; num_partitions];
let mut interleave_indices = Vec::with_capacity(total_rows);
for (part_id, batch_idx, row_idx) in &keys {
Comment on lines +432 to +444
let pid = *part_id as usize;
if pid < num_partitions {
partition_counts[pid] += 1;
} else {
log::warn!(
"Partition ID {} is out of range [0, {})",
pid,
num_partitions
);
}
interleave_indices.push((*batch_idx as usize, *row_idx as usize));
Comment on lines +445 to +453
}
(interleave_indices, partition_counts)
}

#[async_trait::async_trait]
impl Shuffler for TwoFileShuffler {
async fn shuffle(
&self,
data: Box<dyn RecordBatchStream + Unpin + 'static>,
) -> Result<Box<dyn ShuffleReader>> {
let num_partitions = self.num_partitions;
let full_schema = Arc::new(data.schema().as_ref().clone());
// No need to write partition ids since we can infer this
// No need to write partition ids since we can infer this from offsets
let schema = data.schema().without_column(PART_ID_COLUMN);
let offsets_schema = Arc::new(Schema::new(vec![Field::new(
"offset",
Expand All @@ -424,28 +467,6 @@ impl Shuffler for TwoFileShuffler {
)]));
let batch_size_bytes = self.batch_size_bytes;

// Extract loss from batch metadata before rechunking (concat_batches drops metadata)
let total_loss = Arc::new(Mutex::new(0.0f64));
let loss_ref = total_loss.clone();
let loss_stream = data.map(move |result| {
result.inspect(|batch| {
let loss = batch
.metadata()
.get(LOSS_METADATA_KEY)
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(0.0);
*loss_ref.lock().unwrap() += loss;
})
});

// Rechunk to target batch size
let rechunked = rechunk_stream_by_size(
loss_stream,
full_schema,
batch_size_bytes,
batch_size_bytes * 2,
);

// Create data file writer
let data_path = self.output_dir.clone().join("shuffle_data.lance");
let spill_path = self.output_dir.clone().join("shuffle_data.spill");
Expand All @@ -468,72 +489,63 @@ impl Shuffler for TwoFileShuffler {
)?
.with_page_metadata_spill(self.object_store.clone(), spill_path);

let num_batches = Arc::new(AtomicU64::new(0));
let num_batches_ref = num_batches.clone();
let mut num_batches: u64 = 0;
let mut partition_counts: Vec<u64> = vec![0; num_partitions];
let mut global_row_count: u64 = 0;
let mut rows_processed: u64 = 0;
let mut total_loss = 0.0f64;
let mut accumulated: Vec<RecordBatch> = Vec::new();
let mut acc_bytes: usize = 0;

let mut rechunked = std::pin::pin!(rechunked);
while let Some(batch) = rechunked.next().await {
num_batches_ref.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let mut data = std::pin::pin!(data);
while let Some(batch) = data.next().await {
Comment on lines +500 to +505

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an incoming batch blows past memory limits, we can't do much about it. It's already in-memory. Slicing it and making a smaller copy just increases memory temporarily, so it's not helpful. If input chunks are too large, we need instead a way to tell the reader to use smaller chunks.

let batch = batch?;
let np = num_partitions;
let num_rows = batch.num_rows() as u64;

// Sort by partition ID and compute offsets on CPU
let (sorted_batch, batch_offsets) = spawn_cpu(move || {
let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let indices = sort_to_indices(part_ids, None, None)?;
let batch = batch.take(&indices)?;

let part_ids: &UInt32Array = batch[PART_ID_COLUMN].as_primitive();
let batch = batch.drop_column(PART_ID_COLUMN)?;

// Count rows per partition by scanning sorted part IDs
let mut partition_counts = vec![0u64; np];
for i in 0..part_ids.len() {
let pid = part_ids.value(i) as usize;
if pid < np {
partition_counts[pid] += 1;
} else {
log::warn!("Partition ID {} is out of range [0, {})", pid, np);
}
}

// Build cumulative offsets (end positions) for this batch
let mut batch_offsets = Vec::with_capacity(np);
let mut running = 0u64;
for count in &partition_counts {
running += count;
batch_offsets.push(running);
total_loss += batch
.metadata()
.get(LOSS_METADATA_KEY)
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(0.0);
acc_bytes += batch.get_array_memory_size();
accumulated.push(batch);

if acc_bytes >= batch_size_bytes {
let (total_rows, counts) = flush_shuffle_batch(
std::mem::take(&mut accumulated),
&mut file_writer,
&mut offsets_writer,
offsets_schema.clone(),
num_partitions,
global_row_count,
)
.await?;
acc_bytes = 0;
for (p, c) in counts.iter().enumerate() {
partition_counts[p] += c;
}
global_row_count += total_rows;
rows_processed += total_rows;
num_batches += 1;
self.progress
.stage_progress("shuffle", rows_processed)
.await?;
}
}

Ok::<(RecordBatch, Vec<u64>), Error>((batch, batch_offsets))
})
if !accumulated.is_empty() {
let (total_rows, counts) = flush_shuffle_batch(
accumulated,
&mut file_writer,
&mut offsets_writer,
offsets_schema,
num_partitions,
global_row_count,
)
.await?;

// Write sorted batch to data file
file_writer.write_batch(&sorted_batch).await?;

// Record offsets adjusted by global row count
let mut adjusted_offsets = Vec::with_capacity(batch_offsets.len());
let mut last_offset = 0;
for (idx, offset) in batch_offsets.iter().enumerate() {
adjusted_offsets.push(global_row_count + offset);
partition_counts[idx] += offset - last_offset;
last_offset = *offset;
for (p, c) in counts.iter().enumerate() {
partition_counts[p] += c;
}
global_row_count += sorted_batch.num_rows() as u64;

// Write offsets to offsets file
let offsets_batch = RecordBatch::try_new(
offsets_schema.clone(),
vec![Arc::new(UInt64Array::from(adjusted_offsets))],
)?;
offsets_writer.write_batch(&offsets_batch).await?;

rows_processed += num_rows;
rows_processed += total_rows;
num_batches += 1;
self.progress
.stage_progress("shuffle", rows_processed)
.await?;
Expand All @@ -543,22 +555,80 @@ impl Shuffler for TwoFileShuffler {
file_writer.finish().await?;
offsets_writer.finish().await?;

let num_batches = num_batches.load(std::sync::atomic::Ordering::Relaxed);

let total_loss_val = *total_loss.lock().unwrap();

TwoFileShuffleReader::try_new(
self.object_store.clone(),
self.output_dir.clone(),
num_partitions,
num_batches,
partition_counts,
total_loss_val,
total_loss,
)
.await
}
}

/// Sorts `accumulated` batches by partition ID and writes the result to the data
/// and offsets files.
///
/// Returns `(total_rows_written, per_partition_row_counts)`.
async fn flush_shuffle_batch(
accumulated: Vec<RecordBatch>,
file_writer: &mut FileWriter,
offsets_writer: &mut FileWriter,
offsets_schema: Arc<Schema>,
Comment on lines +574 to +582
num_partitions: usize,
global_row_count: u64,
) -> Result<(u64, Vec<u64>)> {
let total_rows: u64 = accumulated.iter().map(|b| b.num_rows() as u64).sum();

// Clone part-id columns into the CPU task (cheap: Arc ref bump, not data copy).
let part_id_cols: Vec<UInt32Array> = accumulated
.iter()
.map(|b| {
let col: &UInt32Array = b[PART_ID_COLUMN].as_primitive();
col.clone()
})
.collect();

let np = num_partitions;
let (interleave_indices, batch_partition_counts) = spawn_cpu(move || {
Ok::<_, Error>(sort_to_interleave_indices(
&part_id_cols.iter().collect::<Vec<_>>(),
np,
))
})
.await?;

// Drop part-id column from source batches before interleaving.
let source_batches: Vec<RecordBatch> = accumulated
.into_iter()
.map(|b| b.drop_column(PART_ID_COLUMN).map_err(Error::from))
.collect::<Result<_>>()?;

// Stream sorted output to the data file in fixed-size chunks so the peak
// memory for the interleave output stays small relative to the source data.
for chunk in interleave_indices.chunks(SHUFFLE_WRITE_CHUNK_ROWS) {
let out = interleave_batches(&source_batches, chunk)?;
file_writer.write_batch(&out).await?;
}

// Compute cumulative end-row offsets (adjusted by global position) and write
// one offsets batch for this flush group.
let mut adjusted_offsets = Vec::with_capacity(num_partitions);
let mut running = 0u64;
for count in &batch_partition_counts {
running += count;
adjusted_offsets.push(global_row_count + running);
}
let offsets_batch = RecordBatch::try_new(
offsets_schema,
vec![Arc::new(UInt64Array::from(adjusted_offsets))],
)?;
offsets_writer.write_batch(&offsets_batch).await?;

Ok((total_rows, batch_partition_counts))
}

pub struct TwoFileShuffleReader {
_scheduler: Arc<ScanScheduler>,
file_reader: FileReader,
Expand Down
Loading