4848import  io .github .jbellis .jvector .graph .disk .feature .Feature ;
4949import  io .github .jbellis .jvector .graph .disk .feature .FeatureId ;
5050import  io .github .jbellis .jvector .graph .disk .feature .InlineVectors ;
51+ import  io .github .jbellis .jvector .graph .disk .feature .NVQ ;
5152import  io .github .jbellis .jvector .graph .similarity .DefaultSearchScoreProvider ;
5253import  io .github .jbellis .jvector .quantization .BinaryQuantization ;
5354import  io .github .jbellis .jvector .quantization .CompressedVectors ;
55+ import  io .github .jbellis .jvector .quantization .NVQuantization ;
5456import  io .github .jbellis .jvector .quantization .ProductQuantization ;
5557import  io .github .jbellis .jvector .quantization .VectorCompressor ;
5658import  io .github .jbellis .jvector .util .Accountable ;
6466import  io .github .jbellis .jvector .vector .types .VectorFloat ;
6567import  io .github .jbellis .jvector .vector .types .VectorTypeSupport ;
6668import  org .agrona .collections .IntHashSet ;
69+ import  org .apache .cassandra .config .CassandraRelevantProperties ;
6770import  org .apache .cassandra .db .compaction .CompactionSSTable ;
6871import  org .apache .cassandra .db .marshal .VectorType ;
6972import  org .apache .cassandra .db .memtable .Memtable ;
@@ -101,6 +104,9 @@ public enum PQVersion {
101104        V1 , // includes unit vector calculation 
102105    }
103106
107+     /** Whether to use NVQ when writing indexes (assuming all other conditions are met) */ 
108+     private  static  final  boolean  ENABLE_NVQ  = CassandraRelevantProperties .SAI_VECTOR_ENABLE_NVQ .getBoolean ();
109+ 
104110    /** minimum number of rows to perform PQ codebook generation */ 
105111    public  static  final  int  MIN_PQ_ROWS  = 1024 ;
106112
@@ -127,6 +133,8 @@ public enum PQVersion {
127133    // we don't need to explicitly close these since only on-heap resources are involved 
128134    private  final  ThreadLocal <GraphSearcherAccessManager > searchers ;
129135
136+     private  final  boolean  writeNvq ;
137+ 
130138    /** 
131139     * @param forSearching if true, vectorsByKey will be initialized and populated with vectors as they are added 
132140     */ 
@@ -159,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
159167        allVectorsAreUnitLength  = true ;
160168
161169        int  jvectorVersion  = context .version ().onDiskFormat ().jvectorFileFormatVersion ();
170+         // NVQ is only written during compaction to save on compute costs 
171+         writeNvq  = ENABLE_NVQ  && jvectorVersion  >= 6  && !forSearching ;
172+ 
162173        // This is only a warning since it's not a fatal error to write without hierarchy 
163174        if  (indexConfig .isHierarchyEnabled () && jvectorVersion  < 4 )
164175            logger .warn ("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. "  +
@@ -439,6 +450,9 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
439450
440451        OrdinalMapper  ordinalMapper  = remappedPostings .ordinalMapper ;
441452
453+         // Write the NVQ feature, optimize when https://github.com/datastax/jvector/pull/549 is merged 
454+         NVQuantization  nvq  = writeNvq  ? NVQuantization .compute (vectorValues , 2 ) : null ;
455+ 
442456        IndexComponent .ForWrite  termsDataComponent  = perIndexComponents .addOrGet (IndexComponentType .TERMS_DATA );
443457        var  indexFile  = termsDataComponent .file ();
444458        long  termsOffset  = SAICodecUtils .headerSize ();
@@ -450,7 +464,7 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
450464                               .withStartOffset (termsOffset )
451465                               .withVersion (perIndexComponents .version ().onDiskFormat ().jvectorFileFormatVersion ())
452466                               .withMapper (ordinalMapper )
453-                                .with (new  InlineVectors (vectorValues .dimension ()))
467+                                .with (nvq  !=  null  ?  new   NVQ ( nvq ) :  new  InlineVectors (vectorValues .dimension ()))
454468                               .build ())
455469        {
456470            SAICodecUtils .writeHeader (pqOutput );
@@ -483,8 +497,10 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
483497
484498            // write the graph 
485499            var  start  = System .nanoTime ();
486-             var  suppliers  = Feature .singleStateFactory (FeatureId .INLINE_VECTORS , nodeId  -> new  InlineVectors .State (vectorValues .getVector (nodeId )));
487-             indexWriter .write (suppliers );
500+             var  supplier  = nvq  != null 
501+                             ? Feature .singleStateFactory (FeatureId .NVQ_VECTORS , nodeId  -> new  NVQ .State (nvq .encode (vectorValues .getVector (nodeId ))))
502+                             : Feature .singleStateFactory (FeatureId .INLINE_VECTORS , nodeId  -> new  InlineVectors .State (vectorValues .getVector (nodeId )));
503+             indexWriter .write (supplier );
488504            SAICodecUtils .writeFooter (indexWriter .getOutput (), indexWriter .checksum ());
489505            logger .info ("Writing graph took {}ms" , (System .nanoTime () - start ) / 1_000_000 );
490506            long  termsLength  = indexWriter .getOutput ().position () - termsOffset ;
0 commit comments