4848import io .github .jbellis .jvector .graph .disk .feature .Feature ;
4949import io .github .jbellis .jvector .graph .disk .feature .FeatureId ;
5050import io .github .jbellis .jvector .graph .disk .feature .InlineVectors ;
51+ import io .github .jbellis .jvector .graph .disk .feature .NVQ ;
5152import io .github .jbellis .jvector .graph .similarity .DefaultSearchScoreProvider ;
5253import io .github .jbellis .jvector .quantization .BinaryQuantization ;
5354import io .github .jbellis .jvector .quantization .CompressedVectors ;
55+ import io .github .jbellis .jvector .quantization .NVQuantization ;
5456import io .github .jbellis .jvector .quantization .ProductQuantization ;
5557import io .github .jbellis .jvector .quantization .VectorCompressor ;
5658import io .github .jbellis .jvector .util .Accountable ;
6466import io .github .jbellis .jvector .vector .types .VectorFloat ;
6567import io .github .jbellis .jvector .vector .types .VectorTypeSupport ;
6668import org .agrona .collections .IntHashSet ;
69+ import org .apache .cassandra .config .CassandraRelevantProperties ;
6770import org .apache .cassandra .db .compaction .CompactionSSTable ;
6871import org .apache .cassandra .db .marshal .VectorType ;
6972import org .apache .cassandra .db .memtable .Memtable ;
@@ -101,6 +104,9 @@ public enum PQVersion {
101104 V1 , // includes unit vector calculation
102105 }
103106
107+ /** Whether to use NVQ when writing indexes (assuming all other conditions are met) */
108+ private static final boolean ENABLE_NVQ = CassandraRelevantProperties .SAI_VECTOR_ENABLE_NVQ .getBoolean ();
109+
104110 /** minimum number of rows to perform PQ codebook generation */
105111 public static final int MIN_PQ_ROWS = 1024 ;
106112
@@ -127,6 +133,8 @@ public enum PQVersion {
127133 // we don't need to explicitly close these since only on-heap resources are involved
128134 private final ThreadLocal <GraphSearcherAccessManager > searchers ;
129135
136+ private final boolean writeNvq ;
137+
130138 /**
131139 * @param forSearching if true, vectorsByKey will be initialized and populated with vectors as they are added
132140 */
@@ -159,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
159167 allVectorsAreUnitLength = true ;
160168
161169 int jvectorVersion = context .version ().onDiskFormat ().jvectorFileFormatVersion ();
170+ // NVQ is only written during compaction to save on compute costs
171+ writeNvq = ENABLE_NVQ && jvectorVersion >= 6 && !forSearching ;
172+
162173 // This is only a warning since it's not a fatal error to write without hierarchy
163174 if (indexConfig .isHierarchyEnabled () && jvectorVersion < 4 )
164175 logger .warn ("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " +
@@ -439,6 +450,9 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
439450
440451 OrdinalMapper ordinalMapper = remappedPostings .ordinalMapper ;
441452
453+ // Write the NVQ feature, optimize when https://github.com/datastax/jvector/pull/549 is merged
454+ NVQuantization nvq = writeNvq ? NVQuantization .compute (vectorValues , 2 ) : null ;
455+
442456 IndexComponent .ForWrite termsDataComponent = perIndexComponents .addOrGet (IndexComponentType .TERMS_DATA );
443457 var indexFile = termsDataComponent .file ();
444458 long termsOffset = SAICodecUtils .headerSize ();
@@ -450,7 +464,7 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
450464 .withStartOffset (termsOffset )
451465 .withVersion (perIndexComponents .version ().onDiskFormat ().jvectorFileFormatVersion ())
452466 .withMapper (ordinalMapper )
453- .with (new InlineVectors (vectorValues .dimension ()))
467+ .with (nvq != null ? new NVQ ( nvq ) : new InlineVectors (vectorValues .dimension ()))
454468 .build ())
455469 {
456470 SAICodecUtils .writeHeader (pqOutput );
@@ -483,8 +497,10 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
483497
484498 // write the graph
485499 var start = System .nanoTime ();
486- var suppliers = Feature .singleStateFactory (FeatureId .INLINE_VECTORS , nodeId -> new InlineVectors .State (vectorValues .getVector (nodeId )));
487- indexWriter .write (suppliers );
500+ var supplier = nvq != null
501+ ? Feature .singleStateFactory (FeatureId .NVQ_VECTORS , nodeId -> new NVQ .State (nvq .encode (vectorValues .getVector (nodeId ))))
502+ : Feature .singleStateFactory (FeatureId .INLINE_VECTORS , nodeId -> new InlineVectors .State (vectorValues .getVector (nodeId )));
503+ indexWriter .write (supplier );
488504 SAICodecUtils .writeFooter (indexWriter .getOutput (), indexWriter .checksum ());
489505 logger .info ("Writing graph took {}ms" , (System .nanoTime () - start ) / 1_000_000 );
490506 long termsLength = indexWriter .getOutput ().position () - termsOffset ;
0 commit comments