diff --git a/java/lance-jni/src/index.rs b/java/lance-jni/src/index.rs index 1e533eed9fc..6cb64a05a81 100644 --- a/java/lance-jni/src/index.rs +++ b/java/lance-jni/src/index.rs @@ -173,6 +173,8 @@ fn determine_index_type<'local>( Some("ZONEMAP") } else if lower.contains("bloomfilter") { Some("BLOOM_FILTER") + } else if lower.contains("rtree") { + Some("RTREE") } else if lower.contains("ivfhnsw") { if lower.contains("sq") { Some("IVF_HNSW_SQ") diff --git a/java/src/main/java/org/lance/index/IndexType.java b/java/src/main/java/org/lance/index/IndexType.java index 3a03934effd..1fff86fc7e0 100644 --- a/java/src/main/java/org/lance/index/IndexType.java +++ b/java/src/main/java/org/lance/index/IndexType.java @@ -24,6 +24,7 @@ public enum IndexType { MEM_WAL(7), ZONEMAP(8), BLOOM_FILTER(9), + RTREE(10), VECTOR(100), IVF_FLAT(101), IVF_SQ(102), diff --git a/java/src/main/java/org/lance/index/scalar/ScalarIndexParams.java b/java/src/main/java/org/lance/index/scalar/ScalarIndexParams.java index 345a55f20b2..b3408e2d68d 100644 --- a/java/src/main/java/org/lance/index/scalar/ScalarIndexParams.java +++ b/java/src/main/java/org/lance/index/scalar/ScalarIndexParams.java @@ -31,7 +31,7 @@ private ScalarIndexParams(Builder builder) { * Create a new ScalarIndexParams with the given index type and no parameters. * * @param indexType the index type (e.g., "btree", "zonemap", "bitmap", "inverted", "labellist", - * "ngram") + * "ngram", "rtree") * @return ScalarIndexParams */ public static ScalarIndexParams create(String indexType) { @@ -42,7 +42,7 @@ public static ScalarIndexParams create(String indexType) { * Create a new ScalarIndexParams with the given index type and JSON parameters. * * @param indexType the index type (e.g., "btree", "zonemap", "bitmap", "inverted", "labellist", - * "ngram") + * "ngram", "rtree") * @param jsonParams JSON string containing index-specific parameters * @return ScalarIndexParams */ @@ -58,7 +58,7 @@ public static class Builder { * Create a new builder for scalar index parameters. * * @param indexType the index type (e.g., "btree", "zonemap", "bitmap", "inverted", "labellist", - * "ngram") + * "ngram", "rtree") */ public Builder(String indexType) { this.indexType = indexType; diff --git a/java/src/test/java/org/lance/index/ScalarIndexTest.java b/java/src/test/java/org/lance/index/ScalarIndexTest.java index b993a7e8a5f..cb090e7c955 100644 --- a/java/src/test/java/org/lance/index/ScalarIndexTest.java +++ b/java/src/test/java/org/lance/index/ScalarIndexTest.java @@ -25,14 +25,18 @@ import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -318,4 +322,78 @@ public void testCreateZonemapIndex(@TempDir Path tempDir) throws Exception { } } } + + @Test + public void testCreateRTreeIndex(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("rtree_test").toString(); + ArrowType f64 = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + Field geometryField = + new Field( + "geometry", + new FieldType( + true, + new ArrowType.Struct(), + null, + Collections.singletonMap("ARROW:extension:name", "geoarrow.point")), + Arrays.asList(Field.notNullable("x", f64), Field.notNullable("y", f64))); + Schema schema = new Schema(Collections.singletonList(geometryField), null); + + int rowCount = 3; + try (RootAllocator allocator = new RootAllocator(); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + StructVector geometry = (StructVector) root.getVector("geometry"); + Float8Vector x = (Float8Vector) geometry.getChild("x"); + Float8Vector y = (Float8Vector) geometry.getChild("y"); + for (int i = 0; i < rowCount; i++) { + geometry.setIndexDefined(i); + x.setSafe(i, (double) i); + y.setSafe(i, i * 2.0); + } + geometry.setValueCount(rowCount); + root.setRowCount(rowCount); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + try (ArrowStreamReader reader = + new ArrowStreamReader(new ByteArrayInputStream(out.toByteArray()), allocator); + Dataset dataset = + Dataset.write() + .reader(reader) + .uri(datasetPath) + .allocator(allocator) + .mode(WriteParams.WriteMode.CREATE) + .execute()) { + // The point data round-trips through Lance. + assertEquals(rowCount, dataset.countRows()); + try (ArrowReader scan = dataset.newScan(new ScanOptions.Builder().build()).scanBatches()) { + assertTrue(scan.loadNextBatch()); + StructVector readGeometry = + (StructVector) scan.getVectorSchemaRoot().getVector("geometry"); + assertEquals(2.0, ((Float8Vector) readGeometry.getChild("x")).get(2)); + assertEquals(4.0, ((Float8Vector) readGeometry.getChild("y")).get(2)); + } + + // Creating and listing an RTree index via the typed IndexType works end-to-end. + Index index = + dataset.createIndex( + Collections.singletonList("geometry"), + IndexType.RTREE, + Optional.of("rtree_geometry_index"), + IndexParams.builder() + .setScalarIndexParams(ScalarIndexParams.create("rtree")) + .build(), + true); + assertEquals(IndexType.RTREE, index.indexType()); + assertTrue( + dataset.listIndexes().contains("rtree_geometry_index"), + "Expected 'rtree_geometry_index' in: " + dataset.listIndexes()); + } + } + } }