Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 254 additions & 23 deletions diskann-providers/src/index/wrapped_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use std::{num::NonZeroUsize, sync::Arc};

use diskann::{
ANNResult,
ANNError, ANNResult,
graph::{
self, ConsolidateKind, InplaceDeleteMethod,
glue::{
Expand All @@ -22,50 +22,66 @@ use diskann::{
utils::{ONE, async_tools::VectorIdBoxSlice},
};

use crate::storage::{LoadWith, StorageReadProvider};

/// Synchronous wrapper around [`graph::DiskANNIndex`] that owns or borrows a tokio runtime.
pub struct DiskANNIndex<DP: DataProvider> {
/// The underlying async DiskANNIndex.
pub inner: Arc<graph::DiskANNIndex<DP>>,
/// Keeps the runtime alive when `Self` owns it; `None` when using an external handle.
_runtime: Option<tokio::runtime::Runtime>,
handle: tokio::runtime::Handle,
}

/// Create a multi-threaded tokio runtime and return it together with its handle.
fn create_multi_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_multi_thread()
.build()
.expect("failed to create tokio runtime");
let handle = rt.handle().clone();
(rt, handle)
}

/// Create a current-thread tokio runtime and return it together with its handle.
fn create_current_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("failed to create tokio runtime");
let handle = rt.handle().clone();
(rt, handle)
}

impl<DP> DiskANNIndex<DP>
where
DP: DataProvider,
{
/// Construct a synchronous `DiskANNIndex` with its own `tokio::runtime::Runtime`.
/// Construct a synchronous `DiskANNIndex` with its own multi-threaded `tokio::runtime::Runtime`.
///
/// A default configured multi-threaded runtime will be created and used behind the scenes. To use
/// a specific Toktio runtime, use `DiskANNIndex::new_with_multi_thread_runtime()` or `DiskANNIndex::new_with_handle()`.
/// A default multi-threaded runtime will be created and owned by `Self`. For a single-threaded
/// runtime use [`new_with_current_thread_runtime`](Self::new_with_current_thread_runtime), or
/// to supply an external runtime handle use [`new_with_handle`](Self::new_with_handle).
pub fn new_with_multi_thread_runtime(config: graph::Config, data_provider: DP) -> Self {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_multi_thread()
.build()
.expect("failed to create tokio runtime");

let handle = rt.handle().clone();

let (rt, handle) = create_multi_thread_runtime();
Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE))
}

/// Construct a synchronous `DiskANNIndex` with its own `tokio::runtime::Runtime`.
/// Construct a synchronous `DiskANNIndex` with its own single-threaded `tokio::runtime::Runtime`.
///
/// A default configured runtime that uses the curren thread will be created and used behind the scenes. To use
/// a specific Toktio runtime, use `DiskANNIndex::new_with_multi_thread_runtime()` or `DiskANNIndex::new_with_handle()`.
/// A default current-thread runtime will be created and owned by `Self`. For a multi-threaded
/// runtime use [`new_with_multi_thread_runtime`](Self::new_with_multi_thread_runtime), or
/// to supply an external runtime handle use [`new_with_handle`](Self::new_with_handle).
pub fn new_with_current_thread_runtime(config: graph::Config, data_provider: DP) -> Self {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("failed to create tokio runtime");

let handle = rt.handle().clone();

let (rt, handle) = create_current_thread_runtime();
Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE))
}

/// Construct a synchronous `DiskANNIndex` that uses a provided `tokio::runtime::Handle`.
///
/// The `tokio::runtime::Runtime` is owned externally and we just keep a `Handle` to it.
/// `thread_hint` is forwarded to [`graph::DiskANNIndex::new`] to size internal thread pools;
/// pass `None` to let it choose a default.
pub fn new_with_handle(
config: graph::Config,
data_provider: DP,
Expand All @@ -83,13 +99,99 @@ where
thread_hint: Option<NonZeroUsize>,
) -> Self {
let inner = Arc::new(graph::DiskANNIndex::new(config, data_provider, thread_hint));

Self {
inner,
_runtime: runtime,
handle,
}
}

/// Run an arbitrary async operation against the underlying
/// [`graph::DiskANNIndex`] using this wrapper's tokio runtime.
///
/// This is a catch-all escape hatch for async methods on the inner index
/// that do not (yet) have a dedicated synchronous wrapper. The closure
/// receives an `&Arc<graph::DiskANNIndex<DP>>` and should return a future.
///
/// # Example
///
/// ```ignore
/// let stats = index.run(|inner| inner.some_async_method(&ctx))?;
/// ```
pub fn run<F, Fut, R>(&self, f: F) -> R
where
F: FnOnce(&Arc<graph::DiskANNIndex<DP>>) -> Fut,
Fut: core::future::Future<Output = R>,
{
self.handle.block_on(f(&self.inner))
}

/// Load a prebuilt index from storage with its own multi-threaded `tokio::runtime::Runtime`.
///
/// This is the synchronous equivalent of
/// [`LoadWith::load_with`](crate::storage::LoadWith::load_with).
/// A default multi-threaded runtime is created and owned by `Self`.
/// For a single-threaded runtime use [`load_with_current_thread_runtime`](Self::load_with_current_thread_runtime),
/// or to supply an external runtime handle use [`load_with_handle`](Self::load_with_handle).
pub fn load_with_multi_thread_runtime<T, P>(provider: &P, auxiliary: &T) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let (rt, handle) = create_multi_thread_runtime();
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: Some(rt),
handle,
})
}

/// Load a prebuilt index from storage with its own single-threaded `tokio::runtime::Runtime`.
///
/// This is the synchronous equivalent of
/// [`LoadWith::load_with`](crate::storage::LoadWith::load_with).
/// A default current-thread runtime is created and owned by `Self`.
/// For a multi-threaded runtime use [`load_with_multi_thread_runtime`](Self::load_with_multi_thread_runtime),
/// or to supply an external runtime handle use [`load_with_handle`](Self::load_with_handle).
pub fn load_with_current_thread_runtime<T, P>(provider: &P, auxiliary: &T) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let (rt, handle) = create_current_thread_runtime();
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: Some(rt),
handle,
})
}

/// Load a prebuilt index from storage using a provided `tokio::runtime::Handle`.
///
/// This is the synchronous equivalent of
/// [`LoadWith::load_with`](crate::storage::LoadWith::load_with).
/// The `tokio::runtime::Runtime` is owned externally and we just keep a `Handle` to it.
/// For an owned runtime use [`load_with_multi_thread_runtime`](Self::load_with_multi_thread_runtime)
/// or [`load_with_current_thread_runtime`](Self::load_with_current_thread_runtime).
pub fn load_with_handle<T, P>(
provider: &P,
auxiliary: &T,
handle: tokio::runtime::Handle,
) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: None,
handle,
})
}

pub fn insert<S, T>(
&self,
strategy: S,
Expand Down Expand Up @@ -222,7 +324,6 @@ where
.block_on(self.inner.consolidate_vector(strategy, context, vector_id))
}

#[allow(clippy::too_many_arguments, clippy::type_complexity)]
pub fn search<S, T, O, OB>(
&self,
strategy: &S,
Expand Down Expand Up @@ -320,3 +421,133 @@ where
self.handle.block_on(self.inner.get_degree_stats(accessor))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use diskann::{
graph::{self, search_output_buffer},
provider::DefaultContext,
utils::ONE,
};
use diskann_utils::test_data_root;
use diskann_vector::distance::Metric;

use super::DiskANNIndex;
use crate::{
index::diskann_async,
model::{
configuration::IndexConfiguration,
graph::provider::async_::{
common::{FullPrecision, TableBasedDeletes},
inmem::{self, CreateFullPrecision, DefaultProvider},
},
},
storage::{AsyncIndexMetadata, SaveWith, StorageReadProvider, VirtualStorageProvider},
utils::create_rnd_from_seed_in_tests,
};

#[test]
fn test_save_then_sync_load_round_trip() {
// -- Build an index in async context and save it -----------------------
let save_path = "/index";
let file_path = "/sift/siftsmall_learn_256pts.fbin";

let train_data = {
let storage = VirtualStorageProvider::new_overlay(test_data_root());
let mut reader = storage.open_reader(file_path).unwrap();
diskann_utils::io::read_bin::<f32>(&mut reader).unwrap()
};

let pq_bytes = 8;
let pq_table = diskann_async::train_pq(
train_data.as_view(),
pq_bytes,
&mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade),
2,
)
.unwrap();

let (build_config, parameters) = diskann_async::simplified_builder(
20,
32,
Metric::L2,
train_data.ncols(),
train_data.nrows(),
|_| {},
)
.unwrap();

let fp_precursor =
CreateFullPrecision::new(parameters.dim, parameters.prefetch_cache_line_level);
let data_provider =
DefaultProvider::new_empty(parameters, fp_precursor, pq_table, TableBasedDeletes)
.unwrap();

let index =
DiskANNIndex::new_with_current_thread_runtime(build_config.clone(), data_provider);

let storage = VirtualStorageProvider::new_memory();
let ctx = DefaultContext;
for (i, v) in train_data.row_iter().enumerate() {
index.insert(FullPrecision, &ctx, &(i as u32), v).unwrap();
}

let save_metadata = AsyncIndexMetadata::new(save_path.to_string());
let storage_ref = &storage;
let metadata_ref = &save_metadata;
index
.run(|inner| {
let inner = Arc::clone(inner);
async move { inner.save_with(storage_ref, metadata_ref).await }
})
.unwrap();

// -- Reload via the synchronous wrapped_async API ----------------------
let load_config = IndexConfiguration::new(
Metric::L2,
train_data.ncols(),
train_data.nrows(),
ONE,
1,
build_config,
);

type TestProvider = inmem::FullPrecisionProvider<
f32,
crate::model::graph::provider::async_::FastMemoryQuantVectorProviderAsync,
crate::model::graph::provider::async_::TableDeleteProviderAsync,
>;

let loaded: DiskANNIndex<TestProvider> =
DiskANNIndex::load_with_current_thread_runtime(&storage, &(save_path, load_config))
.unwrap();

// -- Verify the loaded index is functional -----------------------------
// A single search call is enough to confirm the sync wrapper loaded a
// working index. Exhaustive search-correctness is tested elsewhere.
let top_k = 5;
let search_l = 20;
let mut ids = vec![0u32; top_k];
let mut distances = vec![0.0f32; top_k];
let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances);

let query = train_data.row(0);
let search_params = graph::search::Knn::new_default(top_k, search_l).unwrap();
let stats = loaded
.search(
&FullPrecision,
&DefaultContext,
query,
&search_params,
&mut output,
)
.unwrap();

assert_eq!(stats.result_count, top_k as u32);
// The query is itself in the dataset, so the nearest neighbor must be at distance 0.
assert_eq!(ids[0], 0);
assert_eq!(distances[0], 0.0);
}
}
Loading