diff --git a/rust/README.md b/rust/README.md index 34d434a..795a3c7 100644 --- a/rust/README.md +++ b/rust/README.md @@ -1,9 +1,19 @@ # Rust extension for pymde -This directory contains a Rust implementation of brute-force exact L2 -k-nearest neighbor search, exposed to Python via [PyO3](https://pyo3.rs). -It is a self-contained, solution that uses platform-native BLAS -(Accelerate on macOS, OpenBLAS on Linux) for accelerated matrix operations. +This directory contains Rust implementations of core pymde algorithms, +exposed to Python via [PyO3](https://pyo3.rs): + +- **Exact kNN** (`knn_l2`): brute-force L2 k-nearest neighbor search using + platform-native BLAS (Accelerate on macOS, OpenBLAS on Linux). Can be + extremely fast and competitive with approximate kNN algorithms on + machines with enough cores. +- **Approximate kNN** (`nn_descent`): NN-Descent algorithm for building + approximate k-nearest neighbor graphs. Uses RP-tree initialization and + iterative local joins to converge on high-recall neighbor graphs. Originally + based on the pynndescent implementation. For machines with only a few cores, + or on very large datasets, this can be much faster than the exact alternative. +- **BFS** (`breadth_first_directed`): breadth-first search on directed CSR + graphs. ## Prerequisites @@ -33,7 +43,7 @@ pip install -e '.[dev]' ``` This compiles the Rust code in release mode and places the resulting shared -library (`_knn.*.so` / `_knn.*.pyd`) into `pymde/`. +library (`_native.*.so` / `_native.*.pyd`) into `pymde/`. To rebuild after editing Rust code, run the same command again. Only changed files are recompiled. @@ -46,48 +56,73 @@ files are recompiled. cd rust && cargo test ``` -This runs the native Rust tests (for `insert_topk`, `sgemm_nn_t`, and -`knn_blas_tiled`) without needing Python. No extra setup beyond the Rust -toolchain is required. +This runs the native Rust tests without needing Python. No extra setup +beyond the Rust toolchain is required. ### Python integration tests ```sh -pytest pymde/test_knn.py -v +pytest pymde/test_knn.py -v # exact kNN +pytest pymde/preprocess/test_nndescent.py -v # approximate kNN ``` ## Project layout ``` rust/ -├── Cargo.toml # Package metadata and dependencies -├── Cargo.lock # Pinned dependency versions (committed for reproducibility) +├── Cargo.toml # Package metadata and dependencies +├── Cargo.lock # Pinned dependency versions (committed for reproducibility) └── src/ - └── lib.rs # All Rust source code (single file) + ├── lib.rs # PyO3 module definition and exports + ├── knn.rs # Exact kNN (BLAS-accelerated brute force) + ├── blas.rs # BLAS FFI bindings (sgemm) + ├── nndescent.rs # NN-Descent approximate kNN algorithm + ├── heap.rs # Thread-safe neighbor heaps with AtomicBool try-locks + ├── candidates.rs # Candidate tracking for NN-Descent iterations + ├── distance.rs # L2 distance kernels (with NEON intrinsics on aarch64) + ├── rng.rs # Fast deterministic PRNG (SplitMix64) + └── bfs.rs # Breadth-first search on directed CSR graphs ``` ## How it works -The module exposes one Python function: `pymde._knn.knn_l2(data, k)`. +### Exact kNN (`knn_l2`) -The algorithm: +`pymde._native.knn_l2(data, k)` — brute-force exact search. 1. Precompute squared norms `||x_i||^2` for every row. 2. Tile the data matrix into query blocks and database blocks. -3. For each tile pair, compute pairwise inner products using BLAS `sgemm` - (the fastest way to do dense matrix multiply). +3. For each tile pair, compute pairwise inner products using BLAS `sgemm`. 4. Recover squared distances via `||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a · b`. 5. Maintain a sorted top-k list per query row, keeping only the closest neighbors. Query tiles are processed in parallel using [rayon](https://docs.rs/rayon). +### Approximate kNN (`nn_descent`) + +`pymde._native.nn_descent(data, n_neighbors)` — approximate search via +NN-Descent, much faster than exact search for large datasets. + +1. **RP-Tree Init**: Build random projection trees to get an initial neighbor + graph. Points in the same leaf node become candidate neighbors. +2. **NN-Descent Loop**: Iteratively refine the graph using local joins — for + each point, compare its neighbors' neighbors as potential new neighbors. + Repeat until convergence (few updates per iteration). +3. **Finalize**: Sort heaps, apply sqrt to distances, return + `(neighbors, distances)`. + +Thread safety uses per-point `AtomicBool` try-locks for concurrent heap +updates, skipping on contention rather than blocking. + ## Key dependencies | Crate | Purpose | |-------|---------| | [pyo3](https://pyo3.rs) | Rust ↔ Python bindings (function signatures, type conversions, GIL management) | | [numpy](https://docs.rs/numpy) | Zero-copy access to NumPy arrays from Rust | -| [rayon](https://docs.rs/rayon) | Data-parallel iteration (parallelizes across query tiles) | +| [rayon](https://docs.rs/rayon) | Data-parallel iteration (parallelizes across query tiles and NN-Descent joins) | +| [rand](https://docs.rs/rand) | Random number generation (RP-tree construction) | +| [rand_chacha](https://docs.rs/rand_chacha) | Deterministic seeded RNG for reproducibility | BLAS is linked directly via `extern "C"` — no Rust BLAS crate is used. @@ -104,11 +139,11 @@ BLAS is linked directly via `extern "C"` — no Rust BLAS crate is used. } ``` -2. Export it from the module at the bottom of `lib.rs`: +2. Export it from the module in `lib.rs`: ```rust #[pymodule] - mod _knn { + mod _native { #[pymodule_export] use super::my_function; }