diff --git a/diskann-benchmark/example/async-multihop-high-selectivity-small.json b/diskann-benchmark/example/async-multihop-high-selectivity-small.json new file mode 100644 index 000000000..402829728 --- /dev/null +++ b/diskann-benchmark/example/async-multihop-high-selectivity-small.json @@ -0,0 +1,52 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "num_start_points": 1, + "num_insert_attempts": 1, + "saturate_inserts": false, + "start_point_strategy": "medoid" + }, + "search_phase": { + "search-type": "topk-high-selectivity-multihop-filter", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "gt_small_filter.bin", + "query_predicates": "query.10.label.jsonl", + "data_labels": "data.256.label.jsonl", + "reps": 5, + "num_threads": [ + 1 + ], + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40, + 50, + 100, + 200 + ], + "recall_k": 10 + } + ] + } + } + } + ] + } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index d38332ee1..3de782dd7 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -388,6 +388,44 @@ where .collect(), )?; + let search_results = search::knn::run(&multihop, &groundtruth, steps)?; + result.append(AggregatedSearchResults::Topk(search_results)); + Ok(result) + } + SearchPhase::TopkHighSelectivityMultihopFilter(search_phase) => { + // Handle MultiHop Topk search phase with high-selectivity optimization + // This uses RejectAndNeedExpand to enable exploration queue + let mut result = BuildResult::new_topk(build_stats); + + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let bit_maps = + generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; + + let multihop = benchmark_core::search::graph::MultiHop::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + bit_maps + .into_iter() + .map(utils::filters::as_high_selectivity_query_label_provider) + .collect(), + )?; + let search_results = search::knn::run(&multihop, &groundtruth, steps)?; result.append(AggregatedSearchResults::Topk(search_results)); Ok(result) diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 33cb2e2fe..fede4731b 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -460,6 +460,56 @@ mod imp { writeln!(output, "\n\n{}", result)?; Ok(result) } + SearchPhase::TopkHighSelectivityMultihopFilter(search_phase) => { + // Handle MultiHop Topk search with high-selectivity optimization + + // Save construction stats before running queries. + _checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset( + datafiles::BinFile(&search_phase.queries), + )?); + + let groundtruth = datafiles::load_groundtruth(datafiles::BinFile( + &search_phase.groundtruth, + ))?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let bit_maps = generate_bitmaps( + &search_phase.query_predicates, + &search_phase.data_labels, + )?; + + let bit_map_filters: Arc<[_]> = bit_maps + .into_iter() + .map(utils::filters::as_high_selectivity_query_label_provider) + .collect(); + + for &layout in self.input.query_layouts.iter() { + let multihop = benchmark_core::search::graph::MultiHop::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast( + inmem::spherical::Quantized::search(layout.into()), + ), + bit_map_filters.clone(), + )?; + + let search_results = + search::knn::run(&multihop, &groundtruth, steps)?; + result.append(SearchRun { + layout, + results: AggregatedSearchResults::Topk(search_results), + }); + } + writeln!(output, "\n\n{}", result)?; + Ok(result) + } } } } diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index c76fdb594..e4374e689 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -263,6 +263,43 @@ impl CheckDeserialization for MultiHopSearchPhase { } } +/// Multi-hop search phase with high-selectivity optimization enabled. +/// +/// This search type uses `RejectAndNeedExpand` for rejected nodes, enabling +/// the exploration queue mechanism. This is beneficial when the filter has +/// high selectivity (few matching vectors), as it allows the search to +/// continue exploring through non-matching nodes even when the primary +/// queue is exhausted. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct HighSelectivityMultiHopSearchPhase { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) reps: NonZeroUsize, + pub(crate) data_labels: InputFile, + // Enable sweeping threads + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +impl CheckDeserialization for HighSelectivityMultiHopSearchPhase { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + // Check the validity of the input files. + self.queries.check_deserialization(checker)?; + + self.query_predicates.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + + self.groundtruth.check_deserialization(checker)?; + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + + Ok(()) + } +} + /// A one-to-one correspondence with [`diskann::index::config::IntraBatchCandidates`]. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] @@ -333,6 +370,8 @@ pub(crate) enum SearchPhase { Range(RangeSearchPhase), TopkBetaFilter(BetaSearchPhase), TopkMultihopFilter(MultiHopSearchPhase), + /// Multi-hop search with high-selectivity optimization (exploration queue enabled). + TopkHighSelectivityMultihopFilter(HighSelectivityMultiHopSearchPhase), } impl CheckDeserialization for SearchPhase { @@ -342,6 +381,9 @@ impl CheckDeserialization for SearchPhase { SearchPhase::Range(phase) => phase.check_deserialization(checker), SearchPhase::TopkBetaFilter(phase) => phase.check_deserialization(checker), SearchPhase::TopkMultihopFilter(phase) => phase.check_deserialization(checker), + SearchPhase::TopkHighSelectivityMultihopFilter(phase) => { + phase.check_deserialization(checker) + } } } } diff --git a/diskann-benchmark/src/utils/filters.rs b/diskann-benchmark/src/utils/filters.rs index 43d528c18..72b1bc132 100644 --- a/diskann-benchmark/src/utils/filters.rs +++ b/diskann-benchmark/src/utils/filters.rs @@ -5,8 +5,14 @@ use bit_set::BitSet; use std::fmt::Debug; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; -use diskann::{graph::index::QueryLabelProvider, utils::VectorId}; +use diskann::{ + graph::index::{QueryLabelProvider, QueryVisitDecision}, + neighbor::Neighbor, + utils::VectorId, +}; use diskann_benchmark_runner::files::InputFile; use diskann_label_filter::{ kv_index::GenericIndex, @@ -81,6 +87,261 @@ where } } +/// A bitmap filter wrapper that implements the full 4-layer high selectivity handling: +/// +/// **Layer 1: Exploration Queue** - Uses `RejectAndNeedExpand` to enable continued +/// graph traversal through non-matching nodes when the primary queue is exhausted. +/// +/// **Layer 2: Match Rate Detection** - Automatically detects high selectivity after +/// 30 samples. If match rate < 2%, enables exploration mode. +/// +/// **Layer 3: Checkpoint-Based Timeout** - Checks timeout every 1000 visits instead +/// of every visit to reduce syscall overhead by ~99%. +/// +/// **Layer 4: Two-Tier Early Stop** - Soft timeout (default 10ms) triggers when +/// enough matches found; hard timeout (default 100ms) is unconditional. +pub struct HighSelectivityBitmapFilter { + bitmap: BitSet, + /// Number of nodes visited + node_visited: AtomicU64, + /// Number of nodes matched + node_matched: AtomicU64, + /// Cached mode: 0 = undetermined, 1 = high match rate, 2 = low match rate + need_expand_mode: AtomicU64, + /// Next checkpoint for timeout/mode checking (avoids checking every visit) + next_timeout_check: AtomicU64, + /// Start time for timeout calculations + start_instant: Instant, + /// Soft early stop threshold (default 10ms) + soft_early_stop: Duration, + /// Hard early stop threshold (default 100ms) + hard_early_stop: Duration, + /// Minimum matched count required for soft early stop + min_matched_count: u64, + /// Flag indicating if early stop was triggered + early_stopped: AtomicBool, +} + +impl HighSelectivityBitmapFilter { + /// Minimum samples needed before we can estimate match rate reliably + const MIN_SAMPLES_FOR_ESTIMATION: u64 = 30; + + /// Match rate threshold for enabling exploration queue (2%) + const LOW_MATCH_RATE_THRESHOLD: f64 = 0.02; + + /// Check interval for match rate calculation and timeout enforcement + const MATCH_RATE_CHECK_INTERVAL: u64 = 1000; + + /// Default soft early stop threshold in milliseconds + const DEFAULT_SOFT_EARLY_STOP_MS: u64 = 10; + + /// Default hard early stop threshold in milliseconds + const DEFAULT_HARD_EARLY_STOP_MS: u64 = 100; + + /// Default minimum matched count for soft early stop + const DEFAULT_MIN_MATCHED_COUNT: u64 = 10; + + /// Create a new filter with default timeout settings. + pub fn new(bitmap: BitSet) -> Self { + Self::with_config( + bitmap, + Duration::from_millis(Self::DEFAULT_SOFT_EARLY_STOP_MS), + Duration::from_millis(Self::DEFAULT_HARD_EARLY_STOP_MS), + Self::DEFAULT_MIN_MATCHED_COUNT, + ) + } + + /// Create a new filter with custom timeout configuration. + /// + /// # Arguments + /// * `bitmap` - The bitmap for filtering + /// * `soft_early_stop` - Soft timeout (triggers when elapsed > soft AND matched >= min_matched) + /// * `hard_early_stop` - Hard timeout (unconditional termination) + /// * `min_matched_count` - Minimum matches required for soft early stop + pub fn with_config( + bitmap: BitSet, + soft_early_stop: Duration, + hard_early_stop: Duration, + min_matched_count: u64, + ) -> Self { + // Clamp: soft_early_stop should not exceed hard_early_stop + let soft_early_stop = soft_early_stop.min(hard_early_stop); + + Self { + bitmap, + node_visited: AtomicU64::new(0), + node_matched: AtomicU64::new(0), + need_expand_mode: AtomicU64::new(0), + next_timeout_check: AtomicU64::new(Self::MATCH_RATE_CHECK_INTERVAL), + start_instant: Instant::now(), + soft_early_stop, + hard_early_stop, + min_matched_count, + early_stopped: AtomicBool::new(false), + } + } + + /// Check if we need expansion mode based on current match rate. + /// Returns: 0 = not determined, 1 = high match rate (no expand), 2 = low match rate (need expand) + fn check_need_expand_mode(&self, visited: u64, matched: u64) -> u64 { + // Already determined + let cached = self.need_expand_mode.load(Ordering::Relaxed); + if cached != 0 { + return cached; + } + + // Need enough samples + if visited < Self::MIN_SAMPLES_FOR_ESTIMATION { + return 0; + } + + let match_rate = if visited > 0 { + matched as f64 / visited as f64 + } else { + 0.0 + }; + + let mode = if match_rate < Self::LOW_MATCH_RATE_THRESHOLD { + 2 // Low match rate, need expand + } else { + 1 // High match rate, no expand needed + }; + + // Cache the result (compare-and-swap to avoid race) + let _ = self.need_expand_mode.compare_exchange( + 0, + mode, + Ordering::Relaxed, + Ordering::Relaxed, + ); + mode + } + + /// Check whether early stop should trigger based on two-tier logic: + /// - Soft early stop: elapsed > soft_early_stop AND matched >= min_matched_count + /// - Hard early stop: elapsed > hard_early_stop (unconditional) + fn should_early_stop(&self, matched: u64) -> bool { + let elapsed = self.start_instant.elapsed(); + + // Hard early stop: unconditionally terminate to bound worst-case latency + if elapsed > self.hard_early_stop { + return true; + } + + // Soft early stop: time is past threshold AND enough matched results + // This ensures we don't terminate too early without enough results + if elapsed > self.soft_early_stop && matched >= self.min_matched_count { + return true; + } + + false + } + + /// Returns true if early stop was triggered during search. + #[allow(dead_code)] + pub fn was_early_stopped(&self) -> bool { + self.early_stopped.load(Ordering::Relaxed) + } + + /// Returns the number of nodes visited. + #[allow(dead_code)] + pub fn nodes_visited(&self) -> u64 { + self.node_visited.load(Ordering::Relaxed) + } + + /// Returns the number of nodes matched. + #[allow(dead_code)] + pub fn nodes_matched(&self) -> u64 { + self.node_matched.load(Ordering::Relaxed) + } +} + +impl Debug for HighSelectivityBitmapFilter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HighSelectivityBitmapFilter") + .field("bitmap_len", &self.bitmap.len()) + .field("node_visited", &self.node_visited.load(Ordering::Relaxed)) + .field("node_matched", &self.node_matched.load(Ordering::Relaxed)) + .field( + "need_expand_mode", + &self.need_expand_mode.load(Ordering::Relaxed), + ) + .field("soft_early_stop", &self.soft_early_stop) + .field("hard_early_stop", &self.hard_early_stop) + .field("min_matched_count", &self.min_matched_count) + .field( + "early_stopped", + &self.early_stopped.load(Ordering::Relaxed), + ) + .finish() + } +} + +impl QueryLabelProvider for HighSelectivityBitmapFilter +where + T: VectorId, +{ + fn is_match(&self, vec_id: T) -> bool { + self.bitmap.contains(vec_id.into_usize()) + } + + fn on_visit(&self, neighbor: Neighbor) -> QueryVisitDecision { + let visited = self.node_visited.fetch_add(1, Ordering::Relaxed) + 1; + let matched = self.node_matched.load(Ordering::Relaxed); + + // Layer 3: Checkpoint-based timeout checking + // Only check timeout at checkpoints (every MATCH_RATE_CHECK_INTERVAL visits) + // to reduce syscall overhead from Instant::now() calls + let threshold = self.next_timeout_check.load(Ordering::Relaxed); + if visited >= threshold { + // Update next checkpoint + self.next_timeout_check.store( + visited + Self::MATCH_RATE_CHECK_INTERVAL, + Ordering::Relaxed, + ); + + // Update need_expand_mode based on current match rate + self.check_need_expand_mode(visited, matched); + + // Layer 4: Two-tier early stop + // - Soft: elapsed > soft_early_stop AND matched >= min_matched_count + // - Hard: elapsed > hard_early_stop (unconditional) + if self.should_early_stop(matched) { + self.early_stopped.store(true, Ordering::Relaxed); + return QueryVisitDecision::Terminate; + } + } + + // Evaluate filter match + if self.is_match(neighbor.id) { + self.node_matched.fetch_add(1, Ordering::Relaxed); + QueryVisitDecision::Accept(neighbor) + } else { + // Layer 2: Match rate detection for exploration mode + let mode = self.need_expand_mode.load(Ordering::Relaxed); + if mode == 0 { + // Not yet determined: check every time until we have enough samples + let current_matched = self.node_matched.load(Ordering::Relaxed); + let updated_mode = self.check_need_expand_mode(visited, current_matched); + + if updated_mode == 2 { + // Just determined as low match rate: use exploration queue + QueryVisitDecision::RejectAndNeedExpand + } else { + // Still undetermined or high match rate: simple reject + QueryVisitDecision::Reject + } + } else if mode == 2 { + // Low match rate confirmed: use exploration queue + QueryVisitDecision::RejectAndNeedExpand + } else { + // High match rate (mode=1): simple reject + QueryVisitDecision::Reject + } + } + } +} + pub(crate) fn generate_bitmaps( query_predicates: &InputFile, data_labels: &InputFile, @@ -116,9 +377,21 @@ pub(crate) fn as_query_label_provider(set: BitSet) -> Arc Arc> { + Arc::new(HighSelectivityBitmapFilter::new(set)) +} + #[cfg(test)] mod tests { use super::*; + use std::thread; #[test] fn test_bitmap_filter_match() { @@ -151,4 +424,315 @@ mod tests { assert!(filter.is_match(1000u32)); assert!(!filter.is_match(999u32)); } + + // ----------------------------------------------------------------------- + // Layer 2: Match Rate Detection Tests + // ----------------------------------------------------------------------- + + #[test] + fn test_high_selectivity_filter_low_match_rate() { + // Create a filter where only 1 out of 100 IDs match (1% match rate) + let mut bitset = BitSet::new(); + bitset.insert(50); // Only ID 50 matches + let filter = HighSelectivityBitmapFilter::new(bitset); + + // Visit 40 non-matching nodes (more than MIN_SAMPLES_FOR_ESTIMATION=30) + for id in 0..40u32 { + if id == 50 { + continue; + } + let neighbor = Neighbor::new(id, id as f32); + let decision = filter.on_visit(neighbor); + + // Mode is determined after 30 samples + // Before that, we check mode each time but it stays 0 until we have enough samples + // After we have 30 samples with 0% match rate, mode becomes 2 + if filter.node_visited.load(Ordering::Relaxed) >= 30 { + // After 30 visits with 0% match rate: should return RejectAndNeedExpand + assert!( + matches!(decision, QueryVisitDecision::RejectAndNeedExpand), + "After 30 samples with 0% match rate, should return RejectAndNeedExpand, got {:?}", + decision + ); + } + } + + // Mode should now be 2 (low match rate) + assert_eq!( + filter.need_expand_mode.load(Ordering::Relaxed), + 2, + "Mode should be 2 (low match rate)" + ); + } + + #[test] + fn test_high_selectivity_filter_high_match_rate() { + // Create a filter where half the IDs match (50% match rate, well above 2% threshold) + let mut bitset = BitSet::new(); + for i in 0..50 { + bitset.insert(i * 2); // Even IDs match + } + let filter = HighSelectivityBitmapFilter::new(bitset); + + // Visit 60 nodes alternating match/no-match + for id in 0..60u32 { + let neighbor = Neighbor::new(id, id as f32); + let decision = filter.on_visit(neighbor); + + if id % 2 == 0 { + // Even IDs match + assert!( + matches!(decision, QueryVisitDecision::Accept(_)), + "Matching nodes should be accepted" + ); + } else { + // Odd IDs don't match + // After 30 samples with 50% match rate: should return Reject (high match rate) + if filter.node_visited.load(Ordering::Relaxed) >= 30 { + assert!( + matches!(decision, QueryVisitDecision::Reject), + "High match rate should return Reject, got {:?}", + decision + ); + } + } + } + + // Mode should be 1 (high match rate) after enough samples + assert_eq!( + filter.need_expand_mode.load(Ordering::Relaxed), + 1, + "Mode should be 1 (high match rate)" + ); + } + + #[test] + fn test_mode_undetermined_with_insufficient_samples() { + let bitset = BitSet::new(); // No matches + let filter = HighSelectivityBitmapFilter::new(bitset); + + // Visit only 20 nodes (less than MIN_SAMPLES_FOR_ESTIMATION=30) + for id in 0..20u32 { + let neighbor = Neighbor::new(id, id as f32); + let _ = filter.on_visit(neighbor); + } + + // Mode should still be 0 (undetermined) + assert_eq!( + filter.need_expand_mode.load(Ordering::Relaxed), + 0, + "Mode should be 0 (undetermined) with < 30 samples" + ); + } + + // ----------------------------------------------------------------------- + // Layer 3: Checkpoint-Based Timeout Tests + // ----------------------------------------------------------------------- + + #[test] + fn test_checkpoint_interval() { + let bitset = BitSet::new(); + let filter = HighSelectivityBitmapFilter::new(bitset); + + // Initial checkpoint should be at MATCH_RATE_CHECK_INTERVAL (1000) + assert_eq!( + filter.next_timeout_check.load(Ordering::Relaxed), + HighSelectivityBitmapFilter::MATCH_RATE_CHECK_INTERVAL, + "Initial checkpoint should be at 1000" + ); + + // Visit 999 nodes - checkpoint should not be updated yet + for id in 0..999u32 { + let neighbor = Neighbor::new(id, id as f32); + let _ = filter.on_visit(neighbor); + } + + assert_eq!( + filter.next_timeout_check.load(Ordering::Relaxed), + HighSelectivityBitmapFilter::MATCH_RATE_CHECK_INTERVAL, + "Checkpoint should not be updated before reaching threshold" + ); + + // Visit one more node to reach checkpoint + let neighbor = Neighbor::new(999u32, 999.0); + let _ = filter.on_visit(neighbor); + + // Checkpoint should now be updated to 2000 + assert_eq!( + filter.next_timeout_check.load(Ordering::Relaxed), + 2 * HighSelectivityBitmapFilter::MATCH_RATE_CHECK_INTERVAL, + "Checkpoint should be updated to 2000 after reaching 1000" + ); + } + + // ----------------------------------------------------------------------- + // Layer 4: Two-Tier Early Stop Tests + // ----------------------------------------------------------------------- + + #[test] + fn test_hard_early_stop() { + // Create filter with very short hard timeout (1ms) + let bitset = BitSet::new(); + let filter = HighSelectivityBitmapFilter::with_config( + bitset, + Duration::from_millis(100), // soft (won't trigger - no matches) + Duration::from_millis(1), // hard (very short) + 10, + ); + + // Sleep to ensure we exceed hard timeout + thread::sleep(Duration::from_millis(5)); + + // Visit enough nodes to reach checkpoint (1000) + let mut terminated = false; + for id in 0..1500u32 { + let neighbor = Neighbor::new(id, id as f32); + let decision = filter.on_visit(neighbor); + if matches!(decision, QueryVisitDecision::Terminate) { + terminated = true; + break; + } + } + + assert!(terminated, "Should have terminated due to hard timeout"); + assert!( + filter.was_early_stopped(), + "early_stopped flag should be set" + ); + } + + #[test] + fn test_soft_early_stop_with_enough_matches() { + // Create filter with short soft timeout and matching IDs + let mut bitset = BitSet::new(); + for i in 0..1000 { + bitset.insert(i); // All IDs match + } + let filter = HighSelectivityBitmapFilter::with_config( + bitset, + Duration::from_millis(1), // soft (very short) + Duration::from_millis(1000), // hard (long) + 5, // min_matched_count + ); + + // Sleep to ensure we exceed soft timeout + thread::sleep(Duration::from_millis(5)); + + // Visit enough matching nodes to exceed min_matched_count and reach checkpoint + let mut terminated = false; + for id in 0..1500u32 { + let neighbor = Neighbor::new(id, id as f32); + let decision = filter.on_visit(neighbor); + if matches!(decision, QueryVisitDecision::Terminate) { + terminated = true; + break; + } + } + + assert!( + terminated, + "Should have terminated due to soft timeout (enough matches)" + ); + assert!( + filter.was_early_stopped(), + "early_stopped flag should be set" + ); + assert!( + filter.nodes_matched() >= 5, + "Should have at least min_matched_count matches" + ); + } + + #[test] + fn test_soft_early_stop_not_triggered_without_enough_matches() { + // Create filter with short soft timeout but NO matches + let bitset = BitSet::new(); // No matches + let filter = HighSelectivityBitmapFilter::with_config( + bitset, + Duration::from_millis(1), // soft (very short) + Duration::from_millis(1000), // hard (long) + 10, // min_matched_count (won't be met) + ); + + // Sleep to ensure we exceed soft timeout + thread::sleep(Duration::from_millis(5)); + + // Visit nodes - should NOT terminate because we don't have enough matches + let mut terminated = false; + for id in 0..1500u32 { + let neighbor = Neighbor::new(id, id as f32); + let decision = filter.on_visit(neighbor); + if matches!(decision, QueryVisitDecision::Terminate) { + terminated = true; + break; + } + } + + // Should NOT have terminated - soft timeout requires min_matched_count + assert!( + !terminated, + "Should NOT terminate without enough matches (soft timeout not met)" + ); + assert!( + !filter.was_early_stopped(), + "early_stopped flag should NOT be set" + ); + } + + #[test] + fn test_soft_timeout_clamped_to_hard_timeout() { + // Create filter where soft > hard (should be clamped) + let bitset = BitSet::new(); + let filter = HighSelectivityBitmapFilter::with_config( + bitset, + Duration::from_millis(200), // soft (greater than hard) + Duration::from_millis(50), // hard + 10, + ); + + // soft_early_stop should be clamped to hard_early_stop + assert!( + filter.soft_early_stop <= filter.hard_early_stop, + "soft_early_stop ({:?}) should not exceed hard_early_stop ({:?})", + filter.soft_early_stop, + filter.hard_early_stop + ); + } + + // ----------------------------------------------------------------------- + // Default Constants Tests + // ----------------------------------------------------------------------- + + #[test] + fn test_default_constants() { + assert_eq!( + HighSelectivityBitmapFilter::DEFAULT_SOFT_EARLY_STOP_MS, + 10, + "Default soft timeout should be 10ms" + ); + assert_eq!( + HighSelectivityBitmapFilter::DEFAULT_HARD_EARLY_STOP_MS, + 100, + "Default hard timeout should be 100ms" + ); + assert_eq!( + HighSelectivityBitmapFilter::DEFAULT_MIN_MATCHED_COUNT, + 10, + "Default min matched count should be 10" + ); + assert_eq!( + HighSelectivityBitmapFilter::MIN_SAMPLES_FOR_ESTIMATION, + 30, + "Min samples for estimation should be 30" + ); + assert!( + (HighSelectivityBitmapFilter::LOW_MATCH_RATE_THRESHOLD - 0.02).abs() < f64::EPSILON, + "Low match rate threshold should be 2%" + ); + assert_eq!( + HighSelectivityBitmapFilter::MATCH_RATE_CHECK_INTERVAL, + 1000, + "Match rate check interval should be 1000" + ); + } } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index cd81ba5f1..ea6bd4675 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -73,6 +73,9 @@ pub enum QueryVisitDecision { Accept(Neighbor), /// Reject this node; do not add it to the frontier. Reject, + /// Reject this node but signal that exploration queue should be enabled + /// (for low match rate scenarios where we need to traverse through non-matching nodes). + RejectAndNeedExpand, /// Stop the search immediately without accepting this node. Terminate, } diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index aba0f44c5..5f0d800a9 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -215,21 +215,53 @@ where let mut two_hop_neighbors = Vec::with_capacity(max_degree_with_slack); let mut candidates_two_hop_expansion = Vec::with_capacity(max_degree_with_slack); - while scratch.best.has_notvisited_node() && !accessor.terminate_early() { + // Exploration queue: maintains rejected nodes for continued graph traversal + // when scratch.best has no more unvisited nodes (important for low match rate scenarios). + // Only enabled when we receive RejectAndNeedExpand signals. + let l_search = search_params.l_value().get(); + let exploration_queue_capacity = l_search; + let mut exploration_queue: Vec> = Vec::with_capacity(exploration_queue_capacity); + // Track which nodes are already in exploration_queue to avoid duplicates + let mut exploration_set: HashSet = HashSet::with_capacity(exploration_queue_capacity); + // Flag to track if exploration mode is enabled (set when we receive RejectAndNeedExpand) + let mut exploration_mode_enabled = false; + + loop { + // Check termination conditions + if accessor.terminate_early() { + break; + } + scratch.beam_nodes.clear(); one_hop_neighbors.clear(); candidates_two_hop_expansion.clear(); two_hop_neighbors.clear(); - // In this loop we are going to find the beam_width number of nodes that are closest to the query. - // Each of these nodes will be a frontier node. - while scratch.beam_nodes.len() < beam_width - && let Some(closest_node) = scratch.best.closest_notvisited() - { + // Fill beam from scratch.best first (matching nodes have priority) + while scratch.beam_nodes.len() < beam_width { + let Some(closest_node) = scratch.best.closest_notvisited() else { + break; + }; search_record.record(closest_node, scratch.hops, scratch.cmps); scratch.beam_nodes.push(closest_node.id); } + // If beam not full and exploration mode is enabled, use exploration nodes + // (These are non-matching nodes used for graph traversal only) + if exploration_mode_enabled { + while scratch.beam_nodes.len() < beam_width { + let Some(node) = exploration_queue.pop() else { + break; + }; + scratch.beam_nodes.push(node.id); + } + } + + // Exit if no nodes to process + if scratch.beam_nodes.is_empty() { + break; + } + // compute distances from query to one-hop neighbors, and mark them visited accessor .expand_beam( @@ -250,6 +282,11 @@ where // Rejected nodes: still add to two-hop expansion so we can traverse through them candidates_two_hop_expansion.push(neighbor); } + QueryVisitDecision::RejectAndNeedExpand => { + // Low match rate detected: enable exploration mode and add to expansion + exploration_mode_enabled = true; + candidates_two_hop_expansion.push(neighbor); + } QueryVisitDecision::Terminate => { scratch.cmps += one_hop_neighbors.len() as u32; scratch.hops += scratch.beam_nodes.len() as u32; @@ -295,6 +332,26 @@ where scratch.cmps += two_hop_neighbors.len() as u32; scratch.hops += two_hop_expansion_candidate_ids.len() as u32; + + // Only add to exploration queue if exploration mode is enabled. + // This enables the search to continue exploring the graph even when no matching + // nodes are in scratch.best (critical for low match rate scenarios). + if exploration_mode_enabled { + for candidate in candidates_two_hop_expansion.iter() { + if exploration_set.insert(candidate.id) { + exploration_queue.push(*candidate); + } + } + + // Sort exploration queue by distance (descending order, farthest first) + // so pop() returns the closest node (greedy search behavior) + exploration_queue.sort_unstable_by(|a, b| { + b.distance + .partial_cmp(&a.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + exploration_queue.truncate(exploration_queue_capacity); + } } Ok(make_stats(scratch))