Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,23 @@
import java.io.Closeable;
import java.io.IOException;

import javax.annotation.concurrent.NotThreadSafe;

import org.apache.cassandra.index.sai.disk.io.IndexInputReader;
import org.apache.cassandra.index.sai.utils.IndexFileUtils;
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
import org.apache.cassandra.io.util.FileHandle;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.lucene.codecs.CodecUtil;

@NotThreadSafe
public class DocLengthsReader implements Closeable
{
private final FileHandle fileHandle;
private final IndexInputReader input;
private final SegmentMetadata.ComponentMetadata componentMetadata;

public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata)
{
this.fileHandle = fileHandle;
this.input = IndexFileUtils.instance.openInput(fileHandle);
this.componentMetadata = componentMetadata;
}
Expand All @@ -53,7 +54,7 @@ public int get(int rowID) throws IOException
@Override
public void close() throws IOException
{
FileUtils.close(fileHandle, input);
FileUtils.close(input);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.io.sstable.format.SSTableReadsListener;
import org.apache.cassandra.io.util.FileHandle;
import org.apache.cassandra.io.util.FileUtils;
import org.apache.cassandra.utils.AbstractIterator;
import org.apache.cassandra.utils.CloseableIterator;
Expand All @@ -80,7 +81,8 @@ public class InvertedIndexSearcher extends IndexSearcher
private final Version version;
private final boolean filterRangeResults;
private final SSTableReader sstable;
private final DocLengthsReader docLengthsReader;
private final SegmentMetadata.ComponentMetadata docLengthsMeta;
private final FileHandle docLengths;
private final long segmentRowIdOffset;

protected InvertedIndexSearcher(SSTableContext sstableContext,
Expand All @@ -99,9 +101,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext,
this.version = version;
this.filterRangeResults = filterRangeResults;
perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics();
var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS);
this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset;
this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta);
this.docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS);
this.docLengths = docLengthsMeta == null ? null : indexFiles.docLengths();

Map<String,String> map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes;
String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER);
Expand Down Expand Up @@ -176,7 +178,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending()));
return toMetaSortedIterator(iter, queryContext);
}
if (docLengthsReader == null)
if (docLengthsMeta == null)
throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");

// find documents that match each term
Expand All @@ -194,11 +196,12 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express

var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
var merged = IntersectingPostingList.intersect(postingLists);

var docLengthsReader = new DocLengthsReader(docLengths, docLengthsMeta);

// Wrap the iterator with resource management
var it = new AbstractIterator<DocTF>() { // Anonymous class extends AbstractIterator
private boolean closed;

@Override
protected DocTF computeNext()
{
Expand All @@ -222,7 +225,7 @@ public void close()
{
if (closed) return;
closed = true;
FileUtils.closeQuietly(pkm, merged);
FileUtils.closeQuietly(pkm, merged, docLengthsReader);
}
};
return bm25Internal(it, queryTerms, documentFrequencies);
Expand Down Expand Up @@ -250,7 +253,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
{
if (!orderer.isBM25())
return super.orderResultsBy(reader, queryContext, keys, orderer, limit);
if (docLengthsReader == null)
if (docLengthsMeta == null)
throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");

var queryTerms = orderer.getQueryTerms();
Expand Down Expand Up @@ -279,7 +282,7 @@ public String toString()
@Override
public void close()
{
FileUtils.closeQuietly(reader, docLengthsReader);
FileUtils.closeQuietly(reader, docLengths);
}

/**
Expand Down
44 changes: 44 additions & 0 deletions test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,28 @@

package org.apache.cassandra.index.sai.cql;

import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.cassandra.cql3.UntypedResultSet;
import org.apache.cassandra.index.sai.SAITester;
import org.apache.cassandra.index.sai.SAIUtil;
import org.apache.cassandra.index.sai.disk.format.Version;
import org.apache.cassandra.index.sai.disk.v1.DocLengthsReader;
import org.apache.cassandra.index.sai.disk.v1.SegmentBuilder;
import org.apache.cassandra.index.sai.plan.QueryController;
import org.apache.cassandra.inject.Injections;
import org.apache.cassandra.inject.InvokePointBuilder;

import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.junit.Assert.assertEquals;

public class BM25Test extends SAITester
{
Expand Down Expand Up @@ -507,4 +518,37 @@ public void testQueryEmptyTable()
var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1");
assertThat(result).hasSize(0);
}

@Test
public void testBM25RaceConditionConcurrentQueriesInInvertedIndexSearcher() throws Throwable
{
createTable("CREATE TABLE %s (pk int, v text, PRIMARY KEY (pk))");
analyzeIndex();

// Create 3 docs that have the same BM25 score and will be our top docs
execute("INSERT INTO %s (pk, v) VALUES (1, 'apple apple apple')");
execute("INSERT INTO %s (pk, v) VALUES (2, 'apple apple apple')");
execute("INSERT INTO %s (pk, v) VALUES (3, 'apple apple apple')");

// Now insert a lot of docs that will hit the query, but will be lower in frequency and therefore in score
for (int i = 4; i < 1000; i++)
execute("INSERT INTO %s (pk, v) VALUES (?, 'apple apple')", i);

// Bug only present in sstable
flush();

// Trigger many concurrent queries
final ExecutorService executor = Executors.newFixedThreadPool(10);
String select = "SELECT pk FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3";
var futures = new ArrayList<Future<UntypedResultSet>>();
for (int i = 0; i < 100; i++)
futures.add(executor.submit(() -> execute(select)));

// The top results are always the same rows
for (Future<UntypedResultSet> future : futures)
assertRowsIgnoringOrder(future.get(), row(1), row(2), row(3));

// Shutdown executor
assertEquals(0, executor.shutdownNow().size());
}
}
Loading