Skip to content

Commit 97aa041

Browse files
CNDB-15469: Implement NVQ for vector graphs built by compaction
1 parent 9de81c5 commit 97aa041

File tree

7 files changed

+250
-56
lines changed

7 files changed

+250
-56
lines changed

src/java/org/apache/cassandra/config/CassandraRelevantProperties.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ public enum CassandraRelevantProperties
415415
SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"),
416416
// Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index.
417417
SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"),
418+
// Use nvq when building graphs in compaction
419+
SAI_VECTOR_ENABLE_NVQ("cassandra.sai.vector.enable_nvq", "true"),
418420
/**
419421
* Whether to disable auto-compaction
420422
*/

src/java/org/apache/cassandra/index/sai/disk/format/Version.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.cassandra.index.sai.disk.v5.V5OnDiskFormat;
4040
import org.apache.cassandra.index.sai.disk.v6.V6OnDiskFormat;
4141
import org.apache.cassandra.index.sai.disk.v7.V7OnDiskFormat;
42+
import org.apache.cassandra.index.sai.disk.v8.V8OnDiskFormat;
4243
import org.apache.cassandra.index.sai.utils.TypeUtil;
4344
import org.apache.cassandra.io.sstable.format.SSTableFormat;
4445
import org.apache.cassandra.schema.SchemaConstants;
@@ -75,6 +76,8 @@ public class Version implements Comparable<Version>
7576
public static final Version EC = new Version("ec", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ec"));
7677
// total terms count serialization in index metadata, enables ANN_USE_SYNTHETIC_SCORE by default
7778
public static final Version ED = new Version("ed", V7OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ed"));
79+
// Vector feature: NVQ TODO is this EE or FA?
80+
public static final Version EE = new Version("EE", V8OnDiskFormat.instance, (c, i, g) -> stargazerFileNameFormat(c, i, g, "ee"));
7881

7982
// These are in reverse-chronological order so that the latest version is first. Version matching tests
8083
// are more likely to match the latest version, so we want to test that one first.

src/java/org/apache/cassandra/index/sai/disk/vector/CassandraDiskAnn.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component
124124
// don't load full PQVectors, all we need is the metadata from the PQ at the start
125125
pq = ProductQuantization.load(reader);
126126
compression = new VectorCompression(VectorCompression.CompressionType.PRODUCT_QUANTIZATION,
127-
rawGraph.getDimension() * Float.BYTES,
127+
graph.getDimension() * Float.BYTES,
128128
pq.compressedVectorSize());
129129
}
130130
else

src/java/org/apache/cassandra/index/sai/disk/vector/CassandraOnHeapGraph.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
import io.github.jbellis.jvector.graph.disk.feature.Feature;
4949
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
5050
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
51+
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
5152
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
5253
import io.github.jbellis.jvector.quantization.BinaryQuantization;
5354
import io.github.jbellis.jvector.quantization.CompressedVectors;
55+
import io.github.jbellis.jvector.quantization.NVQuantization;
5456
import io.github.jbellis.jvector.quantization.ProductQuantization;
5557
import io.github.jbellis.jvector.quantization.VectorCompressor;
5658
import io.github.jbellis.jvector.util.Accountable;
@@ -64,6 +66,7 @@
6466
import io.github.jbellis.jvector.vector.types.VectorFloat;
6567
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
6668
import org.agrona.collections.IntHashSet;
69+
import org.apache.cassandra.config.CassandraRelevantProperties;
6770
import org.apache.cassandra.db.compaction.CompactionSSTable;
6871
import org.apache.cassandra.db.marshal.VectorType;
6972
import org.apache.cassandra.db.memtable.Memtable;
@@ -101,6 +104,9 @@ public enum PQVersion {
101104
V1, // includes unit vector calculation
102105
}
103106

107+
/** Whether to use NVQ when writing indexes (assuming all other conditions are met) */
108+
private static final boolean ENABLE_NVQ = CassandraRelevantProperties.SAI_VECTOR_ENABLE_NVQ.getBoolean();
109+
104110
/** minimum number of rows to perform PQ codebook generation */
105111
public static final int MIN_PQ_ROWS = 1024;
106112

@@ -127,6 +133,8 @@ public enum PQVersion {
127133
// we don't need to explicitly close these since only on-heap resources are involved
128134
private final ThreadLocal<GraphSearcherAccessManager> searchers;
129135

136+
private final boolean writeNvq;
137+
130138
/**
131139
* @param forSearching if true, vectorsByKey will be initialized and populated with vectors as they are added
132140
*/
@@ -159,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
159167
allVectorsAreUnitLength = true;
160168

161169
int jvectorVersion = context.version().onDiskFormat().jvectorFileFormatVersion();
170+
// NVQ is only written during compaction to save on compute costs
171+
writeNvq = ENABLE_NVQ && jvectorVersion >= 6 && !forSearching;
172+
162173
// This is only a warning since it's not a fatal error to write without hierarchy
163174
if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4)
164175
logger.warn("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " +
@@ -439,6 +450,9 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
439450

440451
OrdinalMapper ordinalMapper = remappedPostings.ordinalMapper;
441452

453+
// Write the NVQ feature, optimize when https://github.com/datastax/jvector/pull/549 is merged
454+
NVQuantization nvq = writeNvq ? NVQuantization.compute(vectorValues, 2) : null;
455+
442456
IndexComponent.ForWrite termsDataComponent = perIndexComponents.addOrGet(IndexComponentType.TERMS_DATA);
443457
var indexFile = termsDataComponent.file();
444458
long termsOffset = SAICodecUtils.headerSize();
@@ -450,7 +464,7 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
450464
.withStartOffset(termsOffset)
451465
.withVersion(perIndexComponents.version().onDiskFormat().jvectorFileFormatVersion())
452466
.withMapper(ordinalMapper)
453-
.with(new InlineVectors(vectorValues.dimension()))
467+
.with(nvq != null ? new NVQ(nvq) : new InlineVectors(vectorValues.dimension()))
454468
.build())
455469
{
456470
SAICodecUtils.writeHeader(pqOutput);
@@ -483,8 +497,10 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
483497

484498
// write the graph
485499
var start = System.nanoTime();
486-
var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
487-
indexWriter.write(suppliers);
500+
var supplier = nvq != null
501+
? Feature.singleStateFactory(FeatureId.NVQ_VECTORS, nodeId -> new NVQ.State(nvq.encode(vectorValues.getVector(nodeId))))
502+
: Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
503+
indexWriter.write(supplier);
488504
SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
489505
logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
490506
long termsLength = indexWriter.getOutput().position() - termsOffset;

0 commit comments

Comments
 (0)