diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java index 45769f9337d8..3686eef3d811 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java @@ -20,6 +20,8 @@ import java.io.Closeable; import java.io.IOException; +import javax.annotation.concurrent.NotThreadSafe; + import org.apache.cassandra.index.sai.disk.io.IndexInputReader; import org.apache.cassandra.index.sai.utils.IndexFileUtils; import org.apache.cassandra.index.sai.utils.SAICodecUtils; @@ -27,15 +29,14 @@ import org.apache.cassandra.io.util.FileUtils; import org.apache.lucene.codecs.CodecUtil; +@NotThreadSafe public class DocLengthsReader implements Closeable { - private final FileHandle fileHandle; private final IndexInputReader input; private final SegmentMetadata.ComponentMetadata componentMetadata; public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata) { - this.fileHandle = fileHandle; this.input = IndexFileUtils.instance.openInput(fileHandle); this.componentMetadata = componentMetadata; } @@ -53,7 +54,7 @@ public int get(int rowID) throws IOException @Override public void close() throws IOException { - FileUtils.close(fileHandle, input); + FileUtils.close(input); } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 08a63818c709..59d04289c2c7 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -61,6 +61,7 @@ import org.apache.cassandra.index.sai.utils.SAICodecUtils; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.io.sstable.format.SSTableReadsListener; +import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; @@ -80,7 +81,8 @@ public class InvertedIndexSearcher extends IndexSearcher private final Version version; private final boolean filterRangeResults; private final SSTableReader sstable; - private final DocLengthsReader docLengthsReader; + private final SegmentMetadata.ComponentMetadata docLengthsMeta; + private final FileHandle docLengths; private final long segmentRowIdOffset; protected InvertedIndexSearcher(SSTableContext sstableContext, @@ -99,9 +101,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, this.version = version; this.filterRangeResults = filterRangeResults; perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics(); - var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS); this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset; - this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta); + this.docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS); + this.docLengths = docLengthsMeta == null ? null : indexFiles.docLengths(); Map map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes; String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER); @@ -176,7 +178,7 @@ public CloseableIterator orderBy(Orderer orderer, Express var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); return toMetaSortedIterator(iter, queryContext); } - if (docLengthsReader == null) + if (docLengthsMeta == null) throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); // find documents that match each term @@ -194,11 +196,12 @@ public CloseableIterator orderBy(Orderer orderer, Express var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); var merged = IntersectingPostingList.intersect(postingLists); - + var docLengthsReader = new DocLengthsReader(docLengths, docLengthsMeta); + // Wrap the iterator with resource management var it = new AbstractIterator() { // Anonymous class extends AbstractIterator private boolean closed; - + @Override protected DocTF computeNext() { @@ -222,7 +225,7 @@ public void close() { if (closed) return; closed = true; - FileUtils.closeQuietly(pkm, merged); + FileUtils.closeQuietly(pkm, merged, docLengthsReader); } }; return bm25Internal(it, queryTerms, documentFrequencies); @@ -250,7 +253,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { if (!orderer.isBM25()) return super.orderResultsBy(reader, queryContext, keys, orderer, limit); - if (docLengthsReader == null) + if (docLengthsMeta == null) throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt"); var queryTerms = orderer.getQueryTerms(); @@ -279,7 +282,7 @@ public String toString() @Override public void close() { - FileUtils.closeQuietly(reader, docLengthsReader); + FileUtils.closeQuietly(reader, docLengths); } /** diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java index 3a40e075114a..530daa0d66ea 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -18,9 +18,15 @@ package org.apache.cassandra.index.sai.cql; +import java.util.ArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + import org.junit.Before; import org.junit.Test; +import org.apache.cassandra.cql3.UntypedResultSet; import org.apache.cassandra.index.sai.SAITester; import org.apache.cassandra.index.sai.SAIUtil; import org.apache.cassandra.index.sai.disk.format.Version; @@ -29,6 +35,7 @@ import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.junit.Assert.assertEquals; public class BM25Test extends SAITester { @@ -507,4 +514,37 @@ public void testQueryEmptyTable() var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); assertThat(result).hasSize(0); } + + @Test + public void testBM25RaceConditionConcurrentQueriesInInvertedIndexSearcher() throws Throwable + { + createTable("CREATE TABLE %s (pk int, v text, PRIMARY KEY (pk))"); + analyzeIndex(); + + // Create 3 docs that have the same BM25 score and will be our top docs + execute("INSERT INTO %s (pk, v) VALUES (1, 'apple apple apple')"); + execute("INSERT INTO %s (pk, v) VALUES (2, 'apple apple apple')"); + execute("INSERT INTO %s (pk, v) VALUES (3, 'apple apple apple')"); + + // Now insert a lot of docs that will hit the query, but will be lower in frequency and therefore in score + for (int i = 4; i < 10000; i++) + execute("INSERT INTO %s (pk, v) VALUES (?, 'apple apple')", i); + + // Bug only present in sstable + flush(); + + // Trigger many concurrent queries + final ExecutorService executor = Executors.newFixedThreadPool(10); + String select = "SELECT pk FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"; + var futures = new ArrayList>(); + for (int i = 0; i < 1000; i++) + futures.add(executor.submit(() -> execute(select))); + + // The top results are always the same rows + for (Future future : futures) + assertRowsIgnoringOrder(future.get(), row(1), row(2), row(3)); + + // Shutdown executor + assertEquals(0, executor.shutdownNow().size()); + } }