Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 107 additions & 21 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use arrow::datatypes;
use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use bitpacking::{BitPacker, BitPacker4x};
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
use datafusion::execution::{
RecordBatchStream, SendableRecordBatchStream,
memory_pool::{FairSpillPool, MemoryConsumer, MemoryPool, MemoryReservation},
};
use deepsize::DeepSizeOf;
use fst::Streamer;
use futures::{Stream, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -86,6 +89,7 @@ fn resolve_num_workers(params: &InvertedIndexParams) -> usize {
.clamp(1, max_workers)
}

#[cfg(test)]
fn resolve_worker_memory_limit_bytes(params: &InvertedIndexParams, num_workers: usize) -> u64 {
let default_worker_memory_limit_bytes = *LANCE_FTS_PARTITION_SIZE << 20;
params
Expand All @@ -94,6 +98,14 @@ fn resolve_worker_memory_limit_bytes(params: &InvertedIndexParams, num_workers:
.unwrap_or(default_worker_memory_limit_bytes)
}

/// Total memory budget for the entire FTS build (all workers combined).
fn resolve_total_memory_limit_bytes(params: &InvertedIndexParams, num_workers: usize) -> u64 {
params
.memory_limit_mb
.map(|mb| mb << 20)
.unwrap_or_else(|| (*LANCE_FTS_PARTITION_SIZE << 20) * num_workers as u64)
}

fn merge_all_tail_partitions(tails: Vec<TailPartition>) -> Result<Option<InnerBuilder>> {
if tails.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -233,14 +245,15 @@ impl InvertedIndexBuilder {
let num_workers = resolve_num_workers(&self.params);
let tokenizer = self.params.build()?;
let with_position = self.params.with_position;
let worker_memory_limit_bytes =
resolve_worker_memory_limit_bytes(&self.params, num_workers);
let total_memory_bytes = resolve_total_memory_limit_bytes(&self.params, num_workers);
let memory_pool: Arc<dyn MemoryPool> =
Arc::new(FairSpillPool::new(total_memory_bytes.max(1) as usize));
let worker_config = IndexWorkerConfig {
with_position,
format_version: self.format_version,
fragment_mask: self.fragment_mask,
token_set_format: self.token_set_format,
worker_memory_limit_bytes,
memory_pool,
};
let next_id = self.partitions.iter().map(|id| id + 1).max().unwrap_or(0);
let id_alloc = Arc::new(AtomicU64::new(next_id));
Expand All @@ -255,6 +268,7 @@ impl InvertedIndexBuilder {
let id_alloc = id_alloc.clone();
let progress = self.progress.clone();
let tokenized_count = tokenized_count.clone();
let worker_config = worker_config.clone();
index_tasks.push(tokio::task::spawn(async move {
let mut worker =
IndexWorker::new(tokenizer, dest_store, id_alloc, worker_config).await?;
Expand Down Expand Up @@ -856,7 +870,7 @@ struct IndexWorker {
partitions: Vec<u64>,
schema: SchemaRef,
memory_size: u64,
worker_memory_limit_bytes: u64,
reservation: MemoryReservation,
total_doc_length: usize,
fragment_mask: Option<u64>,
token_set_format: TokenSetFormat,
Expand All @@ -873,13 +887,13 @@ struct WorkerOutput {
tail_partition: Option<TailPartition>,
}

#[derive(Debug, Clone, Copy)]
#[derive(Clone)]
struct IndexWorkerConfig {
with_position: bool,
format_version: InvertedListFormatVersion,
fragment_mask: Option<u64>,
token_set_format: TokenSetFormat,
worker_memory_limit_bytes: u64,
memory_pool: Arc<dyn MemoryPool>,
}

impl IndexWorker {
Expand Down Expand Up @@ -924,6 +938,7 @@ impl IndexWorker {
config: IndexWorkerConfig,
) -> Result<Self> {
let schema = inverted_list_schema_for_version(config.with_position, config.format_version);
let reservation = MemoryConsumer::new("FTSIndexWorker").register(&config.memory_pool);

Ok(Self {
tokenizer,
Expand All @@ -939,7 +954,7 @@ impl IndexWorker {
id_alloc,
schema,
memory_size: 0,
worker_memory_limit_bytes: config.worker_memory_limit_bytes,
reservation,
total_doc_length: 0,
fragment_mask: config.fragment_mask,
token_set_format: config.token_set_format,
Expand Down Expand Up @@ -1119,16 +1134,26 @@ impl IndexWorker {
self.temporary_memory_size(),
);

if self.builder.docs.len() == 1 && self.memory_size > self.worker_memory_limit_bytes {
return Err(Error::invalid_input(format!(
"single document row_id={} exceeds worker memory limit: {} > {} bytes",
row_id, self.memory_size, self.worker_memory_limit_bytes
)));
// Sync the memory pool reservation with the tracked memory usage.
// try_grow returning Err is the spill signal: flush the current partition.
let reserved = self.reservation.size() as u64;
if self.memory_size > reserved {
let delta = self.memory_size - reserved;
if self.reservation.try_grow(delta as usize).is_err() {
if builder_was_empty {
return Err(Error::invalid_input(format!(
"single document row_id={} exceeds memory budget: {} bytes",
row_id, self.memory_size
)));
}
self.flush().await?;
}
} else if self.memory_size < reserved {
self.reservation
.shrink((reserved - self.memory_size) as usize);
}

if self.builder.docs.len() as u32 == u32::MAX
|| (!builder_was_empty && self.memory_size >= self.worker_memory_limit_bytes)
{
if self.builder.docs.len() as u32 == u32::MAX {
self.flush().await?;
}
}
Expand All @@ -1147,6 +1172,13 @@ impl IndexWorker {
self.memory_size / (1024 * 1024)
);
self.memory_size = self.temporary_memory_size();
// Release all reserved memory, then re-reserve just for the remaining
// temporary buffers so the pool invariant stays consistent.
self.reservation.free();
if self.memory_size > 0 {
// This should always succeed: we just freed all prior memory.
let _ = self.reservation.try_grow(self.memory_size as usize);
}
let with_position = self.has_position();
let format_version = self.builder.format_version;
let builder = std::mem::replace(
Expand Down Expand Up @@ -1766,6 +1798,7 @@ mod tests {
use arrow_array::{RecordBatch, StringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
use datafusion::execution::memory_pool::UnboundedMemoryPool;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use lance_core::ROW_ID;
Expand Down Expand Up @@ -1932,7 +1965,7 @@ mod tests {
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format,
worker_memory_limit_bytes: u64::MAX,
memory_pool: Arc::new(UnboundedMemoryPool::default()),
},
)
.await?;
Expand All @@ -1955,7 +1988,7 @@ mod tests {
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format,
worker_memory_limit_bytes: u64::MAX,
memory_pool: Arc::new(UnboundedMemoryPool::default()),
},
)
.await?;
Expand Down Expand Up @@ -2356,6 +2389,59 @@ mod tests {
);
}

#[tokio::test]
async fn test_memory_pool_spills_on_tight_budget() -> Result<()> {
// Build a worker with a tight pool so try_grow signals spill, producing
// multiple completed partitions rather than one large tail partition.
// Each doc adds 300 unique tokens; the first doc uses ~69 KiB of tracked
// memory. A pool of 72 KiB fits the first doc but triggers a spill as soon
// as any growth occurs on the second doc, producing multiple partitions.
let index_dir = TempDir::default();
let store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
index_dir.obj_path(),
Arc::new(LanceCache::no_cache()),
));

let doc_count = 20u64;
let tokens_per_doc = 300;
let batches: Vec<_> = (0..doc_count)
.map(|row_id| {
let doc: String = (0..tokens_per_doc)
.map(|t| format!("tok_{row_id}_{t}"))
.collect::<Vec<_>>()
.join(" ");
make_doc_batch(&doc, row_id)
})
.collect();

let params = InvertedIndexParams::default().num_workers(1);
let tokenizer = params.build()?;
let pool: Arc<dyn MemoryPool> = Arc::new(FairSpillPool::new(72 * 1024));
let id_alloc = Arc::new(AtomicU64::new(0));
let config = IndexWorkerConfig {
with_position: false,
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format: TokenSetFormat::default(),
memory_pool: pool,
};
let mut worker = IndexWorker::new(tokenizer, store.clone(), id_alloc, config).await?;
for batch in &batches {
worker.process_batch(batch.clone()).await?;
}
let output = worker.finish().await?;

let completed = output.partitions.len();
let tail = output.tail_partition.is_some() as usize;
assert!(
completed >= 1,
"expected pool-triggered spill to produce at least one completed partition \
(completed={completed}, tail={tail})"
);
Ok(())
}

#[tokio::test]
async fn test_worker_trims_position_temp_buffers() -> Result<()> {
let tokenizer = InvertedIndexParams::default().with_position(true).build()?;
Expand All @@ -2370,7 +2456,7 @@ mod tests {
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format: TokenSetFormat::default(),
worker_memory_limit_bytes: u64::MAX,
memory_pool: Arc::new(UnboundedMemoryPool::default()),
},
)
.await?;
Expand Down Expand Up @@ -2401,7 +2487,7 @@ mod tests {
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format: TokenSetFormat::default(),
worker_memory_limit_bytes: u64::MAX,
memory_pool: Arc::new(UnboundedMemoryPool::default()),
},
)
.await?;
Expand Down Expand Up @@ -2439,7 +2525,7 @@ mod tests {
format_version: InvertedListFormatVersion::V1,
fragment_mask: None,
token_set_format: TokenSetFormat::default(),
worker_memory_limit_bytes: u64::MAX,
memory_pool: Arc::new(UnboundedMemoryPool::default()),
},
)
.await?;
Expand Down
Loading