Skip to content

Commit 45b1309

Browse files
michaeljmarshallmichaelsembwever
authored andcommitted
CNDB-15640: Determine if vectors are unit length at insert (#2059)
Fixes: riptano/cndb#15640 In order to lay the ground work for Fused ADC, I want to refactor some of the PQ/BQ logic. The unit length computation needs to move, so I decided to move it out to its own PR. The core idea is that: * some models are documented to provide unit length vectors, and in those cases, we should skip the computational check * otherwise, we should check at runtime until we hit a non-unit length vector, and then we can skip the check and configure the `writePQ` method as needed (I asked chat gpt to provide proof for the config changes proposed in this PR. Here is it's generated description.) Quick rundown of which models spit out normalized vectors (so cosine == dot product, etc.): * **OpenAI (ada-002, v3-small, v3-large)** → already normalized. [OpenAI FAQ](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) literally says embeddings are unit-length. * **BERT** → depends. The SBERT “-cos-” models add a [`Normalize` layer](https://www.sbert.net/docs/package_reference/layers.html#normalize) so they’re fine; vanilla BERT doesn’t. * **Google Gecko** → normalized out of the box per [Vertex AI docs](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings). * **NVIDIA QA-4** → nothing in the [NVIDIA NIM model card](https://docs.api.nvidia.com/nim/reference/nvidia-embed-qa-4) about normalization, so assume *not* normalized and handle it yourself. * **Cohere v3** → not explicitly in their [API docs](https://docs.cohere.com/docs/cohere-embed) TL;DR: OpenAI + Gecko are definitely safe, Cohere/BERT/NV need manual normalization due to lack of documentation.
1 parent 623483d commit 45b1309

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.util.function.Function;
3434
import java.util.function.IntUnaryOperator;
3535
import java.util.function.ToIntFunction;
36-
import java.util.stream.IntStream;
3736

3837
import com.google.common.annotations.VisibleForTesting;
3938
import org.cliffc.high_scale_lib.NonBlockingHashMap;
@@ -125,6 +124,7 @@ public enum PQVersion {
125124
private final InvalidVectorBehavior invalidVectorBehavior;
126125
private final IntHashSet deletedOrdinals;
127126
private volatile boolean hasDeletions;
127+
private volatile boolean allVectorsAreUnitLength;
128128

129129
// we don't need to explicitly close these since only on-heap resources are involved
130130
private final ThreadLocal<GraphSearcherAccessManager> searchers;
@@ -158,6 +158,8 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
158158
invalidVectorBehavior = forSearching ? InvalidVectorBehavior.FAIL : InvalidVectorBehavior.IGNORE;
159159

160160
int jvectorVersion = Version.current().onDiskFormat().jvectorFileFormatVersion();
161+
// Assume true until we observe otherwise.
162+
allVectorsAreUnitLength = true;
161163
// This is only a warning since it's not a fatal error to write without hierarchy
162164
if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4)
163165
logger.warn("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " +
@@ -269,6 +271,12 @@ public long add(ByteBuffer term, T key)
269271
var success = postingsByOrdinal.compareAndPut(ordinal, null, postings);
270272
assert success : "postingsByOrdinal already contains an entry for ordinal " + ordinal;
271273
bytesUsed += builder.addGraphNode(ordinal, vector);
274+
275+
// If necessary, check if the vector is unit length.
276+
if (!sourceModel.hasKnownUnitLengthVectors() && allVectorsAreUnitLength)
277+
if (!(Math.abs(VectorUtil.dotProduct(vector, vector) - 1.0f) < 0.01))
278+
allVectorsAreUnitLength = false;
279+
272280
return bytesUsed;
273281
}
274282
else
@@ -560,7 +568,6 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
560568
// Build encoder and compress vectors
561569
VectorCompressor<?> compressor; // will be null if we can't compress
562570
CompressedVectors cv = null;
563-
boolean containsUnitVectors;
564571
// limit the PQ computation and encoding to one index at a time -- goal during flush is to
565572
// evict from memory ASAP so better to do the PQ build (in parallel) one at a time
566573
synchronized (CassandraOnHeapGraph.class)
@@ -580,15 +587,10 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
580587
// encode (compress) the vectors to save
581588
if (compressor != null)
582589
cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues));
583-
584-
containsUnitVectors = IntStream.range(0, vectorValues.size())
585-
.parallel()
586-
.mapToObj(vectorValues::getVector)
587-
.allMatch(v -> Math.abs(VectorUtil.dotProduct(v, v) - 1.0f) < 0.01);
588590
}
589591

590592
var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type;
591-
writePqHeader(writer, containsUnitVectors, actualType);
593+
writePqHeader(writer, allVectorsAreUnitLength, actualType);
592594
if (actualType == CompressionType.NONE)
593595
return writer.position();
594596

src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,19 @@
3131
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.BINARY_QUANTIZATION;
3232
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.NONE;
3333
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.PRODUCT_QUANTIZATION;
34-
3534
public enum VectorSourceModel
3635
{
37-
ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
38-
OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5),
39-
OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),
40-
BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0),
41-
GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
42-
NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
43-
COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),
44-
45-
OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery);
36+
ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
37+
OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5, true),
38+
OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, true),
39+
// BERT is not known to have unit length vectors in all cases
40+
BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0, false),
41+
GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
42+
NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, false),
43+
// Cohere does not officially say they have unit length vectors, but some users report that they do
44+
COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, false),
45+
46+
OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery, false);
4647

4748
/**
4849
* Default similarity function for this model.
@@ -58,18 +59,33 @@ public enum VectorSourceModel
5859
*/
5960
public final Function<VectorCompression, Double> overqueryProvider;
6061

61-
VectorSourceModel(Function<Integer, VectorCompression> compressionProvider, double overqueryFactor)
62+
/**
63+
* Indicates that the model is known to have unit length vectors. When false, the runtime checks per graph
64+
* until a non-unit length vector is found.
65+
*/
66+
private final boolean knownUnitLength;
67+
68+
VectorSourceModel(Function<Integer, VectorCompression> compressionProvider,
69+
double overqueryFactor,
70+
boolean knownUnitLength)
6271
{
63-
this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor);
72+
this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor, knownUnitLength);
6473
}
6574

6675
VectorSourceModel(VectorSimilarityFunction defaultSimilarityFunction,
6776
Function<Integer, VectorCompression> compressionProvider,
68-
Function<VectorCompression, Double> overqueryProvider)
77+
Function<VectorCompression, Double> overqueryProvider,
78+
boolean knownUnitLength)
6979
{
7080
this.defaultSimilarityFunction = defaultSimilarityFunction;
7181
this.compressionProvider = compressionProvider;
7282
this.overqueryProvider = overqueryProvider;
83+
this.knownUnitLength = knownUnitLength;
84+
}
85+
86+
public boolean hasKnownUnitLengthVectors()
87+
{
88+
return knownUnitLength;
7389
}
7490

7591
public static VectorSourceModel fromString(String value)

0 commit comments

Comments
 (0)