diff --git a/java/lance-jni/src/merge_insert.rs b/java/lance-jni/src/merge_insert.rs index df4d63bd2f6..4d332ca2c09 100644 --- a/java/lance-jni/src/merge_insert.rs +++ b/java/lance-jni/src/merge_insert.rs @@ -12,7 +12,8 @@ use jni::objects::{JObject, JString, JValueGen}; use jni::sys::jlong; use lance::dataset::scanner::ExprFilter; use lance::dataset::{ - MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, + MergeInsertBuilder, MergeStats, SourceDedupeBehavior, WhenMatched, WhenNotMatched, + WhenNotMatchedBySource, }; use lance_core::datatypes::Schema; use lance_index::mem_wal::MergedGeneration; @@ -52,6 +53,7 @@ fn inner_merge_insert<'local>( let retry_timeout_ms = extract_retry_timeout_ms(env, &jparam)?; let skip_auto_cleanup = extract_skip_auto_cleanup(env, &jparam)?; let use_index = extract_use_index(env, &jparam)?; + let source_dedupe_behavior = extract_source_dedupe_behavior(env, &jparam)?; let marked_generations = extract_marked_generations(env, &jparam)?; let (new_ds, merge_stats) = unsafe { @@ -71,6 +73,7 @@ fn inner_merge_insert<'local>( .retry_timeout(Duration::from_millis(retry_timeout_ms as u64)) .skip_auto_cleanup(skip_auto_cleanup) .use_index(use_index) + .source_dedupe_behavior(source_dedupe_behavior) .mark_generations_as_merged(marked_generations) .try_build()?; @@ -241,6 +244,29 @@ fn extract_use_index<'local>(env: &mut JNIEnv<'local>, jparam: &JObject) -> Resu Ok(use_index) } +fn extract_source_dedupe_behavior<'local>( + env: &mut JNIEnv<'local>, + jparam: &JObject, +) -> Result { + let behavior: JString = env + .call_method( + jparam, + "sourceDedupeBehaviorValue", + "()Ljava/lang/String;", + &[], + )? + .l()? + .into(); + let behavior = behavior.extract(env)?; + match behavior.as_str() { + "Fail" => Ok(SourceDedupeBehavior::Fail), + "FirstSeen" => Ok(SourceDedupeBehavior::FirstSeen), + _ => Err(Error::input_error(format!( + "Illegal source_dedupe_behavior: {behavior}", + ))), + } +} + fn extract_marked_generations<'local>( env: &mut JNIEnv<'local>, jparam: &JObject, diff --git a/java/src/main/java/org/lance/merge/MergeInsertParams.java b/java/src/main/java/org/lance/merge/MergeInsertParams.java index 2ae27b67cba..ab67e022604 100644 --- a/java/src/main/java/org/lance/merge/MergeInsertParams.java +++ b/java/src/main/java/org/lance/merge/MergeInsertParams.java @@ -39,6 +39,7 @@ public class MergeInsertParams { private long retryTimeoutMs = 30 * 1000; private boolean skipAutoCleanup = false; private boolean useIndex = true; + private SourceDedupeBehavior sourceDedupeBehavior = SourceDedupeBehavior.Fail; private List markedGenerations = Collections.emptyList(); public MergeInsertParams(List on) { @@ -244,6 +245,25 @@ public MergeInsertParams withUseIndex(boolean useIndex) { return this; } + /** + * Control how duplicate source rows that match the same target row are handled. + * + *

Default is {@link SourceDedupeBehavior#Fail}, which errors if the source contains duplicate + * join keys. Use {@link SourceDedupeBehavior#FirstSeen} to keep the first encountered row and + * skip subsequent duplicates. + * + *

If the source contains duplicates and {@code FirstSeen} behavior doesn't match your needs, + * sort the source data before passing it to the merge insert operation. + * + * @param sourceDedupeBehavior The behavior to apply when duplicate source rows are found + * @return This MergeInsertParams instance + */ + public MergeInsertParams withSourceDedupeBehavior(SourceDedupeBehavior sourceDedupeBehavior) { + Preconditions.checkNotNull(sourceDedupeBehavior); + this.sourceDedupeBehavior = sourceDedupeBehavior; + return this; + } + /** * Mark MemWAL generations as merged into the base table. * @@ -319,6 +339,14 @@ public boolean useIndex() { return useIndex; } + public SourceDedupeBehavior sourceDedupeBehavior() { + return sourceDedupeBehavior; + } + + public String sourceDedupeBehaviorValue() { + return sourceDedupeBehavior.name(); + } + @Override public String toString() { return MoreObjects.toStringHelper(this) @@ -337,6 +365,7 @@ public String toString() { .add("retryTimeoutMs", retryTimeoutMs) .add("skipAutoCleanup", skipAutoCleanup) .add("useIndex", useIndex) + .add("sourceDedupeBehavior", sourceDedupeBehavior) .toString(); } @@ -396,4 +425,18 @@ public enum WhenNotMatchedBySource { */ DeleteIf, } + + /** + * Describes how to handle duplicate source rows that match the same target row. + * + *

If the source contains duplicates and {@code FirstSeen} behavior doesn't match your needs, + * sort the source data before passing it to the merge insert operation. + */ + public enum SourceDedupeBehavior { + /** Fail the operation if duplicates are found (default). */ + Fail, + + /** Keep the first seen value and skip subsequent duplicates. */ + FirstSeen, + } } diff --git a/java/src/test/java/org/lance/MergeInsertTest.java b/java/src/test/java/org/lance/MergeInsertTest.java index b738ef8852d..540cb0fec2a 100644 --- a/java/src/test/java/org/lance/MergeInsertTest.java +++ b/java/src/test/java/org/lance/MergeInsertTest.java @@ -257,6 +257,30 @@ private VectorSchemaRoot buildSource(Schema schema, RootAllocator allocator) { return root; } + /** + * Build a source whose join key {@code id=0} appears twice ("First 0", then "Second 0"), so the + * source-dedupe behavior is exercised. Remaining ids (3, 4) are unique matches. + */ + private VectorSchemaRoot buildDuplicateKeySource(Schema schema, RootAllocator allocator) { + List sourceIds = Arrays.asList(0, 0, 3, 4); + List sourceNames = Arrays.asList("First 0", "Second 0", "Source 3", "Source 4"); + + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + + IntVector idVector = (IntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + + for (int i = 0; i < sourceIds.size(); i++) { + idVector.setSafe(i, sourceIds.get(i)); + nameVector.setSafe(i, sourceNames.get(i).getBytes(StandardCharsets.UTF_8)); + } + + root.setRowCount(sourceIds.size()); + + return root; + } + private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator allocator) throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -275,6 +299,61 @@ private ArrowArrayStream convertToStream(VectorSchemaRoot root, RootAllocator al return stream; } + @Test + public void testSourceDedupeFirstSeenKeepsFirst() throws Exception { + // Source has two rows for id=0 ("First 0" then "Second 0"). FirstSeen keeps + // the first encountered row and skips the duplicate. + + try (VectorSchemaRoot source = buildDuplicateKeySource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + MergeInsertResult result = + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsertParams.WhenNotMatched.InsertAll) + .withSourceDedupeBehavior(MergeInsertParams.SourceDedupeBehavior.FirstSeen), + sourceStream); + + Assertions.assertEquals( + "{0=First 0, 1=Person 1, 2=Person 2, 3=Source 3, 4=Source 4}", + readAll(result.dataset()).toString(), + "FirstSeen should keep the first duplicate source row (id=0) and update unique matches"); + } + } + } + + @Test + public void testSourceDedupeFailWithDuplicates() throws Exception { + // Default behavior (Fail) must error when the source contains duplicate join keys. + + try (VectorSchemaRoot source = buildDuplicateKeySource(testDataset.getSchema(), allocator)) { + try (ArrowArrayStream sourceStream = convertToStream(source, allocator)) { + String originalDataset = readAll(dataset).toString(); + + Exception ex = + Assertions.assertThrows( + Exception.class, + () -> + dataset.mergeInsert( + new MergeInsertParams(Collections.singletonList("id")) + .withMatchedUpdateAll() + .withNotMatched(MergeInsertParams.WhenNotMatched.InsertAll) + .withSourceDedupeBehavior(MergeInsertParams.SourceDedupeBehavior.Fail), + sourceStream)); + + Assertions.assertNotNull(ex.getMessage(), "exception should carry a message"); + Assertions.assertTrue( + ex.getMessage().contains("Ambiguous merge inserts are prohibited"), + "Fail should report the ambiguous-merge cause, got: " + ex.getMessage()); + + Assertions.assertEquals( + originalDataset, + readAll(dataset).toString(), + "Dataset should remain unchanged after a failed mergeInsert"); + } + } + } + @Test public void testMergeInsertWithoutIndex() throws Exception { // Verify that merge insert with useIndex=false still completes and diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 3e0d77704da..2894ef4f8da 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -129,8 +129,8 @@ pub use schema_evolution::{ pub use take::TakeBuilder; use uuid::Uuid; pub use write::merge_insert::{ - MergeInsertBuilder, MergeInsertJob, MergeStats, UncommittedMergeInsert, WhenMatched, - WhenNotMatched, WhenNotMatchedBySource, + MergeInsertBuilder, MergeInsertJob, MergeStats, SourceDedupeBehavior, UncommittedMergeInsert, + WhenMatched, WhenNotMatched, WhenNotMatchedBySource, }; use crate::dataset::index::LanceIndexStoreExt;