Add DiskAnnPy Python wheel (diskannpy-rust) and PyPI publish workflow#872
Add DiskAnnPy Python wheel (diskannpy-rust) and PyPI publish workflow#872YuanyuanTian-hh wants to merge 10 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a new DiskAnnPy Python package (Rust pyo3 extension + Python wrapper) into the DiskANN workspace and adds CI automation to build/test/publish wheels.
Changes:
- Adds a new Rust
DiskAnnPycrate exposing DiskANN APIs to Python viapyo3(static disk index, async memory index, BfTree index, quantizers, shared utils). - Adds a Python wrapper package (
DiskAnnPy/python/diskannpy) with builders, index classes, file utilities, and type hints. - Adds Python unit tests + fixtures and a GitHub Actions workflow to build/test/publish wheels to PyPI.
Reviewed changes
Copilot reviewed 48 out of 58 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| DiskAnnPy/tests/test_static_disk_index.py | Adds StaticDiskIndex unit tests (recall, validation, relative paths). |
| DiskAnnPy/tests/test_quantization.py | Adds MinMax/Product quantizer workflow and accuracy tests. |
| DiskAnnPy/tests/test_builder.py | Adds build_disk_index input validation tests. |
| DiskAnnPy/tests/test_bftree_index.py | Adds BfTreeIndex tests including on-disk save/load and PQ paths. |
| DiskAnnPy/tests/test_async_index.py | Adds AsyncDiskIndex tests, including PQ and delete/insert flows. |
| DiskAnnPy/tests/fixtures/recall.py | Provides recall computation + PQ recall cutoff. |
| DiskAnnPy/tests/fixtures/create_test_data.py | Utilities to generate/write vector test data. |
| DiskAnnPy/tests/fixtures/build_disk_index.py | Fixture to build a StaticDiskIndex test dataset/index. |
| DiskAnnPy/tests/fixtures/build_bftree_index.py | Fixtures to build in-memory/on-disk BfTree indexes (PQ and non-PQ). |
| DiskAnnPy/tests/fixtures/build_async_index.py | Fixture to build async memory indexes (PQ and non-PQ). |
| DiskAnnPy/tests/fixtures/init.py | Re-exports fixtures for tests. |
| DiskAnnPy/tests/data/query_preexisting_vector.bin | Adds binary test data. |
| DiskAnnPy/tests/data/query_insert_vector.bin | Adds binary test data. |
| DiskAnnPy/tests/data/ann_metadata.bin | Adds binary test data. |
| DiskAnnPy/src/utils/search_result.rs | Adds Rust structs for Python-facing search results + stats helpers. |
| DiskAnnPy/src/utils/parallel_tasks.rs | Adds async “worker pool over iterator” helper. |
| DiskAnnPy/src/utils/mod.rs | Wires new utils modules into the Rust crate. |
| DiskAnnPy/src/utils/metric_py.rs | Adds Python-exposed Metric enum and parsing. |
| DiskAnnPy/src/utils/index_build_utils.rs | Adds runtime init + common error helpers. |
| DiskAnnPy/src/utils/graph_data_types.rs | Defines GraphDataType bindings for f32/u8/i8. |
| DiskAnnPy/src/utils/dataset_utils.rs | Adds dataset loading/alignment utilities + recall/truthset helpers. |
| DiskAnnPy/src/utils/data_type.rs | Adds Python-exposed DataType enum and parsing. |
| DiskAnnPy/src/utils/convert_py_array.rs | Adds helper to convert PyArray2 to Vec<Vec> (+ unit test). |
| DiskAnnPy/src/utils/ann_result_py.rs | Adds Python exception wrapper type for ANNError. |
| DiskAnnPy/src/static_disk_index.rs | Adds Rust implementation + Python bindings for StaticDiskIndex. |
| DiskAnnPy/src/quantization/README.md | Adds user documentation for quantizers. |
| DiskAnnPy/src/quantization/product.rs | Adds ProductQuantizer Python bindings. |
| DiskAnnPy/src/quantization/mod.rs | Exposes quantization modules/types. |
| DiskAnnPy/src/quantization/minmax.rs | Adds MinMaxQuantizer Python bindings. |
| DiskAnnPy/src/quantization/metric.rs | Adds metric parsing/translation for quantization. |
| DiskAnnPy/src/quantization/base.rs | Adds base quantizer class for shared metadata. |
| DiskAnnPy/src/lib.rs | Defines the _diskannpy Python module and registers exported classes/functions. |
| DiskAnnPy/src/build_disk_index.rs | Adds Rust entrypoint for building disk indexes via Python. |
| DiskAnnPy/src/build_async_memory_index.rs | Adds Rust entrypoint for building async memory indexes via Python. |
| DiskAnnPy/README.md | Adds DiskAnnPy build/test instructions (maturin-based). |
| DiskAnnPy/python/diskannpy/_static_disk_index.py | Adds Python wrapper for StaticDiskIndex. |
| DiskAnnPy/python/diskannpy/_files.py | Adds vector/tag file I/O helpers. |
| DiskAnnPy/python/diskannpy/_defaults.py | Adds Python-side defaults/constants. |
| DiskAnnPy/python/diskannpy/_common.py | Adds shared validation/helpers + metadata read/write. |
| DiskAnnPy/python/diskannpy/_builder.py | Adds Python builder functions for disk + async indexes. |
| DiskAnnPy/python/diskannpy/_bftree_diskann.py | Adds Python wrapper for BfTreeIndex (incl. save/load). |
| DiskAnnPy/python/diskannpy/_async_diskann.py | Adds Python wrapper for AsyncDiskIndex. |
| DiskAnnPy/python/diskannpy/init.py | Adds public API surface + type aliases for the package. |
| DiskAnnPy/pyproject.toml | Adds packaging metadata and dependencies for the Python project. |
| DiskAnnPy/py.typed | Marks package as typed. |
| DiskAnnPy/Cargo.toml | Adds Rust crate config/dependencies for the Python extension. |
| DiskAnnPy/.gitignore | Adds local ignores for the new Python/Rust build outputs. |
| Cargo.toml | Adds DiskAnnPy to the workspace members. |
| Cargo.lock | Updates lockfile for the new crate/dependencies. |
| .github/workflows/publish-diskannpy.yml | Adds wheel build/test/publish workflow using maturin + trusted publishing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if dtype == np.uint8: | ||
| return np.uint8 | ||
| if dtype == np.int8: | ||
| return np.int8 | ||
| if dtype == np.float32: | ||
| return np.float32 | ||
|
|
||
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | ||
| _assert_dtype(dtype) | ||
| if dtype == np.uint8: | ||
| return dap.DataType.Uint8 | ||
| if dtype == np.int8: | ||
| return dap.DataType.Int8 | ||
| if dtype == np.float32: | ||
| return dap.DataType.Float |
There was a problem hiding this comment.
valid_dtype() allows dtypes that are castable to the supported set (e.g., np.single, np.byte, np.ubyte) via _assert_dtype, but then only returns for exact np.float32/np.int8/np.uint8. For castable aliases it falls off the end and returns None, which will later break index building/loading. Map aliases to canonical dtypes (e.g., treat np.single as np.float32, np.byte as np.int8, np.ubyte as np.uint8) and ensure the function always returns a dtype after _assert_dtype passes.
| if dtype == np.uint8: | |
| return np.uint8 | |
| if dtype == np.int8: | |
| return np.int8 | |
| if dtype == np.float32: | |
| return np.float32 | |
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | |
| _assert_dtype(dtype) | |
| if dtype == np.uint8: | |
| return dap.DataType.Uint8 | |
| if dtype == np.int8: | |
| return dap.DataType.Int8 | |
| if dtype == np.float32: | |
| return dap.DataType.Float | |
| canonical_dtype = np.dtype(dtype).type | |
| if canonical_dtype == np.uint8: | |
| return np.uint8 | |
| if canonical_dtype == np.int8: | |
| return np.int8 | |
| if canonical_dtype == np.float32: | |
| return np.float32 | |
| # This should be unreachable because _assert_dtype ensures castability to _VALID_DTYPES | |
| raise ValueError(f"Unsupported dtype after canonicalization: {dtype!r}") | |
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | |
| _assert_dtype(dtype) | |
| canonical_dtype = np.dtype(dtype).type | |
| if canonical_dtype == np.uint8: | |
| return dap.DataType.Uint8 | |
| if canonical_dtype == np.int8: | |
| return dap.DataType.Int8 | |
| if canonical_dtype == np.float32: | |
| return dap.DataType.Float | |
| # This should be unreachable because _assert_dtype ensures castability to _VALID_DTYPES | |
| raise ValueError(f"Unsupported dtype after canonicalization: {dtype!r}") |
| __all__ = [ | ||
| "build_disk_index", | ||
| "StaticDiskIndex", | ||
| "build_async_index", | ||
| "AsyncDiskIndex", | ||
| "BfTreeIndex", | ||
| "defaults", | ||
| "DistanceMetric", | ||
| "VectorDType", | ||
| "QueryResponse", | ||
| "QueryResponseBatch", | ||
| "VectorIdentifier", | ||
| "VectorIdentifierBatch", | ||
| "VectorLike", | ||
| "VectorLikeBatch", | ||
| "Metadata", | ||
| "vectors_metadata_from_file", | ||
| "vectors_to_file", | ||
| "vectors_from_file", | ||
| "tags_to_file", | ||
| "tags_from_file", | ||
| "valid_dtype", | ||
| ] |
There was a problem hiding this comment.
__all__ exports "QueryResponseBatch", but no such symbol is defined in this module (the type is named QueryResponseBatchWithStats). This makes from diskannpy import QueryResponseBatch fail and also contradicts the docstring list above. Either rename/alias QueryResponseBatchWithStats to QueryResponseBatch, or update __all__ and docs to the correct name.
| warnings.warn( | ||
| f"k_neighbors={k_value} asked for, but list_size={l_value} was smaller. Increasing {l_value} to {k_value}" | ||
| ) | ||
| complexity = k_value |
There was a problem hiding this comment.
When k_value > l_value, the code warns that it is increasing l_value but assigns to a new variable complexity that is never used. As a result, the call to self._index.search(... l_value=l_value ...) still uses the too-small l_value. Update this to set l_value = k_value (or otherwise ensure the native call receives the adjusted value).
| complexity = k_value | |
| l_value = k_value |
| zipped.for_each_in_pool( | ||
| &pool, | ||
| |(((query, query_result_ids), distance_results), query_stats)| { | ||
| let search_result = self | ||
| .search_internal(query, recall_at, l_value, query_stats) | ||
| .unwrap(); | ||
| *query_result_ids = search_result.ids; | ||
| *distance_results = search_result.distances; | ||
| }, | ||
| ); |
There was a problem hiding this comment.
batch_search calls search_internal(...).unwrap() inside the rayon loop. Any search error will panic, which will abort the Python process instead of returning a Python exception. Propagate errors out of the parallel loop (e.g., collect Results and return the first error, or use a shared error slot and short-circuit).
| pub async fn run<I, T, F, Fut>(iterator: I, num_tasks: usize, task_fn: F) | ||
| where | ||
| I: Iterator<Item = T> + Send + 'static, | ||
| T: Send + 'static, | ||
| F: Fn(T) -> Fut + Send + Sync + Clone + 'static, | ||
| Fut: Future + Send + 'static, | ||
| { | ||
| let mut tasks = JoinSet::new(); | ||
| let iterator = Arc::new(Mutex::new(iterator)); | ||
|
|
||
| for _ in 0..num_tasks { | ||
| let iterator_clone = iterator.clone(); | ||
| let task_fn = task_fn.clone(); | ||
|
|
||
| tasks.spawn(async move { | ||
| loop { | ||
| let item = { | ||
| let mut guard = iterator_clone.lock().await; | ||
| guard.next() | ||
| }; | ||
|
|
||
| match item { | ||
| Some(item) => { | ||
| let _ = task_fn(item).await; | ||
| } | ||
| None => break, | ||
| } | ||
| } | ||
| }); |
There was a problem hiding this comment.
The helper drops the output of task_fn(item).await (let _ = ...). Callers in this repo pass async blocks that return ANNResult, so failures are silently ignored and operations may report success despite partial failure. Consider constraining Fut::Output to a Result and returning an aggregated error (or at least capturing/logging failures) instead of discarding the result.
DiskAnnPy/src/build_disk_index.rs
Outdated
| ) -> ANNResultPy<()> { | ||
| println!( | ||
| "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {} num_of_pq_chunks: {} build_DRAM_budget: {}", | ||
| graph_degree, complexity, ALPHA, num_threads, num_of_pq_chunks, build_dram_budget |
There was a problem hiding this comment.
The log line prints ALPHA (the compile-time default) instead of the alpha parameter passed into build_disk_index. This can be misleading when callers override alpha; print the actual argument value.
| graph_degree, complexity, ALPHA, num_threads, num_of_pq_chunks, build_dram_budget | |
| graph_degree, complexity, alpha, num_threads, num_of_pq_chunks, build_dram_budget |
| @classmethod | ||
| def setUpClass(cls) -> None: | ||
| cls._test_matrix = [ | ||
| build_random_vectors_and_async_index(np.float32, "l2"), | ||
| build_random_vectors_and_async_index(np.uint8, "l2"), | ||
| build_random_vectors_and_async_index(np.int8, "l2"), | ||
| build_random_vectors_and_async_index(np.float32, "cosine"), | ||
| ] | ||
| cls._test_matrix_with_pq = [ | ||
| build_random_vectors_and_async_index(np.float32, "l2", use_pq=True, num_pq_bytes=5, use_opq=False), | ||
| build_random_vectors_and_async_index(np.uint8, "l2", use_pq=True, num_pq_bytes=5, use_opq=False), | ||
| build_random_vectors_and_async_index(np.int8, "l2", use_pq=True, num_pq_bytes=5, use_opq=False), | ||
| build_random_vectors_and_async_index(np.int8, "cosine", use_pq=True, num_pq_bytes=5, use_opq=False), | ||
| ] |
There was a problem hiding this comment.
These tests build multiple full indices in setUpClass using 10,000-point datasets (and test_async_index builds 8 of them). This is likely to make CI runs for wheel validation very slow/flaky. Consider reducing vector counts for unit tests, reusing a single built index across classes, and/or gating the large benchmarks behind an env var (e.g., only run the large-build matrix when explicitly requested).
| from typing import Optional | ||
|
|
||
| import numpy as np | ||
| from typing import Optional |
There was a problem hiding this comment.
Duplicate from typing import Optional import. Remove the redundant import to avoid lint noise and keep imports tidy.
| from typing import Optional |
| // Allow the iterator to be shared | ||
| let dataset_iter = Arc::new(Mutex::new( | ||
| VectorDataIterator::<StorageType, AdHoc<T>>::new( | ||
| &data_path, | ||
| Option::None, | ||
| storage_provider, | ||
| )? | ||
| .enumerate(), | ||
| )); |
There was a problem hiding this comment.
dataset_iter is protected by std::sync::Mutex but is locked from within async tasks (tasks.spawn(async move { ... lock() ... })). A blocking mutex in async code can stall Tokio worker threads under contention. Prefer tokio::sync::Mutex (or a channel/stream-based producer) for the shared iterator to avoid blocking the runtime.
| if dtype == np.uint8: | ||
| return np.uint8 | ||
| if dtype == np.int8: | ||
| return np.int8 | ||
| if dtype == np.float32: | ||
| return np.float32 | ||
|
|
||
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | ||
| _assert_dtype(dtype) | ||
| if dtype == np.uint8: | ||
| return dap.DataType.Uint8 | ||
| if dtype == np.int8: | ||
| return dap.DataType.Int8 | ||
| if dtype == np.float32: | ||
| return dap.DataType.Float |
There was a problem hiding this comment.
valid_dap_dtype() has the same issue as valid_dtype(): _assert_dtype permits castable aliases (e.g., np.single), but this function only matches exact np.float32/np.int8/np.uint8 and otherwise returns None. Canonicalize aliases before the equality checks (or compare via np.dtype(dtype)), and ensure this function always returns a dap.DataType once _assert_dtype passes.
| if dtype == np.uint8: | |
| return np.uint8 | |
| if dtype == np.int8: | |
| return np.int8 | |
| if dtype == np.float32: | |
| return np.float32 | |
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | |
| _assert_dtype(dtype) | |
| if dtype == np.uint8: | |
| return dap.DataType.Uint8 | |
| if dtype == np.int8: | |
| return dap.DataType.Int8 | |
| if dtype == np.float32: | |
| return dap.DataType.Float | |
| canonical = np.dtype(dtype).type | |
| if canonical == np.uint8: | |
| return np.uint8 | |
| if canonical == np.int8: | |
| return np.int8 | |
| if canonical == np.float32: | |
| return np.float32 | |
| # This should be unreachable if _assert_dtype and _VALID_DTYPES stay in sync, | |
| # but is kept as a defensive check. | |
| raise ValueError(f"Unsupported vector dtype: {dtype!r}") | |
| def valid_dap_dtype(dtype: Type) -> dap.DataType: | |
| """ | |
| Utility method to determine the corresponding dap.DataType for a supported numpy dtype. | |
| """ | |
| _assert_dtype(dtype) | |
| canonical = np.dtype(dtype).type | |
| if canonical == np.uint8: | |
| return dap.DataType.Uint8 | |
| if canonical == np.int8: | |
| return dap.DataType.Int8 | |
| if canonical == np.float32: | |
| return dap.DataType.Float | |
| # This should be unreachable if _assert_dtype and _VALID_DTYPES stay in sync, | |
| # but is kept as a defensive check. | |
| raise ValueError(f"Unsupported dap vector dtype: {dtype!r}") |
…y-slice guard, unused import, duplicate import, __all__ exports; clean up CI triggers
|
@YuanyuanTian-hh please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #872 +/- ##
==========================================
- Coverage 90.45% 89.31% -1.14%
==========================================
Files 442 445 +3
Lines 83248 84424 +1176
==========================================
+ Hits 75301 75407 +106
- Misses 7947 9017 +1070
Flags with carried forward coverage won't be shown. Click here to find out more. 🚀 New features to boost your workflow:
|
Summary
Add DiskAnnPy, a PyO3-based Python wheel that exposes DiskANN's Rust index types to Python, and a GitHub Actions workflow to build and publish it to PyPI as \diskannpy-rust.
What's included
Rust bindings (\DiskAnnPy/src/)
Python layer (\DiskAnnPy/python/diskannpy/)
Build & publish (.github/workflows/publish-diskannpy.yml)
Tests (\DiskAnnPy/tests/)
CI changes
Test results
Wheel builds (all passed ):
Wheel tests (all passed ):
PyPI publish: Successfully published \diskannpy-rust==0.49.1\ (link)
Workflow run: https://github.com/microsoft/DiskANN/actions/runs/23782386152
Known limitations
How to install
\
pip install diskannpy-rust
\\
Copilot review fixes addressed
p.single\
p.float32) and raise on unmapped types instead of returning \None\