Skip to content

Commit f520f0b

Browse files
michaeljmarshalldjatnieks
authored andcommitted
CNDB-13196: Fix BM25 file access race condition (#1622)
Fixes: riptano/cndb#13196 We were incorrectly sharing a file reader across threads in BM25 queries, which can lead to invalid results, as I reproduced in the test. The PR fixes the issue by creating a `DocLengthsReader` object per query.
1 parent 83ba915 commit f520f0b

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

src/java/org/apache/cassandra/index/sai/disk/v1/DocLengthsReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,20 @@
2121
import java.io.IOException;
2222

2323
import org.apache.cassandra.index.sai.disk.io.IndexFileUtils;
24+
import javax.annotation.concurrent.NotThreadSafe;
25+
2426
import org.apache.cassandra.index.sai.disk.io.IndexInputReader;
2527
import org.apache.cassandra.io.util.FileHandle;
2628
import org.apache.cassandra.io.util.FileUtils;
2729

30+
@NotThreadSafe
2831
public class DocLengthsReader implements Closeable
2932
{
30-
private final FileHandle fileHandle;
3133
private final IndexInputReader input;
3234
private final SegmentMetadata.ComponentMetadata componentMetadata;
3335

3436
public DocLengthsReader(FileHandle fileHandle, SegmentMetadata.ComponentMetadata componentMetadata)
3537
{
36-
this.fileHandle = fileHandle;
3738
this.input = IndexFileUtils.instance().openInput(fileHandle);
3839
this.componentMetadata = componentMetadata;
3940
}
@@ -51,7 +52,7 @@ public int get(int rowID) throws IOException
5152
@Override
5253
public void close() throws IOException
5354
{
54-
FileUtils.close(fileHandle, input);
55+
FileUtils.close(input);
5556
}
5657
}
5758

src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
6262
import org.apache.cassandra.io.sstable.SSTableReadsListener;
6363
import org.apache.cassandra.io.sstable.format.SSTableReader;
64+
import org.apache.cassandra.io.util.FileHandle;
6465
import org.apache.cassandra.io.util.FileUtils;
6566
import org.apache.cassandra.utils.AbstractIterator;
6667
import org.apache.cassandra.utils.CloseableIterator;
@@ -80,7 +81,8 @@ public class InvertedIndexSearcher extends IndexSearcher
8081
private final Version version;
8182
private final boolean filterRangeResults;
8283
private final SSTableReader sstable;
83-
private final DocLengthsReader docLengthsReader;
84+
private final SegmentMetadata.ComponentMetadata docLengthsMeta;
85+
private final FileHandle docLengths;
8486
private final long segmentRowIdOffset;
8587

8688
protected InvertedIndexSearcher(SSTableContext sstableContext,
@@ -99,9 +101,9 @@ protected InvertedIndexSearcher(SSTableContext sstableContext,
99101
this.version = version;
100102
this.filterRangeResults = filterRangeResults;
101103
perColumnEventListener = (QueryEventListener.TrieIndexEventListener)indexContext.getColumnQueryMetrics();
102-
var docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS);
103104
this.segmentRowIdOffset = segmentMetadata.segmentRowIdOffset;
104-
this.docLengthsReader = docLengthsMeta == null ? null : new DocLengthsReader(indexFiles.docLengths(), docLengthsMeta);
105+
this.docLengthsMeta = segmentMetadata.componentMetadatas.getOptional(IndexComponentType.DOC_LENGTHS);
106+
this.docLengths = docLengthsMeta == null ? null : indexFiles.docLengths();
105107

106108
Map<String,String> map = metadata.componentMetadatas.get(IndexComponentType.TERMS_DATA).attributes;
107109
String footerPointerString = map.get(SAICodecUtils.FOOTER_POINTER);
@@ -176,7 +178,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
176178
var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending()));
177179
return toMetaSortedIterator(iter, queryContext);
178180
}
179-
if (docLengthsReader == null)
181+
if (docLengthsMeta == null)
180182
throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");
181183

182184
// find documents that match each term
@@ -194,11 +196,12 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
194196

195197
var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
196198
var merged = IntersectingPostingList.intersect(postingLists);
197-
199+
var docLengthsReader = new DocLengthsReader(docLengths, docLengthsMeta);
200+
198201
// Wrap the iterator with resource management
199202
var it = new AbstractIterator<DocTF>() { // Anonymous class extends AbstractIterator
200203
private boolean closed;
201-
204+
202205
@Override
203206
protected DocTF computeNext()
204207
{
@@ -222,7 +225,7 @@ public void close()
222225
{
223226
if (closed) return;
224227
closed = true;
225-
FileUtils.closeQuietly(pkm, merged);
228+
FileUtils.closeQuietly(pkm, merged, docLengthsReader);
226229
}
227230
};
228231
return bm25Internal(it, queryTerms, documentFrequencies);
@@ -250,7 +253,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
250253
{
251254
if (!orderer.isBM25())
252255
return super.orderResultsBy(reader, queryContext, keys, orderer, limit);
253-
if (docLengthsReader == null)
256+
if (docLengthsMeta == null)
254257
throw new InvalidRequestException(indexContext.getIndexName() + " does not support BM25 scoring until it is rebuilt");
255258

256259
var queryTerms = orderer.getQueryTerms();
@@ -279,7 +282,7 @@ public String toString()
279282
@Override
280283
public void close()
281284
{
282-
FileUtils.closeQuietly(reader, docLengthsReader);
285+
FileUtils.closeQuietly(reader, docLengths);
283286
}
284287

285288
/**

test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@
1818

1919
package org.apache.cassandra.index.sai.cql;
2020

21+
import java.util.ArrayList;
22+
import java.util.concurrent.ExecutorService;
23+
import java.util.concurrent.Executors;
24+
import java.util.concurrent.Future;
25+
2126
import org.junit.Before;
2227
import org.junit.Test;
2328

29+
import org.apache.cassandra.cql3.UntypedResultSet;
2430
import org.apache.cassandra.index.sai.SAITester;
2531
import org.apache.cassandra.index.sai.SAIUtil;
2632
import org.apache.cassandra.index.sai.disk.format.Version;
@@ -29,6 +35,7 @@
2935

3036
import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR;
3137
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
38+
import static org.junit.Assert.assertEquals;
3239

3340
public class BM25Test extends SAITester
3441
{
@@ -507,4 +514,37 @@ public void testQueryEmptyTable()
507514
var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1");
508515
assertThat(result).hasSize(0);
509516
}
517+
518+
@Test
519+
public void testBM25RaceConditionConcurrentQueriesInInvertedIndexSearcher() throws Throwable
520+
{
521+
createTable("CREATE TABLE %s (pk int, v text, PRIMARY KEY (pk))");
522+
analyzeIndex();
523+
524+
// Create 3 docs that have the same BM25 score and will be our top docs
525+
execute("INSERT INTO %s (pk, v) VALUES (1, 'apple apple apple')");
526+
execute("INSERT INTO %s (pk, v) VALUES (2, 'apple apple apple')");
527+
execute("INSERT INTO %s (pk, v) VALUES (3, 'apple apple apple')");
528+
529+
// Now insert a lot of docs that will hit the query, but will be lower in frequency and therefore in score
530+
for (int i = 4; i < 10000; i++)
531+
execute("INSERT INTO %s (pk, v) VALUES (?, 'apple apple')", i);
532+
533+
// Bug only present in sstable
534+
flush();
535+
536+
// Trigger many concurrent queries
537+
final ExecutorService executor = Executors.newFixedThreadPool(10);
538+
String select = "SELECT pk FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3";
539+
var futures = new ArrayList<Future<UntypedResultSet>>();
540+
for (int i = 0; i < 1000; i++)
541+
futures.add(executor.submit(() -> execute(select)));
542+
543+
// The top results are always the same rows
544+
for (Future<UntypedResultSet> future : futures)
545+
assertRowsIgnoringOrder(future.get(), row(1), row(2), row(3));
546+
547+
// Shutdown executor
548+
assertEquals(0, executor.shutdownNow().size());
549+
}
510550
}

0 commit comments

Comments
 (0)