Skip to content

Commit 59faf8a

Browse files
CNDB-12922: Implement rerank_k in SAI ANN queries (#1562)
### What is the issue Fixes riptano/cndb#12922 ### What does this PR fix and why was it fixed This PR follows up on #1525 by integrating the CQL configured `rerank_k` in to SAI so that the ANN query can take the user input and let it influence query execution. Key changes: * Bump `MessagingService` to `VERSION_DS_11` * Fixes implementation of `MessagingService.instance().endpointsWithVersionBelow` by using the `keyspace` variant, which is necessary for CNDB. * When configured, the `rerank_k` is used as the graph's `rerankK` parameter. The only detail worth mentioning here is that we ignore the segment proportionality computation if `rerank_k` is provided. This diverges from the smart default design, but not too greatly. * Added a guardrail for `rerank_k` to prevent it from exceeding 4 times the `cassandra.sai.vector_search.max_top_k`. This is debatable and easily changed, so please let me know if we want something different. * Made the in memory brute force cost comparison logic use the `rerankK` value instead of the `limit`. I am pretty sure this is correct, but please review closely. Instead of validating `rerank_k`'s value within the query, I added a test that ensures that as we increase the rerank k, we increase the recall. CNDB pr: riptano/cndb#13095
1 parent cc92316 commit 59faf8a

File tree

20 files changed

+357
-81
lines changed

20 files changed

+357
-81
lines changed

src/java/org/apache/cassandra/config/CassandraRelevantProperties.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ public enum CassandraRelevantProperties
588588
* The current messaging version. This is used when we add new messaging versions without adopting them immediately,
589589
* or to force the node to use a specific version for testing purposes.
590590
*/
591-
DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_10)),
591+
DS_CURRENT_MESSAGING_VERSION("ds.current_messaging_version", Integer.toString(MessagingService.VERSION_DS_11)),
592592

593593
/**
594594
* Which compression algorithm to use for SSTable compression when not specified explicitly in the sstable options.
@@ -602,7 +602,6 @@ public enum CassandraRelevantProperties
602602
*/
603603
SKIP_OPTIMAL_STREAMING_CANDIDATES_CALCULATION("cassandra.skip_optimal_streaming_candidates_calculation", "false");
604604

605-
606605
CassandraRelevantProperties(String key, String defaultVal)
607606
{
608607
this.key = key;

src/java/org/apache/cassandra/cql3/statements/SelectOptions.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.cassandra.db.filter.ANNOptions;
2323
import org.apache.cassandra.exceptions.InvalidRequestException;
2424
import org.apache.cassandra.exceptions.RequestValidationException;
25+
import org.apache.cassandra.service.QueryState;
2526

2627
/**
2728
* {@code WITH option1=... AND option2=...} options for SELECT statements.
@@ -36,16 +37,19 @@ public class SelectOptions extends PropertyDefinitions
3637
/**
3738
* Validates all the {@code SELECT} options.
3839
*
40+
* @param state the query state
3941
* @param limit the {@code SELECT} query user-provided limit
4042
* @throws InvalidRequestException if any of the options are invalid
4143
*/
42-
public void validate(int limit) throws RequestValidationException
44+
public void validate(QueryState state, String keyspace, int limit) throws RequestValidationException
4345
{
4446
validate(keywords, Collections.emptySet());
45-
parseANNOptions().validate(limit);
47+
parseANNOptions().validate(state, keyspace, limit);
4648
}
4749

4850
/**
51+
* Parse the ANN Options. Does not validate values of the options or whether peers will be able to process them.
52+
*
4953
* @return the ANN options within these options, or {@link ANNOptions#NONE} if no options are present
5054
* @throws InvalidRequestException if the ANN options are invalid
5155
*/

src/java/org/apache/cassandra/cql3/statements/SelectStatement.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ public ReadQuery getQuery(QueryState queryState,
433433
checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset));
434434
}
435435

436-
selectOptions.validate(userLimit);
436+
selectOptions.validate(queryState, table.keyspace, userLimit);
437437

438438
return query;
439439
}

src/java/org/apache/cassandra/db/filter/ANNOptions.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424

2525
import org.apache.cassandra.db.TypeSizes;
2626
import org.apache.cassandra.exceptions.InvalidRequestException;
27+
import org.apache.cassandra.guardrails.Guardrails;
2728
import org.apache.cassandra.io.util.DataInputPlus;
2829
import org.apache.cassandra.io.util.DataOutputPlus;
2930
import org.apache.cassandra.locator.InetAddressAndPort;
3031
import org.apache.cassandra.net.MessagingService;
32+
import org.apache.cassandra.service.QueryState;
3133
import org.apache.cassandra.utils.FBUtilities;
3234

3335
/**
@@ -60,10 +62,26 @@ public static ANNOptions create(@Nullable Integer rerankK)
6062
return rerankK == null ? NONE : new ANNOptions(rerankK);
6163
}
6264

63-
public void validate(int limit)
65+
/**
66+
* Validates the ANN options by checking that they are within the guardrails and that peers support the options.
67+
*/
68+
public void validate(QueryState state, String keyspace, int limit)
6469
{
65-
if (rerankK != null && rerankK > 0 && rerankK < limit)
70+
if (rerankK == null)
71+
return;
72+
73+
if (rerankK < limit)
6674
throw new InvalidRequestException(String.format("Invalid rerank_k value %d lesser than limit %d", rerankK, limit));
75+
76+
Guardrails.annRerankKMaxValue.guard(rerankK, "ANN options", false, state);
77+
78+
// Ensure that all nodes in the cluster are in a version that supports ANN options, including this one
79+
assert keyspace != null;
80+
Set<InetAddressAndPort> badNodes = MessagingService.instance().endpointsWithVersionBelow(keyspace, MessagingService.VERSION_DS_11);
81+
if (MessagingService.current_version < MessagingService.VERSION_DS_11)
82+
badNodes.add(FBUtilities.getBroadcastAddressAndPort());
83+
if (!badNodes.isEmpty())
84+
throw new InvalidRequestException("ANN options are not supported in clusters below DS 11.");
6785
}
6886

6987
/**
@@ -74,13 +92,6 @@ public void validate(int limit)
7492
*/
7593
public static ANNOptions fromMap(Map<String, String> map)
7694
{
77-
// ensure that all nodes in the cluster are in a version that supports ANN options, including this one
78-
Set<InetAddressAndPort> badNodes = MessagingService.instance().endpointsWithVersionBelow(MessagingService.VERSION_DS_11);
79-
if (MessagingService.current_version < MessagingService.VERSION_DS_11)
80-
badNodes.add(FBUtilities.getBroadcastAddressAndPort());
81-
if (!badNodes.isEmpty())
82-
throw new InvalidRequestException("ANN options are not supported in clusters below DS 11.");
83-
8495
Integer rerankK = null;
8596

8697
for (Map.Entry<String, String> entry : map.entrySet())

src/java/org/apache/cassandra/guardrails/Guardrails.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ what, formatSize(v), formatSize(t)))
120120
format("%s has a vector of %s dimensions, this exceeds the %s threshold of %s.",
121121
what, value, isWarning ? "warning" : "failure", threshold));
122122

123+
/**
124+
* Guardrail on the maximum value for the rerank_k parameter, an ANN query option.
125+
*/
126+
public static final Threshold annRerankKMaxValue =
127+
factory.threshold("sai_ann_rerank_k_max_value",
128+
() -> config.sai_ann_rerank_k_warn_threshold,
129+
() -> config.sai_ann_rerank_k_fail_threshold,
130+
(isWarning, what, value, threshold) ->
131+
format("%s specifies rerank_k=%s, this exceeds the %s threshold of %s.",
132+
what, value, isWarning ? "warning" : "failure", threshold));
133+
123134
public static final DisableFlag readBeforeWriteListOperationsEnabled =
124135
factory.disableFlag("read_before_write_list_operations",
125136
() -> !config.read_before_write_list_operations_enabled,

src/java/org/apache/cassandra/guardrails/GuardrailsConfig.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.google.common.collect.ImmutableSet;
3030
import com.google.common.collect.Sets;
3131

32+
import org.apache.cassandra.config.CassandraRelevantProperties;
3233
import org.apache.cassandra.config.Config;
3334
import org.apache.cassandra.config.DatabaseDescriptor;
3435
import org.apache.cassandra.cql3.statements.schema.TableAttributes;
@@ -81,6 +82,8 @@ public class GuardrailsConfig
8182
public volatile Boolean read_before_write_list_operations_enabled;
8283
public volatile Integer vector_dimensions_warn_threshold;
8384
public volatile Integer vector_dimensions_failure_threshold;
85+
public volatile Integer sai_ann_rerank_k_warn_threshold;
86+
public volatile Integer sai_ann_rerank_k_fail_threshold;
8487

8588
// Legacy 2i guardrail
8689
public volatile Integer secondary_index_per_table_failure_threshold;
@@ -165,6 +168,11 @@ public void applyConfig()
165168
enforceDefault(tombstone_warn_threshold, v -> tombstone_warn_threshold = v, 1000, 1000);
166169
enforceDefault(tombstone_failure_threshold, v -> tombstone_failure_threshold = v, 100000, 100000);
167170

171+
// Default to no warning and failure at 4 times the maxTopK value
172+
int maxTopK = CassandraRelevantProperties.SAI_VECTOR_SEARCH_MAX_TOP_K.getInt();
173+
enforceDefault(sai_ann_rerank_k_warn_threshold, v -> sai_ann_rerank_k_warn_threshold = v, -1, -1);
174+
enforceDefault(sai_ann_rerank_k_fail_threshold, v -> sai_ann_rerank_k_fail_threshold = v, 4 * maxTopK, 4 * maxTopK);
175+
168176
// for write requests
169177
enforceDefault(logged_batch_enabled, v -> logged_batch_enabled = v, true, true);
170178
enforceDefault(batch_size_warn_threshold_in_kb, v -> batch_size_warn_threshold_in_kb = v, 64, 64);
@@ -269,6 +277,10 @@ public void validate()
269277
validateStrictlyPositiveInteger(vector_dimensions_failure_threshold, "vector_dimensions_failure_threshold");
270278
validateWarnLowerThanFail(vector_dimensions_warn_threshold, vector_dimensions_failure_threshold, "vector_dimensions");
271279

280+
validateStrictlyPositiveInteger(sai_ann_rerank_k_warn_threshold, "sai_ann_rerank_k_warn_threshold");
281+
validateStrictlyPositiveInteger(sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k_fail_threshold");
282+
validateWarnLowerThanFail(sai_ann_rerank_k_warn_threshold, sai_ann_rerank_k_fail_threshold, "sai_ann_rerank_k");
283+
272284
validateStrictlyPositiveInteger(tables_warn_threshold, "tables_warn_threshold");
273285
validateStrictlyPositiveInteger(tables_failure_threshold, "tables_failure_threshold");
274286
validateWarnLowerThanFail(tables_warn_threshold, tables_failure_threshold, "tables");

src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,6 @@ public void validate(ReadCommand command) throws InvalidRequestException
717717
throw new InvalidRequestException(String.format("SAI based ORDER BY clause requires a LIMIT that is not greater than %s. LIMIT was %s",
718718
MAX_TOP_K, command.limits().isUnlimited() ? "NO LIMIT" : command.limits().count()));
719719

720-
ANNOptions annOptions = command.rowFilter().annOptions();
721-
if (annOptions != ANNOptions.NONE)
722-
throw new InvalidRequestException("SAI doesn't support ANN options yet.");
723-
724720
indexContext.validate(command.rowFilter());
725721
}
726722

src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,11 @@ public String toString()
205205
* the number of candidates, the more nodes we expect to visit just to find
206206
* results that are in that set.)
207207
*/
208-
public double estimateAnnSearchCost(int limit, int candidates)
208+
public double estimateAnnSearchCost(Orderer orderer, int limit, int candidates)
209209
{
210-
IndexSearcher searcher = getIndexSearcher();
211-
return ((V2VectorIndexSearcher) searcher).estimateAnnSearchCost(limit, candidates);
210+
V2VectorIndexSearcher searcher = (V2VectorIndexSearcher) getIndexSearcher();
211+
int rerankK = orderer.rerankKFor(limit, searcher.getCompression());
212+
return searcher.estimateAnnSearchCost(rerankK, candidates);
212213
}
213214

214215
/**

src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer, E
207207
{
208208
if (segment.intersects(keyRange))
209209
{
210+
// Note that the proportionality is not used when the user supplies a rerank_k value in the
211+
// ANN_OPTIONS map.
210212
var segmentLimit = segment.proportionalAnnLimit(limit, totalRows);
211213
iterators.add(segment.orderBy(orderer, slice, keyRange, context, segmentLimit));
212214
}

src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
163163
if (!orderer.isANN())
164164
throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer));
165165

166-
int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
166+
int rerankK = orderer.rerankKFor(limit, graph.getCompression());
167167
var queryVector = vts.createFloatVector(orderer.getVectorTerm());
168168

169169
var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0);
@@ -428,9 +428,8 @@ public double cost()
428428
}
429429
}
430430

431-
public double estimateAnnSearchCost(int limit, int candidates)
431+
public double estimateAnnSearchCost(int rerankK, int candidates)
432432
{
433-
int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
434433
var estimate = estimateCost(rerankK, candidates);
435434
return estimate.cost();
436435
}
@@ -472,7 +471,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
472471
if (keys.isEmpty())
473472
return CloseableIterator.emptyIterator();
474473

475-
int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression());
474+
int rerankK = orderer.rerankKFor(limit, graph.getCompression());
476475
// Convert PKs to segment row ids and map to ordinals, skipping any that don't exist in this segment
477476
var segmentOrdinalPairs = flatmapPrimaryKeysToBitsAndRows(keys);
478477
var numRows = segmentOrdinalPairs.size();
@@ -611,9 +610,9 @@ public static double logBase2(double number) {
611610
return Math.log(number) / Math.log(2);
612611
}
613612

614-
private int getRawExpectedNodes(int limit, int nPermittedOrdinals)
613+
private int getRawExpectedNodes(int rerankK, int nPermittedOrdinals)
615614
{
616-
return VectorMemtableIndex.expectedNodesVisited(limit, nPermittedOrdinals, graph.size());
615+
return VectorMemtableIndex.expectedNodesVisited(rerankK, nPermittedOrdinals, graph.size());
617616
}
618617

619618
@Override

0 commit comments

Comments
 (0)