diff --git a/.gitignore b/.gitignore index 42830a6c0..db4b67463 100644 --- a/.gitignore +++ b/.gitignore @@ -336,4 +336,7 @@ target/ .history/ # Ignore local settings for claude code AI agent -.claude/settings.local.json \ No newline at end of file +.claude/settings.local.json + +# Generated test data from generate_synthetic_labels tests +rand_labels_50_10K_*.txt \ No newline at end of file diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index 65a7b0e0b..e85847861 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -48,6 +48,7 @@ vfs = { workspace = true, optional = true } approx.workspace = true criterion.workspace = true diskann-utils = { workspace = true, features = ["testing"] } +diskann = { workspace = true, features = ["testing"] } iai-callgrind.workspace = true itertools.workspace = true tempfile.workspace = true diff --git a/diskann-providers/src/model/graph/provider/async_/caching/example.rs b/diskann-providers/src/model/graph/provider/async_/caching/example.rs index f2ee121b9..8c0099cc8 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/example.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/example.rs @@ -4,17 +4,12 @@ */ use diskann::{ - graph::{AdjacencyList, workingset}, + graph::{AdjacencyList, test::provider as test_provider, workingset}, provider::{self as core_provider, DefaultContext}, }; use diskann_utils::future::AsyncFriendly; use diskann_vector::distance::Metric; -use crate::model::graph::provider::async_::{ - common::FullPrecision, - debug_provider::{self, DebugProvider}, -}; - use super::{ bf_cache::{self, Cache}, error::CacheAccessError, @@ -179,14 +174,14 @@ where // Provider Bridge // ///////////////////// -impl<'a> cache_provider::AsCacheAccessorFor<'a, debug_provider::FullAccessor<'a>> for ExampleCache { +impl<'a> cache_provider::AsCacheAccessorFor<'a, test_provider::Accessor<'a>> for ExampleCache { type Accessor = CacheAccessor<'a, bf_cache::VecCacher>; type Error = diskann::error::Infallible; fn as_cache_accessor_for( &'a self, - inner: debug_provider::FullAccessor<'a>, + inner: test_provider::Accessor<'a>, ) -> Result< - cache_provider::CachingAccessor, Self::Accessor>, + cache_provider::CachingAccessor, Self::Accessor>, Self::Error, > { let provider = inner.provider(); @@ -205,7 +200,7 @@ type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; type FullAccessorCache<'a> = CacheAccessor<'a, bf_cache::VecCacher>; impl<'a> cache_provider::CachedFill, WorkingSet> - for debug_provider::FullAccessor<'a> + for test_provider::Accessor<'a> { fn cached_fill<'b, Itr>( &'b mut self, @@ -232,8 +227,14 @@ impl<'a> cache_provider::CachedFill, WorkingSet> vacant.insert(Self::from_cached(element).into()); } None => { - let element = - self.get_element(i).await.map_err(CachingError::Inner)?; + let element = self.get_element(i).await.map_err(|e| match e { + test_provider::AccessError::InvalidId(e) => { + CachingError::Inner(e) + } + test_provider::AccessError::Transient(e) => { + panic!("unexpected transient error: {e}") + } + })?; cache .set_cached(i, Self::as_cached(&element)) .map_err(CachingError::Cache)?; @@ -270,51 +271,47 @@ mod tests { use rstest::rstest; use crate::{ - index::diskann_async::{self, tests as async_tests}, + index::diskann_async::tests as async_tests, model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider}, utils as crate_utils, }; const CTX: &DefaultContext = &DefaultContext; - fn test_provider( + fn make_provider( uncacheable: Option>, - ) -> CachingProvider { + ) -> CachingProvider { let dim = 2; + let start_id = u32::MAX; - let config = debug_provider::DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let table = diskann_async::train_pq( - Matrix::new(0.0, 1, dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, + let config = test_provider::Config::new( + Metric::L2, + 10, // max_degree + test_provider::StartPoint::new(start_id, vec![0.0; dim]), ) .unwrap(); CachingProvider::new( - DebugProvider::new(config, Arc::new(table)).unwrap(), + test_provider::DefaultContextProvider::new(test_provider::Provider::new(config)), ExampleCache::new(PowerOfTwo::new(1024 * 16).unwrap(), uncacheable), ) } #[tokio::test] async fn basic_operations_happy_path() { - let provider = test_provider(None); + let provider = make_provider(None); let ctx = &DefaultContext; // Translations do not yet exist. assert!(provider.to_external_id(ctx, 0).is_err()); assert!(provider.to_internal_id(ctx, &0).is_err()); - assert_eq!(provider.inner().data_writes.get(), 0); + assert_eq!(provider.inner().metrics().set_vector, 0); provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); - assert_eq!(provider.inner().data_writes.get(), 1 /* increased */); + assert_eq!( + provider.inner().metrics().set_vector, + 1 /* increased */ + ); assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); @@ -322,14 +319,14 @@ mod tests { // Retrieval of a valid element. let mut accessor = provider .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) .unwrap(); // Hit served from the underlying provider. - assert_eq!(provider.inner().full_reads.get(), 0); + // Note: get_vector uses a LocalCounter that flushes on drop, so we track reads + // through cache miss counts instead. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); assert_eq!( accessor.cache().stats.get_local_misses(), 1, /* increased */ @@ -339,7 +336,6 @@ mod tests { // This time, the hit is served from the underlying cache. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); assert_eq!(accessor.cache().stats.get_local_misses(), 1); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -347,18 +343,21 @@ mod tests { ); // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); + assert_eq!( + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, + 0 + ); accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, 1, /* increased */ ); let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); + assert_eq!(provider.inner().metrics().get_neighbors, 0); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 1, /* increased */ ); assert_eq!( @@ -372,7 +371,7 @@ mod tests { list.clear(); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); assert_eq!( accessor.cache().graph.stats().get_local_hits(), @@ -385,7 +384,6 @@ mod tests { let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2 /* increased */,); assert_eq!( accessor.cache().stats.get_local_misses(), 2, /* increased */ @@ -395,7 +393,6 @@ mod tests { // Once more from the cache. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!(accessor.cache().stats.get_local_misses(), 2); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -406,7 +403,7 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 2, /* increased */ ); assert_eq!( @@ -420,11 +417,11 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[2, 3, 4]); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, 2, /* increased */ ); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 3, /* increased */ ); assert_eq!( @@ -438,11 +435,11 @@ mod tests { assert_eq!(&*list, &[2, 3, 4, 1]); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, 3, /* increased */ ); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 4, /* increased */ ); assert_eq!( @@ -479,7 +476,6 @@ mod tests { // Access the deleted element is still valid. let element = accessor.get_element(0).await.unwrap(); assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!(accessor.cache().stats.get_local_misses(), 2); assert_eq!( accessor.cache().stats.get_local_hits(), @@ -488,8 +484,11 @@ mod tests { accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[2, 3, 4, 1]); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); + assert_eq!( + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, + 3 + ); + assert_eq!(provider.inner().metrics().get_neighbors, 4); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 4); assert_eq!( accessor.cache().graph.stats().get_local_hits(), @@ -501,7 +500,6 @@ mod tests { assert!(provider.status_by_external_id(CTX, &0).await.is_err()); assert!(accessor.get_element(0).await.is_err()); - assert_eq!(provider.inner().full_reads.get(), 2); assert_eq!( accessor.cache().stats.get_local_misses(), 3 /* increased */ @@ -509,8 +507,11 @@ mod tests { assert_eq!(accessor.cache().stats.get_local_hits(), 3); assert!(accessor.get_neighbors(0, &mut list).await.is_err()); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); + assert_eq!( + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, + 3 + ); + assert_eq!(provider.inner().metrics().get_neighbors, 4); assert_eq!( accessor.cache().graph.stats().get_local_misses(), 5 /* increased */ @@ -540,11 +541,11 @@ mod tests { // Test that returning `Uncacheable` for an adjacency list is handled correctly by // the provider and a call to `set_neighbors` is not made. let uncacheable = u32::MAX; - let provider = test_provider(Some(vec![uncacheable])); + let provider = make_provider(Some(vec![uncacheable])); let mut accessor = provider .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) .unwrap(); provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); @@ -554,18 +555,21 @@ mod tests { //---------------// // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); + assert_eq!( + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, + 0 + ); accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, 1, /* increased */ ); let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); + assert_eq!(provider.inner().metrics().get_neighbors, 0); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 1, /* increased */ ); assert_eq!( @@ -579,7 +583,7 @@ mod tests { list.clear(); accessor.get_neighbors(0, &mut list).await.unwrap(); assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); assert_eq!( accessor.cache().graph.stats().get_local_hits(), @@ -590,21 +594,24 @@ mod tests { // Uncacheable IDs // //-----------------// - assert_eq!(provider.inner().neighbor_writes.get(), 1); + assert_eq!( + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, + 1 + ); accessor.set_neighbors(uncacheable, &[4, 5]).await.unwrap(); assert_eq!( - provider.inner().neighbor_writes.get(), + provider.inner().metrics().set_neighbors + provider.inner().metrics().append_neighbors, 2, /* increased */ ); // The retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 1); + assert_eq!(provider.inner().metrics().get_neighbors, 1); accessor .get_neighbors(uncacheable, &mut list) .await .unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 2, /* increased */ ); assert_eq!( @@ -615,13 +622,13 @@ mod tests { assert_eq!(&*list, &[4, 5]); // Again, retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 2); + assert_eq!(provider.inner().metrics().get_neighbors, 2); accessor .get_neighbors(uncacheable, &mut list) .await .unwrap(); assert_eq!( - provider.inner().neighbor_reads.get(), + provider.inner().metrics().get_neighbors, 3, /* increased */ ); assert_eq!( @@ -659,24 +666,19 @@ mod tests { .build() .unwrap(); - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), + let provider_config = test_provider::Config::new( metric, - }; - - let mut vectors = ::generate_grid(dim, grid_size); - let table = diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), ) .unwrap(); + let mut vectors = ::generate_grid(dim, grid_size); + let provider = CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), + test_provider::DefaultContextProvider::new(test_provider::Provider::new( + provider_config, + )), ExampleCache::new(cache_size, None), ); let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); @@ -690,15 +692,19 @@ mod tests { assert_eq!(adjacency_lists.len(), num_points); assert_eq!(vectors.len(), num_points); - let strategy = cache_provider::Cached::new(FullPrecision); + let strategy = cache_provider::Cached::new(test_provider::DefaultContextStrategy::new()); async_tests::populate_data(index.provider(), CTX, &vectors).await; { // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - &[f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider< + test_provider::DefaultContextProvider, + ExampleCache, + >, + &[f32], + >>::search_accessor(&strategy, index.provider(), CTX) + .unwrap(); async_tests::populate_graph(&mut accessor, &adjacency_lists).await; accessor @@ -728,14 +734,17 @@ mod tests { async_tests::check_grid_search(&index, &vectors, &paged_tests, strategy, strategy).await; } - fn check_stats(caching: &CachingProvider) { + fn check_stats(caching: &CachingProvider) { let provider = caching.inner(); let cache = caching.cache(); - println!("neighbor reads: {}", provider.neighbor_reads.get()); - println!("neighbor writes: {}", provider.neighbor_writes.get()); - println!("vector reads: {}", provider.full_reads.get()); - println!("vector writes: {}", provider.data_writes.get()); + println!("neighbor reads: {}", provider.metrics().get_neighbors); + println!( + "neighbor writes: {}", + provider.metrics().set_neighbors + provider.metrics().append_neighbors + ); + println!("vector reads: {}", provider.metrics().get_vector); + println!("vector writes: {}", provider.metrics().set_vector); println!("neighbor hits: {}", cache.neighbor_stats.get_hits()); println!("neighbor misses: {}", cache.neighbor_stats.get_misses()); @@ -744,12 +753,15 @@ mod tests { // Neighbors assert_eq!( - provider.neighbor_reads.get(), + provider.metrics().get_neighbors, cache.neighbor_stats.get_misses() ); // Vectors - assert_eq!(provider.full_reads.get(), cache.vector_stats.get_misses()); + assert_eq!( + provider.metrics().get_vector, + cache.vector_stats.get_misses() + ); } #[rstest] @@ -775,15 +787,6 @@ mod tests { let num_points = (grid_size).pow(dim as u32); let mut vectors = ::generate_grid(dim, grid_size); - let table = Arc::new( - diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(), - ); let index_config = diskann::graph::config::Builder::new_with( max_degree, @@ -797,12 +800,13 @@ mod tests { .build() .unwrap(); - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), + let provider_config = test_provider::Config::new( metric, - }; + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); + assert_eq!(vectors.len(), num_points); // This is a little subtle, but we need `vectors` to contain the start point as @@ -814,13 +818,15 @@ mod tests { // Initialize an index for a new round of building. let init_index = || { let provider = CachingProvider::new( - DebugProvider::new(test_config.clone(), table.clone()).unwrap(), + test_provider::DefaultContextProvider::new(test_provider::Provider::new( + provider_config.clone(), + )), ExampleCache::new(cache_size, None), ); Arc::new(DiskANNIndex::new(index_config.clone(), provider, None)) }; - let strategy = cache_provider::Cached::new(FullPrecision); + let strategy = cache_provider::Cached::new(test_provider::DefaultContextStrategy::new()); // Build with full-precision single insert { @@ -859,11 +865,10 @@ mod tests { // create small index instance let metric = Metric::L2; let num_points = 4; - let strategy = cache_provider::Cached::new(FullPrecision); + let strategy = cache_provider::Cached::new(test_provider::DefaultContextStrategy::new()); let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); let start_id = num_points as u32; let start_point = vec![0.5, 0.5]; - let dim = start_point.len(); let index_config = diskann::graph::config::Builder::new( 4, // target_degree @@ -874,27 +879,19 @@ mod tests { .build() .unwrap(); - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), + let provider_config = test_provider::Config::new( metric, - }; - - // The contents of the table don't matter for this test because we use full - // precision only. - let table = diskann_async::train_pq( - Matrix::new(0.5, 1, dim).as_view(), - dim, - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), ) .unwrap(); let index = DiskANNIndex::new( index_config, CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), + test_provider::DefaultContextProvider::new(test_provider::Provider::new( + provider_config, + )), ExampleCache::new(cache_size, None), ), None, @@ -918,15 +915,22 @@ mod tests { ]; // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - &[f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider< + test_provider::DefaultContextProvider, + ExampleCache, + >, + &[f32], + >>::search_accessor(&strategy, index.provider(), CTX) + .unwrap(); async_tests::populate_data(index.provider(), CTX, &vectors).await; async_tests::populate_graph(&mut accessor, &adjacency_lists).await; + // Drop the accessor before inplace_delete so the cache is cleanly separated. + std::mem::drop(accessor); + index .inplace_delete( strategy, @@ -957,6 +961,18 @@ mod tests { // and replaced with edges to points 0 and 1 // vertices 0 and 1 should add an edge pointing to 2. // vertex 3 should be dropped + + // Create a fresh accessor to read the post-deletion state. + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider< + test_provider::DefaultContextProvider, + ExampleCache, + >, + &[f32], + >>::search_accessor(&strategy, index.provider(), CTX) + .unwrap(); + { let mut list = AdjacencyList::new(); accessor.get_neighbors(4, &mut list).await.unwrap(); diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs deleted file mode 100644 index dea11e7a8..000000000 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ /dev/null @@ -1,1459 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - collections::{HashMap, hash_map}, - sync::{ - Arc, RwLock, RwLockReadGuard, RwLockWriteGuard, - atomic::{AtomicUsize, Ordering}, - }, -}; - -use diskann::{ - ANNError, ANNErrorKind, ANNResult, default_post_processor, - graph::{ - AdjacencyList, - glue::{ - self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InplaceDeleteStrategy, - InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, - }, - workingset::{self, map}, - }, - provider::{ - self, Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultAccessor, - DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, - NeighborAccessorMut, - }, - tracked_warn, - utils::VectorRepr, -}; -use diskann_quantization::CompressInto; -use diskann_utils::views::Matrix; -use diskann_vector::distance::Metric; -use thiserror::Error; - -use crate::{ - model::{ - FixedChunkPQTable, - distance::{DistanceComputer, QueryComputer}, - graph::provider::async_::{ - common::{FullPrecision, Panics, Quantized}, - distances::{ - self, - pq::{Hybrid, HybridMap}, - }, - postprocess, - }, - pq, - }, - utils::BridgeErr, -}; - -#[derive(Debug, Clone)] -pub struct DebugConfig { - pub start_id: u32, - pub start_point: Vec, - pub max_degree: usize, - pub metric: Metric, -} - -/// A version of `DebugConfig` that has the compressed representation of `start_point` in -/// addition to the full-precision representation. -#[derive(Debug, Clone)] -struct InternalConfig { - start_id: u32, - start_point: Datum, - max_degree: usize, - metric: Metric, -} - -/// A combined full-precision and PQ quantized vector. -#[derive(Debug, Default, Clone)] -pub struct Datum { - full: Vec, - quant: Vec, -} - -impl Datum { - /// Create a new `Datum`. - fn new(full: Vec, quant: Vec) -> Self { - Self { full, quant } - } - - /// Return a reference to the full-precision vector. - fn full(&self) -> &[f32] { - &self.full - } - - /// Return a reference to the quantized vector. - fn quant(&self) -> &[u8] { - &self.quant - } -} - -/// A container for `Datum`s within the `Debug` provider. -/// -/// This tracks whether items are valid or have been marked as deleted inline. -#[derive(Debug, Clone)] -pub enum Vector { - Valid(Datum), - Deleted(Datum), -} - -impl Vector { - /// Change `self` to be `Self::Deleted`, leaving the internal `Datum` unchanged. - fn mark_deleted(&mut self) { - *self = match self.take() { - Self::Valid(v) => Self::Deleted(v), - Self::Deleted(v) => Self::Deleted(v), - } - } - - /// Take the internal `Datum` and construct a new instance of `Self`. - /// - /// Leave the caller with an empty `Datum`. - fn take(&mut self) -> Self { - match self { - Self::Valid(v) => Self::Valid(std::mem::take(v)), - Self::Deleted(v) => Self::Deleted(std::mem::take(v)), - } - } - - /// Return `true` if `self` has been marked as deleted. Otherwise, return `false`. - fn is_deleted(&self) -> bool { - matches!(self, Self::Deleted(_)) - } -} - -impl std::ops::Deref for Vector { - type Target = Datum; - fn deref(&self) -> &Datum { - match self { - Self::Valid(v) => v, - Self::Deleted(v) => v, - } - } -} - -/// A simple increment-only thread-safe counter. -#[derive(Debug)] -pub struct Counter(AtomicUsize); - -impl Counter { - /// Construct a new counter with a count of 0. - fn new() -> Self { - Self(AtomicUsize::new(0)) - } - - /// Increment the counter by 1. - pub(crate) fn increment(&self) { - self.0.fetch_add(1, Ordering::Relaxed); - } - - /// Return the current value of the counter. - pub(crate) fn get(&self) -> usize { - self.0.load(Ordering::Relaxed) - } -} - -pub struct DebugProvider { - config: InternalConfig, - - pub pq_table: Arc, - pub data: RwLock>, - pub neighbors: RwLock>>, - - // Counters - pub full_reads: Counter, - pub quant_reads: Counter, - pub neighbor_reads: Counter, - pub data_writes: Counter, - pub neighbor_writes: Counter, - - // Track whether the `insert_search_accessor` is invoked. - pub insert_search_accessor_calls: Counter, -} - -impl DebugProvider { - pub fn new(config: DebugConfig, pq_table: Arc) -> ANNResult { - // Compress the start point. - let mut pq = vec![0u8; pq_table.get_num_chunks()]; - pq_table - .compress_into(config.start_point.as_slice(), pq.as_mut_slice()) - .bridge_err()?; - - let config = InternalConfig { - start_id: config.start_id, - start_point: Datum::new(config.start_point, pq), - max_degree: config.max_degree, - metric: config.metric, - }; - - let mut data = HashMap::new(); - data.insert(config.start_id, Vector::Valid(config.start_point.clone())); - - let mut neighbors = HashMap::new(); - neighbors.insert(config.start_id, Vec::new()); - - Ok(Self { - config, - pq_table: pq_table.clone(), - data: RwLock::new(data), - neighbors: RwLock::new(neighbors), - full_reads: Counter::new(), - quant_reads: Counter::new(), - neighbor_reads: Counter::new(), - data_writes: Counter::new(), - neighbor_writes: Counter::new(), - insert_search_accessor_calls: Counter::new(), - }) - } - - /// Return the dimension of the full-precision data. - pub fn dim(&self) -> usize { - self.config.start_point.full().len() - } - - /// Return the maximum degree that can be held by this graph. - pub fn max_degree(&self) -> usize { - self.config.max_degree - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data(&self) -> RwLockReadGuard<'_, HashMap> { - self.data.read().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data_mut(&self) -> RwLockWriteGuard<'_, HashMap> { - self.data.write().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors(&self) -> RwLockReadGuard<'_, HashMap>> { - self.neighbors - .read() - .expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors_mut(&self) -> RwLockWriteGuard<'_, HashMap>> { - self.neighbors - .write() - .expect("cannot recover from lock poison") - } - - fn is_deleted(&self, id: u32) -> Result { - match self.data().get(&id) { - Some(element) => Ok(element.is_deleted()), - None => Err(InvalidId::Internal(id)), - } - } -} - -/// Light-weight error type for reporting access to an invalid ID. -#[derive(Debug, Clone, Copy, Error)] -pub enum InvalidId { - #[error("internal id {0} not initialized")] - Internal(u32), - #[error("external id {0} not initialized")] - External(u32), - #[error("is start point {0}")] - IsStartPoint(u32), -} - -diskann::always_escalate!(InvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: InvalidId) -> ANNError { - ANNError::opaque(err) - } -} - -////////////////// -// DataProvider // -////////////////// - -impl DataProvider for DebugProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u32; - type Guard = provider::NoopGuard; - type Error = InvalidId; - - fn to_internal_id( - &self, - _context: &DefaultContext, - gid: &Self::ExternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(gid); - if valid { - Ok(*gid) - } else { - Err(InvalidId::External(*gid)) - } - } - - fn to_external_id( - &self, - _context: &DefaultContext, - id: Self::InternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(&id); - if valid { - Ok(id) - } else { - Err(InvalidId::External(id)) - } - } -} - -impl Delete for DebugProvider { - async fn delete( - &self, - _context: &Self::Context, - gid: &Self::ExternalId, - ) -> Result<(), Self::Error> { - if *gid == self.config.start_id { - return Err(InvalidId::IsStartPoint(*gid)); - } - - let mut guard = self.data_mut(); - match guard.entry(*gid) { - hash_map::Entry::Occupied(mut occupied) => { - occupied.get_mut().mark_deleted(); - Ok(()) - } - hash_map::Entry::Vacant(_) => Err(InvalidId::External(*gid)), - } - } - - async fn release( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result<(), Self::Error> { - if id == self.config.start_id { - return Err(InvalidId::IsStartPoint(id)); - } - - // NOTE: acquire the locks in the order `data` then `neighbors`. - let mut data = self.data_mut(); - let mut neighbors = self.neighbors_mut(); - - let v = data.remove(&id); - let u = neighbors.remove(&id); - - if v.is_none() || u.is_none() { - Err(InvalidId::Internal(id)) - } else { - Ok(()) - } - } - - async fn status_by_internal_id( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result { - if self.is_deleted(id)? { - Ok(provider::ElementStatus::Deleted) - } else { - Ok(provider::ElementStatus::Valid) - } - } - - fn status_by_external_id( - &self, - context: &Self::Context, - gid: &Self::ExternalId, - ) -> impl Future> + Send { - self.status_by_internal_id(context, *gid) - } -} - -impl provider::SetElement<&[f32]> for DebugProvider { - type SetError = ANNError; - - fn set_element( - &self, - _context: &Self::Context, - id: &Self::ExternalId, - element: &[f32], - ) -> impl Future> + Send { - #[derive(Debug, Clone, Copy, Error)] - #[error("vector id {0} is already assigned")] - pub struct AlreadyAssigned(u32); - - diskann::always_escalate!(AlreadyAssigned); - - impl From for ANNError { - #[track_caller] - fn from(err: AlreadyAssigned) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } - } - - // NOTE: acquire the locks in the order `vectors` then `neighbors`. - let result = match self.data_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(data) => match self.neighbors_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(neighbors) => { - self.data_writes.increment(); - - let mut pq = vec![0u8; self.pq_table.get_num_chunks()]; - match self - .pq_table - .compress_into(element, pq.as_mut_slice()) - .bridge_err() - { - Ok(()) => { - data.insert(Vector::Valid(Datum::new(element.into(), pq))); - neighbors.insert(Vec::new()); - Ok(provider::NoopGuard::new(*id)) - } - Err(err) => Err(ANNError::from(err)), - } - } - }, - }; - - std::future::ready(result) - } -} - -impl postprocess::DeletionCheck for DebugProvider { - fn deletion_check(&self, id: u32) -> bool { - match self.is_deleted(id) { - Ok(is_deleted) => is_deleted, - Err(err) => { - tracked_warn!("Deletion post-process failed with error {err} - continuing"); - true - } - } - } -} - -/////////////// -// Accessors // -/////////////// - -#[derive(Debug, Clone, Copy, Error)] -#[error("Attempt to access an invalid id: {0}")] -pub struct AccessedInvalidId(u32); - -diskann::always_escalate!(AccessedInvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: AccessedInvalidId) -> Self { - Self::opaque(err) - } -} - -impl DefaultAccessor for DebugProvider { - type Accessor<'a> = DebugNeighborAccessor<'a>; - - fn default_accessor(&self) -> Self::Accessor<'_> { - DebugNeighborAccessor::new(self) - } -} - -#[derive(Clone, Copy)] -pub struct DebugNeighborAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> DebugNeighborAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for DebugNeighborAccessor<'_> { - type Id = u32; -} - -impl NeighborAccessor for DebugNeighborAccessor<'_> { - fn get_neighbors( - self, - id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - let result = match self.provider.neighbors().get(&id) { - Some(v) => { - self.provider.neighbor_reads.increment(); - neighbors.overwrite_trusted(v); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -impl NeighborAccessorMut for DebugNeighborAccessor<'_> { - fn set_neighbors( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - assert!(neighbors.len() <= self.provider.config.max_degree); - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - self.provider.neighbor_writes.increment(); - v.clear(); - v.extend_from_slice(neighbors); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } - - fn append_vector( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - assert!( - v.len().checked_add(neighbors.len()).unwrap() - <= self.provider.config.max_degree, - "current = {:?}, new = {:?}, id = {}", - v, - neighbors, - id - ); - - let check = neighbors.iter().try_for_each(|n| { - if v.contains(n) { - Err(ANNError::message( - ANNErrorKind::Opaque, - format!("id {} is duplicated", n), - )) - } else { - Ok(()) - } - }); - - match check { - Ok(()) => { - self.provider.neighbor_writes.increment(); - v.extend_from_slice(neighbors); - Ok(self) - } - Err(err) => Err(err), - } - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -//---------------// -// Full Accessor // -//---------------// - -pub struct FullAccessor<'a> { - provider: &'a DebugProvider, - buffer: Box<[f32]>, -} - -impl<'a> FullAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - let buffer = (0..provider.dim()).map(|_| 0.0).collect(); - Self { provider, buffer } - } - - pub fn provider(&self) -> &DebugProvider { - self.provider - } -} - -impl HasId for FullAccessor<'_> { - type Id = u32; -} - -impl Accessor for FullAccessor<'_> { - type Element<'a> - = &'a [f32] - where - Self: 'a; - type ElementRef<'a> = &'a [f32]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - self.buffer.copy_from_slice(v.full()); - Ok(&*self.buffer) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl diskann::provider::CacheableAccessor for FullAccessor<'_> { - type Map = diskann_utils::lifetime::Slice; - - fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] - where - Self: 'a + 'b, - { - element - } - - fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] - where - Self: 'a, - { - element - } -} - -impl SearchExt for FullAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for FullAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for FullAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = ::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(f32::distance( - self.provider.config.metric, - Some(self.provider.dim()), - )) - } -} - -impl BuildQueryComputer<&[f32]> for FullAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = ::QueryDistance; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, self.provider.config.metric)) - } -} - -impl ExpandBeam<&[f32]> for FullAccessor<'_> {} - -impl postprocess::AsDeletionCheck for FullAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//----------------// -// Quant Accessor // -//----------------// - -pub struct QuantAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> QuantAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for QuantAccessor<'_> { - type Id = u32; -} - -impl Accessor for QuantAccessor<'_> { - type Element<'a> - = Vec - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.quant_reads.increment(); - Ok(v.quant().to_owned()) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for QuantAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for QuantAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildQueryComputer<&[f32]> for QuantAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = pq::distance::QueryComputer>; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(QueryComputer::new( - self.provider.pq_table.clone(), - self.provider.config.metric, - from, - None, - ) - .unwrap()) - } -} - -impl ExpandBeam<&[f32]> for QuantAccessor<'_> {} - -impl postprocess::AsDeletionCheck for QuantAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//-----------------// -// Hybrid Accessor // -//-----------------// - -pub struct HybridAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> HybridAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for HybridAccessor<'_> { - type Id = u32; -} - -impl Accessor for HybridAccessor<'_> { - type Element<'a> - = Hybrid, Vec> - where - Self: 'a; - type ElementRef<'a> = Hybrid<&'a [f32], &'a [u8]>; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - Ok(Hybrid::Full(v.full().to_owned())) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for HybridAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for HybridAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for HybridAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = distances::pq::HybridComputer; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(distances::pq::HybridComputer::new( - DistanceComputer::new(self.provider.pq_table.clone(), self.provider.config.metric) - .unwrap(), - f32::distance(self.provider.config.metric, Some(self.provider.dim())), - )) - } -} - -impl workingset::Fill> for HybridAccessor<'_> { - type Error = diskann::error::Infallible; - type View<'a> - = distances::pq::View<'a, f32, u8> - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - state: &'a mut HybridMap, - itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - let map = state.get_mut(); - map.prepare(itr.clone()); - - let threshold = 1; // one full vec per fill - let data = self.provider.data(); - itr.enumerate().for_each(|(i, id)| match map.entry(id) { - map::Entry::Seeded(_) | map::Entry::Occupied(_) => {} - map::Entry::Vacant(v) => { - let element = data.get(&id).unwrap(); - if i < threshold { - v.insert(Hybrid::Full(element.full().to_owned())) - } else { - v.insert(Hybrid::Quant(element.quant().to_owned())) - } - } - }); - - Ok(map.view()) - } -} - -//////////////// -// Strategies // -//////////////// - -impl SearchStrategy for FullPrecision { - type QueryComputer = ::QueryDistance; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -impl DefaultPostProcessor for FullPrecision { - default_post_processor!(Pipeline); -} - -impl SearchStrategy for Quantized { - type QueryComputer = pq::distance::QueryComputer>; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } -} - -impl DefaultPostProcessor for Quantized { - default_post_processor!(Pipeline); -} - -impl PruneStrategy for FullPrecision { - type DistanceComputer = ::Distance; - type PruneAccessor<'a> = FullAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - type WorkingSet = map::Map, map::Ref<[f32]>>; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { - map::Builder::new(map::Capacity::Default).build(capacity) - } -} - -impl PruneStrategy for Quantized { - type DistanceComputer = distances::pq::HybridComputer; - type PruneAccessor<'a> = HybridAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - type WorkingSet = HybridMap; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(HybridAccessor::new(provider)) - } - - fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { - HybridMap::with_capacity(capacity) - } -} - -impl InsertStrategy for FullPrecision { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InsertStrategy for Quantized { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InplaceDeleteStrategy for FullPrecision { - type DeleteElement<'a> = &'a [f32]; - type DeleteElementGuard = Box<[f32]>; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a>; - type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - *self - } - - fn search_post_processor(&self) -> Self::SearchPostProcessor { - postprocess::RemoveDeletedIdsAndCopy - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().into()) - } -} - -impl InplaceDeleteStrategy for Quantized { - type DeleteElement<'a> = &'a [f32]; - type DeleteElementGuard = Box<[f32]>; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = QuantAccessor<'a>; - type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - *self - } - - fn search_post_processor(&self) -> Self::SearchPostProcessor { - postprocess::RemoveDeletedIdsAndCopy - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().into()) - } -} - -impl glue::MultiInsertStrategy> for FullPrecision { - type Seed = map::Builder>; - type WorkingSet = map::Map, map::Ref<[f32]>>; - type FinishError = diskann::error::Infallible; - type InsertStrategy = Self; - - fn insert_strategy(&self) -> Self::InsertStrategy { - *self - } - - fn finish( - &self, - _provider: &DebugProvider, - _ctx: &DefaultContext, - batch: &Arc>, - ids: Itr, - ) -> impl std::future::Future> + Send - where - Itr: ExactSizeIterator + Send, - { - std::future::ready(Ok(map::Builder::new(map::Capacity::Default) - .with_overlay(map::Overlay::from_batch(batch.clone(), ids)))) - } -} - -impl glue::MultiInsertStrategy> for Quantized { - type Seed = map::Overlay>; - type WorkingSet = HybridMap; - type FinishError = diskann::error::Infallible; - type InsertStrategy = Self; - - fn insert_strategy(&self) -> Self::InsertStrategy { - *self - } - - fn finish( - &self, - _provider: &DebugProvider, - _ctx: &DefaultContext, - batch: &Arc>, - ids: Itr, - ) -> impl std::future::Future> + Send - where - Itr: ExactSizeIterator + Send, - { - std::future::ready(Ok(map::Overlay::from_batch(batch.clone(), ids))) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use diskann::{ - graph::{self, DiskANNIndex}, - provider::{Guard, SetElement}, - }; - use diskann_utils::views::Matrix; - use diskann_vector::{PureDistanceFunction, distance::SquaredL2}; - use rstest::rstest; - - use super::*; - use crate::{ - index::diskann_async::{ - tests::{ - GenerateGrid, PagedSearch, check_grid_search, populate_data, populate_graph, squish, - }, - train_pq, - }, - test_utils::groundtruth, - utils, - }; - - #[tokio::test] - async fn basic_operations() { - let dim = 2; - let ctx = &DefaultContext; - - let debug_config = DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let vectors = [vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; - let pq_table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - let provider = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - - provider - .set_element(ctx, &0, &[1.0, 1.0]) - .await - .unwrap() - .complete() - .await; - - // internal id = external id - assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); - assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); - - let mut accessor = FullAccessor::new(&provider); - - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 1); - - let mut neighbors = AdjacencyList::new(); - - let accessor = provider.default_accessor(); - let res = accessor.get_neighbors(0, &mut neighbors).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_reads.get(), 1); - - let accessor = provider.default_accessor(); - let res = accessor.set_neighbors(0, &[1, 2, 3]).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_writes.get(), 1); - - // delete and release vector 0 - let res = provider.delete(&DefaultContext, &0).await; - assert!(res.is_ok()); - assert_eq!( - ElementStatus::Deleted, - provider - .status_by_external_id(&DefaultContext, &0) - .await - .unwrap() - ); - - let mut accessor = FullAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 2); - - let mut accessor = HybridAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 3); - - // Releasing should make the element unreachable. - let res = provider.release(&DefaultContext, 0).await; - assert!(res.is_ok()); - assert!( - provider - .status_by_external_id(&DefaultContext, &0) - .await - .is_err() - ); - } - - pub fn new_quant_index( - index_config: graph::Config, - debug_config: DebugConfig, - pq_table: FixedChunkPQTable, - ) -> Arc> { - let data = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - Arc::new(DiskANNIndex::new(index_config, data, None)) - } - - #[rstest] - #[case(1, 100)] - #[case(3, 7)] - #[case(4, 5)] - #[tokio::test] - async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - let start_id = u32::MAX; - - let index_config = graph::config::Builder::new( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree, - metric: Metric::L2, - }; - - let adjacency_lists = match dim { - 1 => utils::generate_1d_grid_adj_list(grid_size as u32), - 3 => utils::genererate_3d_grid_adj_list(grid_size as u32), - 4 => utils::generate_4d_grid_adj_list(grid_size as u32), - _ => panic!("Unsupported number of dimensions"), - }; - let mut vectors = f32::generate_grid(dim, grid_size); - - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - let index = new_quant_index(index_config, debug_config, table); - { - let mut neighbor_accessor = index.provider().default_accessor(); - populate_data(index.provider(), &DefaultContext, &vectors).await; - populate_graph(&mut neighbor_accessor, &adjacency_lists).await; - - // Set the adjacency list for the start point. - neighbor_accessor - .set_neighbors(start_id, &[num_points as u32 - 1]) - .await - .unwrap(); - } - - // The corpus of actual vectors consists of all but the last point, which we use - // as the start point. - // - // So, when we compute the corpus used during groundtruth generation, we take all - // but this last point. - let corpus: diskann_utils::views::Matrix = - squish(vectors.iter().take(num_points), dim); - - let mut paged_tests = Vec::new(); - - // Test with the zero query. - let query = vec![0.0; dim]; - let gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query, gt)); - - // Test with the start point to ensure it is filtered out. - let query = vectors.last().unwrap(); - let gt = groundtruth(corpus.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query.clone(), gt)); - - // Unfortunately - this is needed for the `check_grid_search` test. - vectors.push(index.provider().config.start_point.full().to_owned()); - check_grid_search(&index, &vectors, &paged_tests, FullPrecision, Quantized).await; - } - - #[rstest] - #[tokio::test] - async fn grid_search_with_build( - #[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize), - ) { - let dim = dim_and_size.0; - let grid_size = dim_and_size.1; - let start_id = u32::MAX; - - let l = 10; - - // NOTE: Be careful changing `max_degree`. It needs to be high enough that the - // graph is navigable, but low enough that the batch parallel handling inside - // `multi_insert` is needed for the multi-insert graph to be navigable. - // - // With the current configured values, removing the other elements in the batch - // from the visited set during `multi_insert` results in a graph failure. - let max_degree = 2 * dim; - - let num_points = (grid_size).pow(dim as u32); - - let index_config = graph::config::Builder::new_with( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - |b| { - b.max_minibatch_par(10); - }, - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree: index_config.max_degree().into(), - metric: Metric::L2, - }; - - let mut vectors = f32::generate_grid(dim, grid_size); - assert_eq!(vectors.len(), num_points); - - // This is a little subtle, but we need `vectors` to contain the start point as - // its last element, but we **don't** want to include it in the index build. - // - // This basically means that we need to be careful with index initialization. - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - // Initialize an index for a new round of building. - let init_index = - || new_quant_index(index_config.clone(), debug_config.clone(), table.clone()); - - // Build with full-precision single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(FullPrecision, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(Quantized, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with full-precision multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch = Arc::new(squish(vectors.iter().take(num_points), dim)); - let ids: Arc<[u32]> = (0..num_points as u32).collect(); - - index - .multi_insert::<_, Matrix>(FullPrecision, &ctx, batch, ids) - .await - .unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch = Arc::new(squish(vectors.iter().take(num_points), dim)); - let ids: Arc<[u32]> = (0..num_points as u32).collect(); - - index - .multi_insert::<_, Matrix>(Quantized, &ctx, batch, ids) - .await - .unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - } -} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..cf719e730 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -39,7 +39,3 @@ pub mod bf_tree; // Caching proxy provider to accelerate slow providers. #[cfg(feature = "bf_tree")] pub mod caching; - -// Debug provider for testing. -#[cfg(test)] -pub mod debug_provider; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index d2542dfdb..185f468ec 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -984,6 +984,11 @@ pub struct Accessor<'a> { } impl<'a> Accessor<'a> { + /// Return the underlying [`Provider`] reference. + pub fn provider(&self) -> &Provider { + self.provider + } + /// Creates an accessor with no flaky behavior (backward-compatible). pub fn new(provider: &'a Provider) -> Self { Self::new_inner(provider, None) @@ -1082,6 +1087,24 @@ impl glue::SearchExt for Accessor<'_> { impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} +impl provider::CacheableAccessor for Accessor<'_> { + type Map = diskann_utils::lifetime::Slice; + + fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] + where + Self: 'a, + { + element + } + + fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] + where + Self: 'a + 'b, + { + element + } +} + #[derive(Debug, Clone, Copy)] pub struct Strategy { // Set this flag to enable reuse within the [`workingset::Map`]. For multi-threaded @@ -1235,6 +1258,329 @@ impl glue::InplaceDeleteStrategy for Strategy { } } +// ============================================================================= +// DefaultContext Wrapper +// ============================================================================= + +/// A wrapper around [`Provider`] that implements [`DataProvider`] with +/// [`DefaultContext`] instead of the test-specific [`Context`]. +/// +/// This exists so the test provider can be used with infrastructure in +/// `diskann-providers` that requires `Context = DefaultContext` (e.g., +/// `check_grid_search`). +/// +/// All methods delegate directly to `Provider` — the context parameter is +/// ignored, just as it is in `Provider`'s own implementations. +pub struct DefaultContextProvider { + inner: Provider, +} + +impl DefaultContextProvider { + /// Create a new wrapper around the given [`Provider`]. + pub fn new(inner: Provider) -> Self { + Self { inner } + } + + /// Access the underlying [`Provider`]. + pub fn provider(&self) -> &Provider { + &self.inner + } +} + +impl std::ops::Deref for DefaultContextProvider { + type Target = Provider; + fn deref(&self) -> &Provider { + &self.inner + } +} + +impl provider::DataProvider for DefaultContextProvider { + type Context = provider::DefaultContext; + type InternalId = u32; + type ExternalId = u32; + type Error = InvalidId; + type Guard = provider::NoopGuard; + + fn to_internal_id( + &self, + _context: &provider::DefaultContext, + gid: &u32, + ) -> Result { + let valid = self.inner.terms.contains_key(gid); + if valid { + Ok(*gid) + } else { + Err(InvalidId::External(*gid)) + } + } + + fn to_external_id( + &self, + _context: &provider::DefaultContext, + id: u32, + ) -> Result { + let valid = self.inner.terms.contains_key(&id); + if valid { + Ok(id) + } else { + Err(InvalidId::Internal(id)) + } + } +} + +impl provider::Delete for DefaultContextProvider { + async fn delete( + &self, + _context: &provider::DefaultContext, + gid: &u32, + ) -> Result<(), InvalidId> { + if self.inner.is_start_point(*gid) { + return Err(InvalidId::IsStartPoint(*gid)); + } + + match self.inner.terms.entry(*gid) { + Entry::Occupied(mut occupied) => { + occupied.get_mut().mark_deleted(); + Ok(()) + } + Entry::Vacant(_) => Err(InvalidId::External(*gid)), + } + } + + async fn release(&self, _context: &provider::DefaultContext, id: u32) -> Result<(), InvalidId> { + if self.inner.is_start_point(id) { + return Err(InvalidId::IsStartPoint(id)); + } + + if self.inner.terms.remove(&id).is_none() { + Err(InvalidId::Internal(id)) + } else { + Ok(()) + } + } + + async fn status_by_internal_id( + &self, + _context: &provider::DefaultContext, + id: u32, + ) -> Result { + if self.inner.is_deleted(id)? { + Ok(provider::ElementStatus::Deleted) + } else { + Ok(provider::ElementStatus::Valid) + } + } + + fn status_by_external_id( + &self, + context: &provider::DefaultContext, + gid: &u32, + ) -> impl std::future::Future> + Send { + self.status_by_internal_id(context, *gid) + } +} + +impl provider::SetElement<&[f32]> for DefaultContextProvider { + type SetError = ANNError; + + async fn set_element( + &self, + _context: &provider::DefaultContext, + id: &u32, + element: &[f32], + ) -> Result, ANNError> { + let ctx = Context::new(); + self.inner.set_element(&ctx, id, element).await + } +} + +impl provider::DefaultAccessor for DefaultContextProvider { + type Accessor<'a> = NeighborAccessor<'a>; + + fn default_accessor(&self) -> Self::Accessor<'_> { + self.inner.default_accessor() + } +} + +/// Strategy adapter for [`DefaultContextProvider`]. +/// +/// Wraps the test [`Strategy`] to work with [`DefaultContextProvider`] instead of +/// [`Provider`]. +#[derive(Debug, Clone, Copy, Default)] +pub struct DefaultContextStrategy { + inner: Strategy, +} + +impl DefaultContextStrategy { + pub fn new() -> Self { + Self::default() + } +} + +impl glue::SearchStrategy for DefaultContextStrategy { + type QueryComputer = ::QueryDistance; + type SearchAccessor<'a> = Accessor<'a>; + type SearchAccessorError = Infallible; + + fn search_accessor<'a>( + &'a self, + provider: &'a DefaultContextProvider, + _context: &'a provider::DefaultContext, + ) -> Result, Self::SearchAccessorError> { + Ok(Accessor::new(&provider.inner)) + } +} + +impl glue::DefaultPostProcessor for DefaultContextStrategy { + default_post_processor!( + glue::Pipeline> + ); +} + +impl glue::PruneStrategy for DefaultContextStrategy { + type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; + type DistanceComputer = ::Distance; + type PruneAccessor<'a> = Accessor<'a>; + type PruneAccessorError = Infallible; + + fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { + self.inner.create_working_set(capacity) + } + + fn prune_accessor<'a>( + &'a self, + provider: &'a DefaultContextProvider, + _context: &'a provider::DefaultContext, + ) -> Result, Self::PruneAccessorError> { + Ok(Accessor::new(&provider.inner)) + } +} + +impl glue::InsertStrategy for DefaultContextStrategy { + type PruneStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DefaultContextProvider, + _context: &'a provider::DefaultContext, + ) -> Result, Self::SearchAccessorError> { + Ok(Accessor::new(&provider.inner)) + } +} + +impl glue::MultiInsertStrategy> for DefaultContextStrategy { + type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; + type Seed = workingset::map::Builder>; + type FinishError = Infallible; + type InsertStrategy = Self; + + fn insert_strategy(&self) -> Self::InsertStrategy { + *self + } + + fn finish( + &self, + _provider: &DefaultContextProvider, + _context: &provider::DefaultContext, + batch: &Arc>, + ids: Itr, + ) -> impl std::future::Future> + Send + where + Itr: ExactSizeIterator + Send, + { + use workingset::map::{Builder, Capacity, Overlay}; + + let capacity = if self.inner.working_set_reuse { + Capacity::Default + } else { + Capacity::None + }; + + std::future::ready(Ok( + Builder::new(capacity).with_overlay(Overlay::from_batch(batch.clone(), ids)) + )) + } +} + +/// A [`glue::SearchPostProcessStep`] that filters out deleted IDs from the candidate stream. +/// +/// This is used in [`DefaultContextStrategy`]'s `InplaceDeleteStrategy` implementation, +/// since `inplace_delete` relies on post-processing to exclude deleted nodes from results. +#[derive(Default)] +pub struct FilterDeletedIds; + +impl<'a, T> glue::SearchPostProcessStep, T> for FilterDeletedIds +where + T: Copy + Send + Sync, + Accessor<'a>: provider::BuildQueryComputer + provider::HasId, +{ + type Error + = NextError + where + NextError: crate::error::StandardError; + + type NextAccessor = Accessor<'a>; + + async fn post_process_step( + &self, + next: &Next, + accessor: &mut Accessor<'a>, + query: T, + computer: & as provider::BuildQueryComputer>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result> + where + I: Iterator as provider::HasId>::Id>> + Send, + B: crate::graph::SearchOutputBuffer< as provider::HasId>::Id> + Send + ?Sized, + Next: glue::SearchPostProcess, T> + Sync, + { + let filtered = candidates.filter(|n| !accessor.provider.is_deleted(n.id).unwrap_or(false)); + next.post_process(accessor, query, computer, filtered, output) + .await + } +} + +impl glue::InplaceDeleteStrategy for DefaultContextStrategy { + type DeleteElement<'a> = &'a [f32]; + type DeleteElementGuard = Box<[f32]>; + type DeleteElementError = AccessedInvalidId; + type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = Accessor<'a>; + type SearchStrategy = Self; + type SearchPostProcessor = glue::Pipeline; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn search_strategy(&self) -> Self::SearchStrategy { + *self + } + + fn search_post_processor(&self) -> Self::SearchPostProcessor { + glue::Pipeline::new(FilterDeletedIds, glue::CopyIds) + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a DefaultContextProvider, + _context: &'a provider::DefaultContext, + id: u32, + ) -> Result { + provider + .inner + .terms + .get(&id) + .map(|v| (*v.data).into()) + .ok_or(AccessedInvalidId(id)) + } +} + /////////// // Tests // /////////// @@ -1896,4 +2242,163 @@ mod tests { assert_message_contains!(msg, "cannot delete start point"); assert!(!provider.is_deleted(0).unwrap()); } + + // ========================================================================= + // DefaultContextProvider / DefaultContextStrategy tests + // ========================================================================= + + fn create_default_context_provider() -> DefaultContextProvider { + DefaultContextProvider::new(create_test_provider()) + } + + #[test] + fn default_context_provider_deref() { + let dcp = create_default_context_provider(); + // Deref should expose the inner Provider's properties. + assert_eq!(dcp.dim(), 2); + assert_eq!(dcp.max_degree(), 4); + assert_eq!(dcp.distance_metric(), Metric::Cosine); + // Explicit accessor method should agree. + assert_eq!(dcp.provider().dim(), 2); + } + + #[test] + fn default_context_provider_id_conversion() { + use provider::DataProvider; + + let dcp = create_default_context_provider(); + let ctx = provider::DefaultContext; + + for id in 0u32..4u32 { + let internal = dcp.to_internal_id(&ctx, &id).unwrap(); + assert_eq!(internal, id); + let external = dcp.to_external_id(&ctx, internal).unwrap(); + assert_eq!(external, id); + } + + // Unknown IDs produce the expected error variants. + let err = dcp.to_internal_id(&ctx, &99).unwrap_err(); + assert!(matches!(err, InvalidId::External(99))); + + let err = dcp.to_external_id(&ctx, 99).unwrap_err(); + assert!(matches!(err, InvalidId::Internal(99))); + } + + #[test] + fn default_context_provider_delete_and_status() { + use provider::{Delete, ElementStatus}; + + let dcp = create_default_context_provider(); + let rt = current_thread_runtime(); + let ctx = provider::DefaultContext; + + // Node 2 starts as valid. + let status = rt.block_on(dcp.status_by_internal_id(&ctx, 2)).unwrap(); + assert_eq!(status, ElementStatus::Valid); + + // Delete it. + rt.block_on(dcp.delete(&ctx, &2)).unwrap(); + let status = rt.block_on(dcp.status_by_internal_id(&ctx, 2)).unwrap(); + assert_eq!(status, ElementStatus::Deleted); + + // status_by_external_id agrees. + let status = rt.block_on(dcp.status_by_external_id(&ctx, &2)).unwrap(); + assert_eq!(status, ElementStatus::Deleted); + + // Deleting a start point fails. + let err = rt.block_on(dcp.delete(&ctx, &0)).unwrap_err(); + assert!(matches!(err, InvalidId::IsStartPoint(0))); + + // Releasing a start point fails. + let err = rt.block_on(dcp.release(&ctx, 0)).unwrap_err(); + assert!(matches!(err, InvalidId::IsStartPoint(0))); + + // Release a non-start-point node. + rt.block_on(dcp.release(&ctx, 3)).unwrap(); + let err = rt.block_on(dcp.status_by_internal_id(&ctx, 3)).unwrap_err(); + assert!(matches!(err, InvalidId::Internal(3))); + + // Releasing an already-released node fails. + let err = rt.block_on(dcp.release(&ctx, 3)).unwrap_err(); + assert!(matches!(err, InvalidId::Internal(3))); + } + + #[test] + fn default_context_provider_set_element() { + use provider::{Accessor as _, Guard, SetElement}; + + let dcp = create_default_context_provider(); + let rt = current_thread_runtime(); + let ctx = provider::DefaultContext; + + let v = vec![0.42f32, 0.58]; + rt.block_on(async { + let guard = dcp.set_element(&ctx, &10, &v).await.unwrap(); + guard.complete().await; + }); + + // Verify via inner provider accessor. + let mut accessor = super::Accessor::new(dcp.provider()); + let element = rt.block_on(accessor.get_element(10)).unwrap(); + assert_eq!(*element, v[..]); + } + + #[test] + fn default_context_provider_default_accessor() { + use provider::{DefaultAccessor, NeighborAccessor}; + + let dcp = create_default_context_provider(); + let na = dcp.default_accessor(); + let rt = current_thread_runtime(); + + // Node 0 has neighbors [1, 2, 3] per create_test_provider. + let mut neighbors = AdjacencyList::new(); + rt.block_on(na.get_neighbors(0, &mut neighbors)).unwrap(); + assert_eq!(neighbors.as_ref(), &[1, 2, 3]); + } + + #[test] + fn default_context_strategy_default_impl() { + let s = DefaultContextStrategy::default(); + let s2 = DefaultContextStrategy::new(); + // Both should be identical (inner Strategy is Default). + assert_eq!(std::mem::size_of_val(&s), std::mem::size_of_val(&s2),); + } + + #[test] + fn default_context_strategy_search_accessor() { + use glue::SearchStrategy; + + let dcp = create_default_context_provider(); + let strategy = DefaultContextStrategy::new(); + let ctx = provider::DefaultContext; + + let accessor = strategy.search_accessor(&dcp, &ctx).unwrap(); + // The accessor should reference the inner provider. + assert_eq!(accessor.provider().dim(), 2); + } + + #[test] + fn accessor_provider_method() { + let provider = create_test_provider(); + let accessor = Accessor::new(&provider); + assert_eq!(accessor.provider().dim(), provider.dim()); + assert_eq!( + accessor.provider().distance_metric(), + provider.distance_metric() + ); + } + + #[test] + fn cacheable_accessor_round_trip() { + use provider::CacheableAccessor; + + let original: &[f32] = &[1.0, 2.0, 3.0]; + + let cached = as CacheableAccessor>::as_cached(&original); + assert_eq!(*cached, original); + + let restored = as CacheableAccessor>::from_cached(cached); + assert_eq!(restored, original); + } }