diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 6cc2c9673..45dff153c 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -41,6 +41,18 @@ where strategy: Strategy, } +/// A [`KNN`] variant that uses explicit post-processing during search. +#[derive(Debug)] +pub struct KNNWithPostProcessor +where + DP: provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: Strategy

, +} + impl KNN where DP: provider::DataProvider, @@ -71,6 +83,39 @@ where } } +impl KNNWithPostProcessor +where + DP: provider::DataProvider, +{ + /// Construct a new [`KNNWithPostProcessor`] searcher. + /// + /// If `strategy` or `post_processor` is one of the container variants of [`Strategy`], + /// its length must match the number of rows in `queries`. If this is the case, then the + /// strategies/processors will have a querywise correspondence (see [`search::SearchResults`]) + /// with the query matrix. + /// + /// # Errors + /// + /// Returns an error if the number of elements in `strategy` or `post_processor` is not + /// compatible with the number of rows in `queries`. + pub fn new( + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: Strategy

, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + post_processor.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor, + })) + } +} + /// Additional metrics collected during [`KNN`] search. /// /// # Note @@ -132,6 +177,59 @@ where } } +impl Search for KNNWithPostProcessor +where + DP: provider::DataProvider, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, + T: AsyncFriendly + Clone, + P: for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + + Clone + + Send + + Sync + + AsyncFriendly, +{ + type Id = DP::ExternalId; + type Parameters = graph::search::Knn; + type Output = Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { + search::IdCount::Fixed(parameters.k_value()) + } + + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> ANNResult + where + O: graph::SearchOutputBuffer + Send, + { + let context = DP::Context::default(); + let knn_search = *parameters; + let stats = self + .index + .search_with( + knn_search, + self.strategy.get(index)?, + self.post_processor.get(index)?.clone(), + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(Metrics { + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + /// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs /// returned by the provided [`Aggregator`]. /// diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs index dd57dbf69..c5e19b0e3 100644 --- a/diskann-benchmark-runner/src/any.rs +++ b/diskann-benchmark-runner/src/any.rs @@ -397,7 +397,7 @@ mod tests { let _: Type = value.convert::().unwrap(); // An invalid match should return an error. - let value = Any::new(0usize, "random-rag"); + let value = Any::new(0usize, "random-determinant-diversity"); let err = value.convert::>().unwrap_err(); let msg = err.to_string(); assert!(msg.contains("invalid dispatch"), "{}", msg); diff --git a/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json b/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json new file mode 100644 index 000000000..e4c61c24f --- /dev/null +++ b/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json @@ -0,0 +1,52 @@ +{ + "search_directories": [ + "C:/data/openai" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/data/openai/openai_index_normal" + }, + "search_phase": { + "queries": "openai_query.bin", + "groundtruth": "openai_gt_50.bin", + "search_list": [100, 200, 400], + "beam_width": 4, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/data/openai/openai_index_normal" + }, + "search_phase": { + "queries": "openai_query.bin", + "groundtruth": "openai_gt_50.bin", + "search_list": [100, 200, 400], + "beam_width": 4, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "is_determinant_diversity_search": true, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 2.0 + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 65e5804a7..ea1f0f130 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -42,6 +42,9 @@ pub(super) struct DiskSearchStats { pub(super) beam_width: usize, pub(super) recall_at: u32, pub(crate) is_flat_search: bool, + pub(crate) is_determinant_diversity_search: bool, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, pub(crate) distance: SimilarityMeasure, pub(crate) uses_vector_filters: bool, pub(super) num_nodes_to_cache: Option, @@ -276,6 +279,9 @@ where Some(search_params.beam_width), vector_filter, search_params.is_flat_search, + search_params.is_determinant_diversity_search, + search_params.determinant_diversity_eta, + search_params.determinant_diversity_power, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; @@ -341,6 +347,9 @@ where beam_width: search_params.beam_width, recall_at: search_params.recall_at, is_flat_search: search_params.is_flat_search, + is_determinant_diversity_search: search_params.is_determinant_diversity_search, + determinant_diversity_eta: search_params.determinant_diversity_eta, + determinant_diversity_power: search_params.determinant_diversity_power, distance: search_params.distance, uses_vector_filters: search_params.vector_filters_file.is_some(), num_nodes_to_cache: search_params.num_nodes_to_cache, @@ -425,6 +434,22 @@ impl fmt::Display for DiskSearchStats { writeln!(f, "Beam width, : {}", self.beam_width)?; writeln!(f, "Recall at, : {}", self.recall_at)?; writeln!(f, "Flat search, : {}", self.is_flat_search)?; + writeln!( + f, + "Det-div search, : {}", + self.is_determinant_diversity_search + )?; + writeln!( + f, + "Det-div params, : {}", + match ( + self.determinant_diversity_eta, + self.determinant_diversity_power, + ) { + (Some(eta), Some(power)) => format!("eta={eta}, power={power}"), + _ => "None".to_string(), + } + )?; writeln!(f, "Distance, : {}", self.distance)?; writeln!(f, "Vector filters, : {}", self.uses_vector_filters)?; writeln!( diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index fa4a77078..68e0c96c9 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -24,7 +24,10 @@ use diskann_benchmark_runner::{ }; use diskann_providers::{ index::diskann_async, - model::{configuration::IndexConfiguration, graph::provider::async_::common}, + model::{ + configuration::IndexConfiguration, + graph::provider::async_::{common, DeterminantDiversitySearchParams}, + }, }; use diskann_utils::{ future::AsyncFriendly, @@ -350,6 +353,8 @@ where + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, + DeterminantDiversitySearchParams: + for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + Send + Sync, { match &input { SearchPhase::Topk(search_phase) => { @@ -366,19 +371,51 @@ where let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - let steps = search::knn::SearchSteps::new( search_phase.reps, &search_phase.num_threads, &search_phase.runs, ); - let search_results = search::knn::run(&knn, &groundtruth, steps)?; + let search_results = if let (Some(eta), Some(power)) = ( + search_phase.determinant_diversity_eta, + search_phase.determinant_diversity_power, + ) { + let processor = DeterminantDiversitySearchParams::new( + search_phase + .determinant_diversity_results_k + .unwrap_or_else(|| { + search_phase + .runs + .iter() + .map(|run| run.search_n) + .max() + .unwrap_or(1) + }), + eta, + power, + ) + .map_err(|err| { + anyhow::anyhow!("Invalid determinant-diversity parameters: {err}") + })?; + + let knn = benchmark_core::search::graph::knn::KNNWithPostProcessor::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + benchmark_core::search::graph::Strategy::broadcast(processor), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + } else { + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + }; result.append(AggregatedSearchResults::Topk(search_results)); Ok(result) } diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 915b8eca6..357e982c2 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -35,6 +35,32 @@ pub(crate) fn run( groundtruth: &dyn benchmark_core::recall::Rows, steps: SearchSteps<'_>, ) -> anyhow::Result> { + run_search(runner, groundtruth, steps, |setup, search_l, search_n| { + let search_params = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + core_search::Run::new(search_params, setup) + }) +} + +type Run = core_search::Run; +pub(crate) trait Knn { + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result>; +} + +fn run_search( + runner: &dyn Knn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + builder: F, +) -> anyhow::Result> +where + F: Fn(core_search::Setup, usize, usize) -> Run, +{ let mut all = Vec::new(); for threads in steps.num_tasks.iter() { @@ -48,12 +74,7 @@ pub(crate) fn run( let parameters: Vec<_> = run .search_l .iter() - .map(|search_l| { - let search_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - core_search::Run::new(search_params, setup.clone()) - }) + .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) .collect(); all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); @@ -63,17 +84,6 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; -pub(crate) trait Knn { - fn search_all( - &self, - parameters: Vec, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result>; -} - /////////// // Impls // /////////// @@ -129,3 +139,30 @@ where Ok(results.into_iter().map(SearchResults::new).collect()) } } + +impl Knn + for Arc> +where + DP: diskann::provider::DataProvider, + core_search::graph::knn::KNNWithPostProcessor: core_search::Search< + Id = DP::InternalId, + Parameters = diskann::graph::search::Knn, + Output = core_search::graph::knn::Metrics, + >, +{ + fn search_all( + &self, + parameters: Vec>, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result> { + let results = core_search::search_all( + self.clone(), + parameters.into_iter(), + core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), + )?; + + Ok(results.into_iter().map(SearchResults::new).collect()) + } +} diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 82bb37dae..4b3e68ba1 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -66,7 +66,9 @@ mod imp { }; use diskann_providers::{ index::diskann_async::{self}, - model::graph::provider::async_::{common::NoDeletes, inmem}, + model::graph::provider::async_::{ + common::NoDeletes, inmem, DeterminantDiversitySearchParams, + }, }; use diskann_quantization::alloc::GlobalAllocator; use diskann_utils::views::Matrix; @@ -331,15 +333,44 @@ mod imp { ); for &layout in self.input.query_layouts.iter() { - let knn = benchmark_core::search::graph::KNN::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::broadcast( - inmem::spherical::Quantized::search(layout.into()), - ), - )?; - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; + let strategy = inmem::spherical::Quantized::search(layout.into()); + let search_results = if let (Some(eta), Some(power)) = ( + search_phase.determinant_diversity_eta, + search_phase.determinant_diversity_power, + ) { + let processor = DeterminantDiversitySearchParams::new( + search_phase + .determinant_diversity_results_k + .unwrap_or_else(|| { + search_phase + .runs + .iter() + .map(|run| run.search_n) + .max() + .unwrap_or(1) + }), + eta, + power, + ) + .map_err(|err| anyhow::anyhow!("Invalid determinant-diversity parameters: {err}"))?; + + let knn = benchmark_core::search::graph::knn::KNNWithPostProcessor::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast(strategy), + benchmark_core::search::graph::Strategy::broadcast(processor), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + } else { + let knn = benchmark_core::search::graph::KNN::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast(strategy), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + }; result.append(SearchRun { layout, results: AggregatedSearchResults::Topk(search_results), diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 19230977d..b160cf4b8 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -123,6 +123,9 @@ pub(crate) struct TopkSearchPhase { pub(crate) queries: InputFile, pub(crate) groundtruth: InputFile, pub(crate) reps: NonZeroUsize, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, + pub(crate) determinant_diversity_results_k: Option, // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, @@ -139,6 +142,44 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } + if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { + return Err(anyhow!( + "determinant_diversity_eta and determinant_diversity_power must either both be set or both be omitted" + )); + } + + if let Some(eta) = self.determinant_diversity_eta { + if !eta.is_finite() || eta < 0.0 { + return Err(anyhow!( + "determinant_diversity_eta must be finite and >= 0.0, got {}", + eta + )); + } + } + + if let Some(power) = self.determinant_diversity_power { + if !power.is_finite() || power < 0.0 { + return Err(anyhow!( + "determinant_diversity_power must be finite and >= 0.0, got {}", + power + )); + } + } + + if let Some(k) = self.determinant_diversity_results_k { + if k == 0 { + return Err(anyhow!("determinant_diversity_results_k must be > 0")); + } + + if self.determinant_diversity_eta.is_none() + || self.determinant_diversity_power.is_none() + { + return Err(anyhow!( + "determinant_diversity_results_k requires determinant_diversity_eta and determinant_diversity_power to both be set" + )); + } + } + Ok(()) } } @@ -164,6 +205,9 @@ impl Example for TopkSearchPhase { queries: InputFile::new("path/to/queries"), groundtruth: InputFile::new("path/to/groundtruth"), reps: REPS, + determinant_diversity_eta: None, + determinant_diversity_power: None, + determinant_diversity_results_k: None, num_threads: THREAD_COUNTS.to_vec(), runs, } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..7121e098b 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -81,6 +81,12 @@ pub(crate) struct DiskSearchPhase { pub(crate) search_list: Vec, pub(crate) recall_at: u32, pub(crate) is_flat_search: bool, + #[serde(default)] + pub(crate) is_determinant_diversity_search: bool, + #[serde(default)] + pub(crate) determinant_diversity_eta: Option, + #[serde(default)] + pub(crate) determinant_diversity_power: Option, pub(crate) distance: SimilarityMeasure, pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, @@ -224,6 +230,35 @@ impl CheckDeserialization for DiskSearchPhase { if self.num_threads == 0 { anyhow::bail!("num_threads must be positive"); } + + if self.is_determinant_diversity_search { + if self.is_flat_search { + anyhow::bail!( + "is_determinant_diversity_search is not supported when is_flat_search is true" + ); + } + + let eta = self.determinant_diversity_eta.unwrap_or(0.01); + let power = self.determinant_diversity_power.unwrap_or(2.0); + + if eta < 0.0 || !eta.is_finite() { + anyhow::bail!("determinant_diversity_eta must be >= 0.0 and finite, got {eta}"); + } + + if power < 0.0 || !power.is_finite() { + anyhow::bail!("determinant_diversity_power must be >= 0.0 and finite, got {power}"); + } + + self.determinant_diversity_eta = Some(eta); + self.determinant_diversity_power = Some(power); + } else if self.determinant_diversity_eta.is_some() + || self.determinant_diversity_power.is_some() + { + anyhow::bail!( + "determinant_diversity_eta/determinant_diversity_power may only be set when is_determinant_diversity_search is true" + ); + } + if let Some(n) = self.num_nodes_to_cache { if n == 0 { anyhow::bail!("num_nodes_to_cache must be positive if specified"); @@ -268,6 +303,9 @@ impl Example for DiskIndexOperation { recall_at: 10, num_threads: 8, is_flat_search: false, + is_determinant_diversity_search: false, + determinant_diversity_eta: None, + determinant_diversity_power: None, distance: SimilarityMeasure::SquaredL2, vector_filters_file: None, num_nodes_to_cache: None, @@ -384,6 +422,22 @@ impl DiskSearchPhase { write_field!(f, "Recall@", self.recall_at)?; write_field!(f, "Threads", self.num_threads)?; write_field!(f, "Flat Search", self.is_flat_search)?; + write_field!( + f, + "Determinant Diversity Search", + self.is_determinant_diversity_search + )?; + match ( + self.determinant_diversity_eta, + self.determinant_diversity_power, + ) { + (Some(eta), Some(power)) => write_field!( + f, + "Determinant Diversity Params", + format!("eta={eta}, power={power}") + )?, + _ => write_field!(f, "Determinant Diversity Params", "none")?, + } write_field!(f, "Distance", self.distance)?; match &self.vector_filters_file { Some(vf) => write_field!(f, "Vector Filters File", vf.display())?, diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index c7f21b682..79075592a 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1103,6 +1103,9 @@ pub(crate) mod disk_index_builder_tests { &mut associated_data, &|_| true, false, + false, + None, + None, ); diskann_providers::test_utils::assert_top_k_exactly_match( diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index c0b16beba..78cc1fec1 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -39,8 +39,10 @@ use diskann::{ use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ model::{ - compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType, - pq::quantizer_preprocess, PQData, PQScratch, + compute_pq_distance, compute_pq_distance_for_pq_coordinates, + graph::{provider::async_::determinant_diversity_post_process, traits::GraphDataType}, + pq::quantizer_preprocess, + PQData, PQScratch, }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; @@ -279,6 +281,30 @@ impl<'a> RerankAndFilter<'a> { } } +#[derive(Clone, Copy)] +pub struct DeterminantDiversityRerankAndFilter<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + top_k: usize, + eta: f64, + power: f64, +} + +impl<'a> DeterminantDiversityRerankAndFilter<'a> { + fn new( + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + top_k: usize, + eta: f64, + power: f64, + ) -> Self { + Self { + filter, + top_k, + eta, + power, + } + } +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -340,6 +366,83 @@ where } } +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + [Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DeterminantDiversityRerankAndFilter<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + _computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; + + let candidate_ids: Vec = candidates + .map(|candidate| candidate.id) + .filter(|id| (self.filter)(id)) + .collect(); + + if candidate_ids.is_empty() { + return Ok(0); + } + + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &candidate_ids)?; + + let mut candidate_vectors = Vec::with_capacity(candidate_ids.len()); + let mut associated_data = HashMap::with_capacity(candidate_ids.len()); + + for id in candidate_ids { + let vector = accessor.scratch.vertex_provider.get_vector(&id)?; + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); + let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; + let data = accessor.scratch.vertex_provider.get_associated_data(&id)?; + + candidate_vectors.push((id, distance, vector_f32.to_vec())); + associated_data.insert(id, *data); + } + + let reranked = determinant_diversity_post_process( + candidate_vectors, + &query_f32, + self.top_k, + self.eta, + self.power, + ); + + Ok( + output.extend(reranked.into_iter().filter_map(|(id, distance)| { + associated_data + .get(&id) + .copied() + .map(|data| ((id, data), distance)) + })), + ) + } +} + impl<'this, Data, ProviderFactory> SearchStrategy, [Data::VectorDataType]> for DiskSearchStrategy<'this, Data, ProviderFactory> where @@ -925,6 +1028,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -933,6 +1037,9 @@ where beam_width: Option, vector_filter: Option>, is_flat_search: bool, + is_determinant_diversity_search: bool, + determinant_diversity_eta: Option, + determinant_diversity_power: Option, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -951,6 +1058,9 @@ where &mut associated_data, &vector_filter.unwrap_or(default_vector_filter::()), is_flat_search, + is_determinant_diversity_search, + determinant_diversity_eta, + determinant_diversity_power, )?; let mut search_result = SearchResult { @@ -988,6 +1098,9 @@ where associated_data: &mut [Data::AssociatedDataType], vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, + is_determinant_diversity_search: bool, + determinant_diversity_eta: Option, + determinant_diversity_power: Option, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -1010,13 +1123,42 @@ where ))? } else { let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search( - knn_search, - &strategy, - &DefaultContext, - strategy.query, - &mut result_output_buffer, - ))? + if is_determinant_diversity_search { + let eta = determinant_diversity_eta.unwrap_or(0.01); + let power = determinant_diversity_power.unwrap_or(2.0); + + if !eta.is_finite() || eta < 0.0 { + return Err(ANNError::log_index_error(format!( + "determinant_diversity_eta must be finite and >= 0.0, got {eta}" + ))); + } + + if !power.is_finite() || power < 0.0 { + return Err(ANNError::log_index_error(format!( + "determinant_diversity_power must be finite and >= 0.0, got {power}" + ))); + } + + let processor = + DeterminantDiversityRerankAndFilter::new(vector_filter, k, eta, power); + + self.runtime.block_on(self.index.search_with( + knn_search, + &strategy, + processor, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? + } else { + self.runtime.block_on(self.index.search( + knn_search, + &strategy, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? + } }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; @@ -1438,6 +1580,9 @@ mod disk_provider_tests { &mut associated_data, &(|_| true), false, + false, + None, + None, ); // Calculate the range of the truth_result for this query @@ -1493,7 +1638,17 @@ mod disk_provider_tests { let query = &aligned_box.as_slice()[1..]; let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + false, + false, + None, + None, + ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1605,6 +1760,9 @@ mod disk_provider_tests { &mut associated_data, &|_| true, false, + false, + None, + None, ); assert!(result.is_err()); @@ -1674,6 +1832,9 @@ mod disk_provider_tests { Some(4), None, false, + false, + None, + None, ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -2013,6 +2174,9 @@ mod disk_provider_tests { &mut associated_data, &vector_filter, is_flat_search, + false, + None, + None, ); assert!(result.is_ok(), "Expected search to succeed"); @@ -2034,6 +2198,9 @@ mod disk_provider_tests { None, // beam_width Some(Box::new(vector_filter)), is_flat_search, + false, + None, + None, ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs new file mode 100644 index 000000000..066f6880d --- /dev/null +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -0,0 +1,515 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Determinant-diversity search post-processing. + +use std::future::Future; + +use diskann::{ + ANNError, + graph::{SearchOutputBuffer, glue}, + neighbor::Neighbor, + provider::BuildQueryComputer, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_vector::{ + DistanceFunction, MathematicalValue, PureDistanceFunction, distance::InnerProduct, +}; + +use super::{ + inmem::GetFullPrecision, + postprocess::{AsDeletionCheck, DeletionCheck}, +}; + +#[derive(Debug)] +pub enum DeterminantDiversityError { + InvalidTopK { top_k: usize }, + InvalidEta { eta: f64 }, + InvalidPower { power: f64 }, +} + +impl std::fmt::Display for DeterminantDiversityError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidTopK { top_k } => write!(f, "top_k must be > 0, got {top_k}"), + Self::InvalidEta { eta } => write!(f, "eta must be >= 0.0, got {eta}"), + Self::InvalidPower { power } => write!(f, "power must be >= 0.0, got {power}"), + } + } +} + +impl std::error::Error for DeterminantDiversityError {} + +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversitySearchParams { + pub top_k: usize, + pub determinant_diversity_eta: f64, + pub determinant_diversity_power: f64, +} + +impl DeterminantDiversitySearchParams { + pub fn new( + top_k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, + ) -> Result { + if top_k == 0 { + return Err(DeterminantDiversityError::InvalidTopK { top_k }); + } + + if determinant_diversity_eta < 0.0 || !determinant_diversity_eta.is_finite() { + return Err(DeterminantDiversityError::InvalidEta { + eta: determinant_diversity_eta, + }); + } + + if determinant_diversity_power < 0.0 || !determinant_diversity_power.is_finite() { + return Err(DeterminantDiversityError::InvalidPower { + power: determinant_diversity_power, + }); + } + + Ok(Self { + top_k, + determinant_diversity_eta, + determinant_diversity_power, + }) + } +} + +impl glue::SearchPostProcess for DeterminantDiversitySearchParams +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = ANNError; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let result = (|| { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let distance = full.distance(); + + let mut candidates_with_vectors = Vec::new(); + for candidate in candidates { + if checker.deletion_check(candidate.id) { + continue; + } + + let vector = unsafe { full.get_vector_sync(candidate.id.into_usize()) }; + let vector_f32 = T::as_f32(vector).map_err(Into::into)?; + let full_precision_distance = distance.evaluate_similarity(query, vector); + candidates_with_vectors.push(( + candidate.id, + full_precision_distance, + vector_f32.to_vec(), + )); + } + + let query_f32 = T::as_f32(query).map_err(Into::into)?; + + let reranked = determinant_diversity_post_process( + candidates_with_vectors, + &query_f32[..], + self.top_k, + self.determinant_diversity_eta, + self.determinant_diversity_power, + ); + + Ok(output.extend(reranked)) + })(); + + std::future::ready(result) + } +} + +pub fn determinant_diversity_post_process( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let candidates: Vec<_> = candidates + .into_iter() + .filter(|(_, _, vector)| vector.len() == query.len()) + .collect(); + + if candidates.is_empty() { + return Vec::new(); + } + + let k = k.min(candidates.len()); + if k == 0 { + return Vec::new(); + } + + if candidates[0].2.is_empty() { + return Vec::new(); + } + + if determinant_diversity_eta > 0.0 { + post_process_with_eta_f32( + candidates, + query, + k, + determinant_diversity_eta, + determinant_diversity_power, + ) + } else { + post_process_greedy_orthogonalization_f32(candidates, query, k, determinant_diversity_power) + } +} + +fn post_process_with_eta_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let eta = determinant_diversity_eta as f32; + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + if candidates[0].2.is_empty() { + return Vec::new(); + } + + let inv_sqrt_eta = 1.0 / eta.sqrt(); + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, similarity_to_query, v) in &candidates { + let similarity = *similarity_to_query; + let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(selected_index) = best_idx else { + break; + }; + + selected.push(selected_index); + available[selected_index] = false; + + if selected.len() == k { + break; + } + + let best_norm_sq = norms_sq[selected_index]; + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq = 1.0 / best_norm_sq; + let r_star_copy = residuals[selected_index].clone(); + + for i in 0..n { + if !available[i] { + projections[i] = 0.0; + } else { + let projection = dot_product(&residuals[i], &r_star_copy) * inv_norm_sq; + projections[i] = projection; + } + } + + for i in 0..n { + if !available[i] { + continue; + } + + let projection = projections[i]; + for (residual, &star) in residuals[i].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[i] = (norms_sq[i] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +fn post_process_greedy_orthogonalization_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, similarity_to_query, v) in &candidates { + let similarity = *similarity_to_query; + let scale = similarity.max(0.0).powf(power as f32); + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + let best = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let Some((best_index, _)) = best else { + break; + }; + + let best_norm_sq = norms_sq[best_index]; + selected.push(best_index); + available[best_index] = false; + + if selected.len() == k { + break; + } + + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq_star = 1.0 / best_norm_sq; + let r_star_copy = residuals[best_index].clone(); + + for j in 0..n { + if !available[j] { + projections[j] = 0.0; + } else { + let projection = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; + projections[j] = projection; + } + } + + for j in 0..n { + if !available[j] { + continue; + } + + let projection = projections[j]; + for (residual, &star) in residuals[j].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[j] = (norms_sq[j] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validation_valid_params() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, 2.0); + assert!(result.is_ok()); + + let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); + assert!(result.is_ok()); + } + + #[test] + fn test_validation_invalid_params() { + let test_cases = [ + ( + DeterminantDiversitySearchParams::new(0, 0.01, 2.0), + DeterminantDiversityError::InvalidTopK { top_k: 0 }, + ), + ( + DeterminantDiversitySearchParams::new(10, -0.01, 2.0), + DeterminantDiversityError::InvalidEta { eta: -0.01 }, + ), + ( + DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0), + DeterminantDiversityError::InvalidEta { eta: f64::NAN }, + ), + ( + DeterminantDiversitySearchParams::new(10, 0.01, -1.0), + DeterminantDiversityError::InvalidPower { power: -1.0 }, + ), + ( + DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY), + DeterminantDiversityError::InvalidPower { + power: f64::INFINITY, + }, + ), + ]; + + for (result, expected) in test_cases { + match (result, expected) { + ( + Err(DeterminantDiversityError::InvalidTopK { top_k: actual }), + DeterminantDiversityError::InvalidTopK { top_k: expected }, + ) => assert_eq!(actual, expected), + ( + Err(DeterminantDiversityError::InvalidEta { eta: actual }), + DeterminantDiversityError::InvalidEta { eta: expected }, + ) => { + if expected.is_nan() { + assert!(actual.is_nan()); + } else { + assert_eq!(actual, expected); + } + } + ( + Err(DeterminantDiversityError::InvalidPower { power: actual }), + DeterminantDiversityError::InvalidPower { power: expected }, + ) => { + if expected.is_infinite() { + assert!(actual.is_infinite()); + } else { + assert_eq!(actual, expected); + } + } + (other, expected) => { + panic!("Unexpected result {:?} for expected {:?}", other, expected) + } + } + } + } + + #[test] + fn test_determinant_diversity_post_process_with_eta() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let v3 = vec![0.0f32, 0.0, 1.0]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2), (3u32, 0.7f32, v3)]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert_eq!(result.len(), 3); + } + + #[test] + fn test_determinant_diversity_post_process_enabled_greedy() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.99f32, 0.1, 0.0]; + let v3 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2), (3u32, 0.4f32, v3)]; + let query = vec![1.0, 1.0, 0.0]; + + let result = determinant_diversity_post_process(candidates, &query, 2, 0.0, 1.0); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_determinant_diversity_post_process_empty() { + let candidates: Vec<(u32, f32, Vec)> = vec![]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert!(result.is_empty()); + } + + #[test] + fn test_determinant_diversity_post_process_k_larger_than_candidates() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2)]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_determinant_diversity_post_process_dimension_mismatch_is_skipped() { + let bad = vec![1.0f32, 0.0]; + let good = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, bad), (2u32, 0.3f32, good)]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, 2); + } +} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..774c5530d 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -9,6 +9,11 @@ pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; pub(crate) mod postprocess; +mod determinant_diversity_post_process; +pub use determinant_diversity_post_process::{ + DeterminantDiversityError, DeterminantDiversitySearchParams, determinant_diversity_post_process, +}; + pub mod distances; pub mod memory_vector_provider; diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 65bb8ffb5..d2ba92151 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -256,6 +256,9 @@ where Some(parameters.beam_width as usize), Some(vector_filter_function), parameters.is_flat_search, + false, + None, + None, ); match result { diff --git a/tmp/wiki_compare_determinant_diversity.json b/tmp/wiki_compare_determinant_diversity.json new file mode 100644 index 000000000..c2c4972c3 --- /dev/null +++ b/tmp/wiki_compare_determinant_diversity.json @@ -0,0 +1,63 @@ +{ + "search_directories": [ + "C:\\wikipedia_dataset" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:\\wikipedia_dataset\\query.bin", + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "reps": 1, + "determinant_diversity_eta": null, + "determinant_diversity_power": null, + "determinant_diversity_results_k": null, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + }, + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:\\wikipedia_dataset\\query.bin", + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "reps": 1, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 2.0, + "determinant_diversity_results_k": 10, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + } + ] +}