Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ public enum CassandraRelevantProperties
SAI_VECTOR_FLUSH_THRESHOLD_MAX_ROWS("cassandra.sai.vector_flush_threshold_max_rows", "-1"),
// Use non-positive value to disable it. Period in millis to trigger a flush for SAI vector memtable index.
SAI_VECTOR_FLUSH_PERIOD_IN_MILLIS("cassandra.sai.vector_flush_period_in_millis", "-1"),
// Use nvq when building graphs in compaction
SAI_VECTOR_ENABLE_NVQ("cassandra.sai.vector.enable_nvq", "true"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Ted and Mariano, it is not time to enable by default yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to push it through as an index option that is disabled by default.

/**
* Whether to disable auto-compaction
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public CassandraDiskAnn(SSTableContext sstableContext, SegmentMetadata.Component
// don't load full PQVectors, all we need is the metadata from the PQ at the start
pq = ProductQuantization.load(reader);
compression = new VectorCompression(VectorCompression.CompressionType.PRODUCT_QUANTIZATION,
rawGraph.getDimension() * Float.BYTES,
graph.getDimension() * Float.BYTES,
pq.compressedVectorSize());
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
import io.github.jbellis.jvector.graph.disk.feature.Feature;
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
import io.github.jbellis.jvector.graph.disk.feature.InlineVectors;
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
import io.github.jbellis.jvector.quantization.BinaryQuantization;
import io.github.jbellis.jvector.quantization.CompressedVectors;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.quantization.VectorCompressor;
import io.github.jbellis.jvector.util.Accountable;
Expand All @@ -64,6 +66,7 @@
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.agrona.collections.IntHashSet;
import org.apache.cassandra.config.CassandraRelevantProperties;
import org.apache.cassandra.db.compaction.CompactionSSTable;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.db.memtable.Memtable;
Expand Down Expand Up @@ -101,6 +104,9 @@ public enum PQVersion {
V1, // includes unit vector calculation
}

/** Whether to use NVQ when writing indexes (assuming all other conditions are met) */
private static final boolean ENABLE_NVQ = CassandraRelevantProperties.SAI_VECTOR_ENABLE_NVQ.getBoolean();

/** minimum number of rows to perform PQ codebook generation */
public static final int MIN_PQ_ROWS = 1024;

Expand All @@ -127,6 +133,8 @@ public enum PQVersion {
// we don't need to explicitly close these since only on-heap resources are involved
private final ThreadLocal<GraphSearcherAccessManager> searchers;

private final boolean writeNvq;

/**
* @param forSearching if true, vectorsByKey will be initialized and populated with vectors as they are added
*/
Expand Down Expand Up @@ -159,6 +167,9 @@ public CassandraOnHeapGraph(IndexContext context, boolean forSearching, Memtable
allVectorsAreUnitLength = true;

int jvectorVersion = context.version().onDiskFormat().jvectorFileFormatVersion();
// NVQ is only written during compaction to save on compute costs
writeNvq = ENABLE_NVQ && jvectorVersion >= 4 && !forSearching;

// This is only a warning since it's not a fatal error to write without hierarchy
if (indexConfig.isHierarchyEnabled() && jvectorVersion < 4)
logger.warn("Hierarchical graphs configured but node configured with V3OnDiskFormat.JVECTOR_VERSION {}. " +
Expand Down Expand Up @@ -439,6 +450,9 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn

OrdinalMapper ordinalMapper = remappedPostings.ordinalMapper;

// Write the NVQ feature, optimize when https://github.com/datastax/jvector/pull/549 is merged
NVQuantization nvq = writeNvq ? NVQuantization.compute(vectorValues, 2) : null;

IndexComponent.ForWrite termsDataComponent = perIndexComponents.addOrGet(IndexComponentType.TERMS_DATA);
var indexFile = termsDataComponent.file();
long termsOffset = SAICodecUtils.headerSize();
Expand All @@ -450,7 +464,7 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn
.withStartOffset(termsOffset)
.withVersion(perIndexComponents.version().onDiskFormat().jvectorFileFormatVersion())
.withMapper(ordinalMapper)
.with(new InlineVectors(vectorValues.dimension()))
.with(nvq != null ? new NVQ(nvq) : new InlineVectors(vectorValues.dimension()))
.build())
{
SAICodecUtils.writeHeader(pqOutput);
Expand Down Expand Up @@ -483,8 +497,10 @@ public SegmentMetadata.ComponentMetadataMap flush(IndexComponents.ForWrite perIn

// write the graph
var start = System.nanoTime();
var suppliers = Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
indexWriter.write(suppliers);
var supplier = nvq != null
? Feature.singleStateFactory(FeatureId.NVQ_VECTORS, nodeId -> new NVQ.State(nvq.encode(vectorValues.getVector(nodeId))))
: Feature.singleStateFactory(FeatureId.INLINE_VECTORS, nodeId -> new InlineVectors.State(vectorValues.getVector(nodeId)));
indexWriter.write(supplier);
SAICodecUtils.writeFooter(indexWriter.getOutput(), indexWriter.checksum());
logger.info("Writing graph took {}ms", (System.nanoTime() - start) / 1_000_000);
long termsLength = indexWriter.getOutput().position() - termsOffset;
Expand Down
Loading