Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ public enum PQVersion {
private final InvalidVectorBehavior invalidVectorBehavior;
private final IntHashSet deletedOrdinals;
private volatile boolean hasDeletions;
private volatile boolean unitVectors;

// we don't need to explicitly close these since only on-heap resources are involved
private final ThreadLocal<GraphSearcherAccessManager> searchers;
Expand Down Expand Up @@ -157,6 +158,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
vectorsByKey = forSearching ? new NonBlockingHashMap<>() : null;
invalidVectorBehavior = forSearching ? InvalidVectorBehavior.FAIL : InvalidVectorBehavior.IGNORE;

// We start by assuming the vectors are unit vectors and then if they are not, we will correct it.
unitVectors = true;

int jvectorVersion = Version.current().onDiskFormat().jvectorFileFormatVersion();
// This is only a warning since it's not a fatal error to write without hierarchy
if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4)
Expand Down Expand Up @@ -269,6 +273,12 @@ public long add(ByteBuffer term, T key)
var success = postingsByOrdinal.compareAndPut(ordinal, null, postings);
assert success : "postingsByOrdinal already contains an entry for ordinal " + ordinal;
bytesUsed += builder.addGraphNode(ordinal, vector);

// We safely added to the graph, check if we need to check for unit length
if (sourceModel.hasKnownUnitLengthVectors() || unitVectors)
if (!(Math.abs(VectorUtil.dotProduct(vector, vector) - 1.0f) < 0.01))
unitVectors = false;

return bytesUsed;
}
else
Expand Down Expand Up @@ -560,7 +570,6 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
// Build encoder and compress vectors
VectorCompressor<?> compressor; // will be null if we can't compress
CompressedVectors cv = null;
boolean containsUnitVectors;
// limit the PQ computation and encoding to one index at a time -- goal during flush is to
// evict from memory ASAP so better to do the PQ build (in parallel) one at a time
synchronized (CassandraOnHeapGraph.class)
Expand All @@ -580,15 +589,10 @@ private long writePQ(SequentialWriter writer, V5VectorPostingsWriter.RemappedPos
// encode (compress) the vectors to save
if (compressor != null)
cv = compressor.encodeAll(new RemappedVectorValues(remapped, remapped.maxNewOrdinal, vectorValues));

containsUnitVectors = IntStream.range(0, vectorValues.size())
.parallel()
.mapToObj(vectorValues::getVector)
.allMatch(v -> Math.abs(VectorUtil.dotProduct(v, v) - 1.0f) < 0.01);
}

var actualType = compressor == null ? CompressionType.NONE : preferredCompression.type;
writePqHeader(writer, containsUnitVectors, actualType);
writePqHeader(writer, unitVectors, actualType);
if (actualType == CompressionType.NONE)
return writer.position();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.BINARY_QUANTIZATION;
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.NONE;
import static org.apache.cassandra.index.sai.disk.vector.VectorCompression.CompressionType.PRODUCT_QUANTIZATION;

public enum VectorSourceModel
{
ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5),
OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),
BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0),
GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25),
COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25),

OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery);
ADA002((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
OPENAI_V3_SMALL((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.5, true),
OPENAI_V3_LARGE((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, true),
// BERT is not known to have unit length vectors in all cases
BERT(COSINE, (dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.25), __ -> 1.0, false),
GECKO((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, true),
NV_QA_4((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.125), 1.25, false),
// Cohere does not officially say they have unit length vectors, but some users report that they do
COHERE_V3((dimension) -> new VectorCompression(PRODUCT_QUANTIZATION, dimension, 0.0625), 1.25, false),

OTHER(COSINE, VectorSourceModel::genericCompressionFor, VectorSourceModel::genericOverquery, false);

/**
* Default similarity function for this model.
Expand All @@ -58,18 +59,33 @@ public enum VectorSourceModel
*/
public final Function<VectorCompression, Double> overqueryProvider;

VectorSourceModel(Function<Integer, VectorCompression> compressionProvider, double overqueryFactor)
/**
* Indicates that the model is known to have unit length vectors. When false, the runtime checks per graph
* until a non-unit length vector is found.
*/
private final boolean knownUnitLength;

VectorSourceModel(Function<Integer, VectorCompression> compressionProvider,
double overqueryFactor,
boolean knownUnitLength)
{
this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor);
this(DOT_PRODUCT, compressionProvider, __ -> overqueryFactor, knownUnitLength);
}

VectorSourceModel(VectorSimilarityFunction defaultSimilarityFunction,
Function<Integer, VectorCompression> compressionProvider,
Function<VectorCompression, Double> overqueryProvider)
Function<VectorCompression, Double> overqueryProvider,
boolean knownUnitLength)
{
this.defaultSimilarityFunction = defaultSimilarityFunction;
this.compressionProvider = compressionProvider;
this.overqueryProvider = overqueryProvider;
this.knownUnitLength = knownUnitLength;
}

public boolean hasKnownUnitLengthVectors()
{
return knownUnitLength;
}

public static VectorSourceModel fromString(String value)
Expand Down