diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java index cb0d40551d89..bee00e842d39 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java @@ -33,7 +33,6 @@ import java.util.function.Function; import java.util.function.IntUnaryOperator; import java.util.function.ToIntFunction; -import java.util.stream.IntStream; import com.google.common.annotations.VisibleForTesting; import org.cliffc.high_scale_lib.NonBlockingHashMap; @@ -50,7 +49,6 @@ import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; -import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.quantization.BinaryQuantization; import io.github.jbellis.jvector.quantization.CompressedVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; @@ -81,7 +79,6 @@ import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; import org.apache.cassandra.index.sai.disk.v2.V2VectorIndexSearcher; import org.apache.cassandra.index.sai.disk.v2.V2VectorPostingsWriter; -import org.apache.cassandra.index.sai.disk.v3.V3OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter.Structure; @@ -125,6 +122,7 @@ public enum PQVersion { private final InvalidVectorBehavior invalidVectorBehavior; private final IntHashSet deletedOrdinals; private volatile boolean hasDeletions; + private volatile boolean allVectorsAreUnitLength; // we don't need to explicitly close these since only on-heap resources are involved private final ThreadLocal searchers; @@ -157,6 +155,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable vectorsByKey = forSearching ? new NonBlockingHashMap<>() : null; invalidVectorBehavior = forSearching ? InvalidVectorBehavior.FAIL : InvalidVectorBehavior.IGNORE; + // Assume true until we observe otherwise. + allVectorsAreUnitLength = true; + int jvectorVersion = context.version().onDiskFormat().jvectorFileFormatVersion(); // This is only a warning since it's not a fatal error to write without hierarchy if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4) @@ -269,6 +270,12 @@ public long add(ByteBuffer term, T key) var success = postingsByOrdinal.compareAndPut(ordinal, null, postings); assert success : "postingsByOrdinal already contains an entry for ordinal " + ordinal; bytesUsed += builder.addGraphNode(ordinal, vector); + + // If necessary, check if the vector is unit length. + if (!sourceModel.hasKnownUnitLengthVectors() && allVectorsAreUnitLength) + if (!(Math.abs(VectorUtil.dotProduct(vector, vector) - 1.0f) < 0.01)) + allVectorsAreUnitLength = false; + return bytesUsed; } else @@ -560,7 +567,6 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos // Build encoder and compress vectors VectorCompressor compressor; // will be null if we can't compress CompressedVectors cv = null; - boolean containsUnitVectors; // limit the PQ computation and encoding to one index at a time -- goal during flush is to // evict from memory ASAP so better to do the PQ build (in parallel) one at a time synchronized (CassandraOnHeapGraph.class) @@ -580,15 +586,10 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos // encode (compress) the vectors to save if (compressor != null) cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues)); - - containsUnitVectors = IntStream.range(0, vectorValues.size()) - .parallel() - .mapToObj(vectorValues::getVector) - .allMatch(v -> Math.abs(VectorUtil.dotProduct(v, v) - 1.0f) < 0.01); } var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type; - writePqHeader(writer, containsUnitVectors, actualType, indexContext.version()); + writePqHeader(writer, allVectorsAreUnitLength, actualType, indexContext.version()); if (actualType == CompressionType.NONE) return writer.position(); diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java index bf08896591dd..bccfd8e1352a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorSourceModel.java @@ -31,18 +31,19 @@ import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.BINARY_QUANTIZATION; import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.NONE; import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.PRODUCT_QUANTIZATION; - public enum VectorSourceModel { - ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5), - OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25), - BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0), - GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25), - COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25), - - OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery); + ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true), + OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5, true), + OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, true), + // BERT is not known to have unit length vectors in all cases + BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0, false), + GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true), + NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, false), + // Cohere does not officially say they have unit length vectors, but some users report that they do + COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, false), + + OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery, false); /** * Default similarity function for this model. @@ -58,18 +59,33 @@ public enum VectorSourceModel */ public final Function overqueryProvider; - VectorSourceModel(Function compressionProvider, double overqueryFactor) + /** + * Indicates that the model is known to have unit length vectors. When false, the runtime checks per graph + * until a non-unit length vector is found. + */ + private final boolean knownUnitLength; + + VectorSourceModel(Function compressionProvider, + double overqueryFactor, + boolean knownUnitLength) { - this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor); + this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor, knownUnitLength); } VectorSourceModel(VectorSimilarityFunction defaultSimilarityFunction, Function compressionProvider, - Function overqueryProvider) + Function overqueryProvider, + boolean knownUnitLength) { this.defaultSimilarityFunction = defaultSimilarityFunction; this.compressionProvider = compressionProvider; this.overqueryProvider = overqueryProvider; + this.knownUnitLength = knownUnitLength; + } + + public boolean hasKnownUnitLengthVectors() + { + return knownUnitLength; } public static VectorSourceModel fromString(String value)