diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index f752ef0e68b..3f45ba2ac8b 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -15,7 +15,10 @@ use arrow::datatypes; use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use bitpacking::{BitPacker, BitPacker4x}; -use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion::execution::{ + RecordBatchStream, SendableRecordBatchStream, + memory_pool::{FairSpillPool, MemoryConsumer, MemoryPool, MemoryReservation}, +}; use deepsize::DeepSizeOf; use fst::Streamer; use futures::{Stream, StreamExt, TryStreamExt}; @@ -86,6 +89,7 @@ fn resolve_num_workers(params: &InvertedIndexParams) -> usize { .clamp(1, max_workers) } +#[cfg(test)] fn resolve_worker_memory_limit_bytes(params: &InvertedIndexParams, num_workers: usize) -> u64 { let default_worker_memory_limit_bytes = *LANCE_FTS_PARTITION_SIZE << 20; params @@ -94,6 +98,14 @@ fn resolve_worker_memory_limit_bytes(params: &InvertedIndexParams, num_workers: .unwrap_or(default_worker_memory_limit_bytes) } +/// Total memory budget for the entire FTS build (all workers combined). +fn resolve_total_memory_limit_bytes(params: &InvertedIndexParams, num_workers: usize) -> u64 { + params + .memory_limit_mb + .map(|mb| mb << 20) + .unwrap_or_else(|| (*LANCE_FTS_PARTITION_SIZE << 20) * num_workers as u64) +} + fn merge_all_tail_partitions(tails: Vec) -> Result> { if tails.is_empty() { return Ok(None); @@ -233,14 +245,15 @@ impl InvertedIndexBuilder { let num_workers = resolve_num_workers(&self.params); let tokenizer = self.params.build()?; let with_position = self.params.with_position; - let worker_memory_limit_bytes = - resolve_worker_memory_limit_bytes(&self.params, num_workers); + let total_memory_bytes = resolve_total_memory_limit_bytes(&self.params, num_workers); + let memory_pool: Arc = + Arc::new(FairSpillPool::new(total_memory_bytes.max(1) as usize)); let worker_config = IndexWorkerConfig { with_position, format_version: self.format_version, fragment_mask: self.fragment_mask, token_set_format: self.token_set_format, - worker_memory_limit_bytes, + memory_pool, }; let next_id = self.partitions.iter().map(|id| id + 1).max().unwrap_or(0); let id_alloc = Arc::new(AtomicU64::new(next_id)); @@ -255,6 +268,7 @@ impl InvertedIndexBuilder { let id_alloc = id_alloc.clone(); let progress = self.progress.clone(); let tokenized_count = tokenized_count.clone(); + let worker_config = worker_config.clone(); index_tasks.push(tokio::task::spawn(async move { let mut worker = IndexWorker::new(tokenizer, dest_store, id_alloc, worker_config).await?; @@ -856,7 +870,7 @@ struct IndexWorker { partitions: Vec, schema: SchemaRef, memory_size: u64, - worker_memory_limit_bytes: u64, + reservation: MemoryReservation, total_doc_length: usize, fragment_mask: Option, token_set_format: TokenSetFormat, @@ -873,13 +887,13 @@ struct WorkerOutput { tail_partition: Option, } -#[derive(Debug, Clone, Copy)] +#[derive(Clone)] struct IndexWorkerConfig { with_position: bool, format_version: InvertedListFormatVersion, fragment_mask: Option, token_set_format: TokenSetFormat, - worker_memory_limit_bytes: u64, + memory_pool: Arc, } impl IndexWorker { @@ -924,6 +938,7 @@ impl IndexWorker { config: IndexWorkerConfig, ) -> Result { let schema = inverted_list_schema_for_version(config.with_position, config.format_version); + let reservation = MemoryConsumer::new("FTSIndexWorker").register(&config.memory_pool); Ok(Self { tokenizer, @@ -939,7 +954,7 @@ impl IndexWorker { id_alloc, schema, memory_size: 0, - worker_memory_limit_bytes: config.worker_memory_limit_bytes, + reservation, total_doc_length: 0, fragment_mask: config.fragment_mask, token_set_format: config.token_set_format, @@ -1119,16 +1134,26 @@ impl IndexWorker { self.temporary_memory_size(), ); - if self.builder.docs.len() == 1 && self.memory_size > self.worker_memory_limit_bytes { - return Err(Error::invalid_input(format!( - "single document row_id={} exceeds worker memory limit: {} > {} bytes", - row_id, self.memory_size, self.worker_memory_limit_bytes - ))); + // Sync the memory pool reservation with the tracked memory usage. + // try_grow returning Err is the spill signal: flush the current partition. + let reserved = self.reservation.size() as u64; + if self.memory_size > reserved { + let delta = self.memory_size - reserved; + if self.reservation.try_grow(delta as usize).is_err() { + if builder_was_empty { + return Err(Error::invalid_input(format!( + "single document row_id={} exceeds memory budget: {} bytes", + row_id, self.memory_size + ))); + } + self.flush().await?; + } + } else if self.memory_size < reserved { + self.reservation + .shrink((reserved - self.memory_size) as usize); } - if self.builder.docs.len() as u32 == u32::MAX - || (!builder_was_empty && self.memory_size >= self.worker_memory_limit_bytes) - { + if self.builder.docs.len() as u32 == u32::MAX { self.flush().await?; } } @@ -1147,6 +1172,13 @@ impl IndexWorker { self.memory_size / (1024 * 1024) ); self.memory_size = self.temporary_memory_size(); + // Release all reserved memory, then re-reserve just for the remaining + // temporary buffers so the pool invariant stays consistent. + self.reservation.free(); + if self.memory_size > 0 { + // This should always succeed: we just freed all prior memory. + let _ = self.reservation.try_grow(self.memory_size as usize); + } let with_position = self.has_position(); let format_version = self.builder.format_version; let builder = std::mem::replace( @@ -1766,6 +1798,7 @@ mod tests { use arrow_array::{RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; + use datafusion::execution::memory_pool::UnboundedMemoryPool; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::stream; use lance_core::ROW_ID; @@ -1932,7 +1965,7 @@ mod tests { format_version: InvertedListFormatVersion::V1, fragment_mask: None, token_set_format, - worker_memory_limit_bytes: u64::MAX, + memory_pool: Arc::new(UnboundedMemoryPool::default()), }, ) .await?; @@ -1955,7 +1988,7 @@ mod tests { format_version: InvertedListFormatVersion::V1, fragment_mask: None, token_set_format, - worker_memory_limit_bytes: u64::MAX, + memory_pool: Arc::new(UnboundedMemoryPool::default()), }, ) .await?; @@ -2356,6 +2389,59 @@ mod tests { ); } + #[tokio::test] + async fn test_memory_pool_spills_on_tight_budget() -> Result<()> { + // Build a worker with a tight pool so try_grow signals spill, producing + // multiple completed partitions rather than one large tail partition. + // Each doc adds 300 unique tokens; the first doc uses ~69 KiB of tracked + // memory. A pool of 72 KiB fits the first doc but triggers a spill as soon + // as any growth occurs on the second doc, producing multiple partitions. + let index_dir = TempDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + index_dir.obj_path(), + Arc::new(LanceCache::no_cache()), + )); + + let doc_count = 20u64; + let tokens_per_doc = 300; + let batches: Vec<_> = (0..doc_count) + .map(|row_id| { + let doc: String = (0..tokens_per_doc) + .map(|t| format!("tok_{row_id}_{t}")) + .collect::>() + .join(" "); + make_doc_batch(&doc, row_id) + }) + .collect(); + + let params = InvertedIndexParams::default().num_workers(1); + let tokenizer = params.build()?; + let pool: Arc = Arc::new(FairSpillPool::new(72 * 1024)); + let id_alloc = Arc::new(AtomicU64::new(0)); + let config = IndexWorkerConfig { + with_position: false, + format_version: InvertedListFormatVersion::V1, + fragment_mask: None, + token_set_format: TokenSetFormat::default(), + memory_pool: pool, + }; + let mut worker = IndexWorker::new(tokenizer, store.clone(), id_alloc, config).await?; + for batch in &batches { + worker.process_batch(batch.clone()).await?; + } + let output = worker.finish().await?; + + let completed = output.partitions.len(); + let tail = output.tail_partition.is_some() as usize; + assert!( + completed >= 1, + "expected pool-triggered spill to produce at least one completed partition \ + (completed={completed}, tail={tail})" + ); + Ok(()) + } + #[tokio::test] async fn test_worker_trims_position_temp_buffers() -> Result<()> { let tokenizer = InvertedIndexParams::default().with_position(true).build()?; @@ -2370,7 +2456,7 @@ mod tests { format_version: InvertedListFormatVersion::V1, fragment_mask: None, token_set_format: TokenSetFormat::default(), - worker_memory_limit_bytes: u64::MAX, + memory_pool: Arc::new(UnboundedMemoryPool::default()), }, ) .await?; @@ -2401,7 +2487,7 @@ mod tests { format_version: InvertedListFormatVersion::V1, fragment_mask: None, token_set_format: TokenSetFormat::default(), - worker_memory_limit_bytes: u64::MAX, + memory_pool: Arc::new(UnboundedMemoryPool::default()), }, ) .await?; @@ -2439,7 +2525,7 @@ mod tests { format_version: InvertedListFormatVersion::V1, fragment_mask: None, token_set_format: TokenSetFormat::default(), - worker_memory_limit_bytes: u64::MAX, + memory_pool: Arc::new(UnboundedMemoryPool::default()), }, ) .await?;