Skip to content

Commit c93158a

Browse files
CNDB-12899: Add a cached version of the estimatedPartitionCount metric (#1559)
### What is the issue [CNDB-12899](riptano/cndb#12899) `CompactionRealm.estimatedPartitionCount()` is very expensive ### What does this PR fix and why was it fixed Adds a cached version of the metric and removes the memtable partitions from the calculation to make it more precise for the compaction use case. Also makes sure that the `estimatedPartitionCount` metric is not recalculated if the table's data view (i.e. sstable and memtable set) has not changed. --------- Co-authored-by: Szymon Miężał <[email protected]>
1 parent 60bfb7e commit c93158a

File tree

10 files changed

+112
-25
lines changed

10 files changed

+112
-25
lines changed

src/java/org/apache/cassandra/db/compaction/CompactionRealm.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.cassandra.db.Directories;
3232
import org.apache.cassandra.db.DiskBoundaries;
3333
import org.apache.cassandra.db.compaction.unified.Environment;
34-
import org.apache.cassandra.db.compaction.unified.RealEnvironment;
3534
import org.apache.cassandra.db.lifecycle.LifecycleTransaction;
3635
import org.apache.cassandra.db.lifecycle.SSTableSet;
3736
import org.apache.cassandra.db.memtable.Memtable;
@@ -126,12 +125,12 @@ default IPartitioner getPartitioner()
126125
* Return the estimated partition count, used when the number of partitions in an sstable is not sufficient to give
127126
* a sensible range estimation.
128127
*/
129-
default long estimatedPartitionCount()
128+
default long estimatedPartitionCountInSSTables()
130129
{
131130
final long INITIAL_ESTIMATED_PARTITION_COUNT = 1 << 16; // If we don't yet have a count, use a sensible default.
132131
if (metrics() == null)
133132
return INITIAL_ESTIMATED_PARTITION_COUNT;
134-
final Long estimation = metrics().estimatedPartitionCount.getValue();
133+
final Long estimation = metrics().estimatedPartitionCountInSSTablesCached.getValue();
135134
if (estimation == null || estimation == 0)
136135
return INITIAL_ESTIMATED_PARTITION_COUNT;
137136
return estimation;

src/java/org/apache/cassandra/db/compaction/DelegatingShardManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public double shardSetCoverage()
6060
@Override
6161
public double minimumPerPartitionSpan()
6262
{
63-
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCount());
63+
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCountInSSTables());
6464
}
6565

6666
@Override

src/java/org/apache/cassandra/db/compaction/ShardManagerNoDisks.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public double shardSetCoverage()
8787

8888
public double minimumPerPartitionSpan()
8989
{
90-
return localSpaceCoverage() / Math.max(1, localRanges.getRealm().estimatedPartitionCount());
90+
return localSpaceCoverage() / Math.max(1, localRanges.getRealm().estimatedPartitionCountInSSTables());
9191
}
9292

9393
@Override

src/java/org/apache/cassandra/db/compaction/ShardManagerReplicaAware.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,15 @@
2020

2121
import java.util.ArrayList;
2222
import java.util.Arrays;
23-
import java.util.Set;
2423
import java.util.concurrent.ConcurrentHashMap;
2524

26-
import javax.annotation.Nullable;
27-
2825
import org.slf4j.Logger;
2926
import org.slf4j.LoggerFactory;
3027

31-
import org.apache.cassandra.db.PartitionPosition;
3228
import org.apache.cassandra.dht.IPartitioner;
3329
import org.apache.cassandra.dht.Range;
3430
import org.apache.cassandra.dht.Token;
3531
import org.apache.cassandra.dht.tokenallocator.IsolatedTokenAllocator;
36-
import org.apache.cassandra.io.sstable.format.SSTableReader;
37-
import org.apache.cassandra.io.sstable.format.SSTableWriter;
3832
import org.apache.cassandra.locator.AbstractReplicationStrategy;
3933
import org.apache.cassandra.locator.TokenMetadata;
4034

@@ -94,7 +88,7 @@ public double shardSetCoverage()
9488
@Override
9589
public double minimumPerPartitionSpan()
9690
{
97-
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCount());
91+
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCountInSSTables());
9892
}
9993

10094
@Override

src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ public static long getApproximateKeyCount(Iterable<? extends SSTableReader> ssta
332332
long count = -1;
333333

334334
if (Iterables.isEmpty(sstables))
335-
return count;
335+
return 0;
336336

337337
boolean failed = false;
338338
ICardinality cardinality = null;

src/java/org/apache/cassandra/metrics/TableMetrics.java

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.cassandra.metrics;
1919

20+
import java.lang.ref.WeakReference;
2021
import java.nio.BufferUnderflowException;
2122
import java.nio.ByteBuffer;
2223
import java.util.ArrayList;
@@ -30,6 +31,8 @@
3031
import java.util.concurrent.ConcurrentHashMap;
3132
import java.util.concurrent.ConcurrentMap;
3233
import java.util.concurrent.TimeUnit;
34+
import java.util.concurrent.atomic.AtomicReference;
35+
import java.util.function.LongSupplier;
3336
import java.util.function.Predicate;
3437
import java.util.stream.Stream;
3538

@@ -44,6 +47,7 @@
4447
import org.slf4j.Logger;
4548
import org.slf4j.LoggerFactory;
4649

50+
import com.codahale.metrics.CachedGauge;
4751
import com.codahale.metrics.Counter;
4852
import com.codahale.metrics.Gauge;
4953
import com.codahale.metrics.Histogram;
@@ -74,6 +78,7 @@
7478
import org.apache.cassandra.utils.Hex;
7579
import org.apache.cassandra.utils.MovingAverage;
7680
import org.apache.cassandra.utils.Pair;
81+
import org.apache.cassandra.utils.concurrent.Refs;
7782

7883
import static org.apache.cassandra.io.sstable.format.SSTableReader.selectOnlyBigTableReaders;
7984
import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics;
@@ -174,6 +179,12 @@ public String asCQLString()
174179
public final Gauge<long[]> estimatedPartitionSizeHistogram;
175180
/** Approximate number of keys in table. */
176181
public final Gauge<Long> estimatedPartitionCount;
182+
/** This function is used to calculate estimated partition count in sstables and store the calculated value for the
183+
* current set of sstables. */
184+
public final LongSupplier estimatedPartitionCountInSSTables;
185+
/** A cached version of the estimated partition count in sstables, used by compaction. This value will be more
186+
* precise when the table has a small number of partitions that keep getting written to. */
187+
public final Gauge<Long> estimatedPartitionCountInSSTablesCached;
177188
/** Histogram of estimated number of columns. */
178189
public final Gauge<long[]> estimatedColumnCountHistogram;
179190

@@ -605,20 +616,53 @@ public Long getValue()
605616
estimatedPartitionSizeHistogram = createTableGauge("EstimatedPartitionSizeHistogram", "EstimatedRowSizeHistogram",
606617
() -> combineHistograms(cfs.getSSTables(SSTableSet.CANONICAL),
607618
SSTableReader::getEstimatedPartitionSize), null);
608-
619+
620+
estimatedPartitionCountInSSTables = new LongSupplier()
621+
{
622+
// Since the sstables only change when the tracker view changes, we can cache the value.
623+
AtomicReference<Pair<WeakReference<View>, Long>> collected = new AtomicReference<>(Pair.create(new WeakReference<>(null), 0L));
624+
625+
public long getAsLong()
626+
{
627+
final View currentView = cfs.getTracker().getView();
628+
final Pair<WeakReference<View>, Long> currentCollected = collected.get();
629+
if (currentView != currentCollected.left.get())
630+
{
631+
Refs<SSTableReader> refs = Refs.tryRef(currentView.select(SSTableSet.CANONICAL));
632+
if (refs != null)
633+
{
634+
try (refs)
635+
{
636+
long collectedValue = SSTableReader.getApproximateKeyCount(refs);
637+
final Pair<WeakReference<View>, Long> newCollected = Pair.create(new WeakReference<>(currentView), collectedValue);
638+
collected.compareAndSet(currentCollected, newCollected); // okay if failed, a different thread did it
639+
return collectedValue;
640+
}
641+
}
642+
// If we can't reference, simply return the previous collected value; it can only result in a delay
643+
// in reporting the correct key count.
644+
}
645+
return currentCollected.right;
646+
}
647+
};
609648
estimatedPartitionCount = createTableGauge("EstimatedPartitionCount", "EstimatedRowCount", new Gauge<Long>()
610649
{
611650
public Long getValue()
612651
{
613-
long memtablePartitions = 0;
652+
long estimatedPartitions = estimatedPartitionCountInSSTables.getAsLong();
614653
for (Memtable memtable : cfs.getTracker().getView().getAllMemtables())
615-
memtablePartitions += memtable.partitionCount();
616-
try(ColumnFamilyStore.RefViewFragment refViewFragment = cfs.selectAndReference(View.selectFunction(SSTableSet.CANONICAL)))
617-
{
618-
return SSTableReader.getApproximateKeyCount(refViewFragment.sstables) + memtablePartitions;
619-
}
654+
estimatedPartitions += memtable.partitionCount();
655+
return estimatedPartitions;
620656
}
621657
}, null);
658+
estimatedPartitionCountInSSTablesCached = new CachedGauge<Long>(1, TimeUnit.SECONDS)
659+
{
660+
public Long loadValue()
661+
{
662+
return estimatedPartitionCountInSSTables.getAsLong();
663+
}
664+
};
665+
622666
estimatedColumnCountHistogram = createTableGauge("EstimatedColumnCountHistogram", "EstimatedColumnCountHistogram",
623667
() -> combineHistograms(cfs.getSSTables(SSTableSet.CANONICAL),
624668
SSTableReader::getEstimatedCellPerPartitionCount), null);

test/unit/org/apache/cassandra/db/compaction/DelegatingShardManagerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void testWrappingShardManagerNoDisks()
4242
{
4343
CompactionRealm realm = Mockito.mock(CompactionRealm.class);
4444
when(realm.getPartitioner()).thenReturn(partitioner);
45-
when(realm.estimatedPartitionCount()).thenReturn(1L << 16);
45+
when(realm.estimatedPartitionCountInSSTables()).thenReturn(1L << 16);
4646
SortedLocalRanges localRanges = SortedLocalRanges.forTestingFull(realm);
4747
ShardManager delegate = new ShardManagerNoDisks(localRanges);
4848

test/unit/org/apache/cassandra/db/compaction/ShardManagerReplicaAwareTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class ShardManagerReplicaAwareTest
4949
public void testRangeEndsForShardCountEqualtToNumTokensPlusOne() throws UnknownHostException
5050
{
5151
var mockCompationRealm = mock(CompactionRealm.class);
52-
when(mockCompationRealm.estimatedPartitionCount()).thenReturn(1L<<16);
52+
when(mockCompationRealm.estimatedPartitionCountInSSTables()).thenReturn(1L<<16);
5353

5454
for (int numTokens = 1; numTokens < 32; numTokens++)
5555
{
@@ -75,7 +75,7 @@ public void testRangeEndsForShardCountEqualtToNumTokensPlusOne() throws UnknownH
7575
public void testRangeEndsAreFromTokenListAndContainLowerRangeEnds() throws UnknownHostException
7676
{
7777
var mockCompationRealm = mock(CompactionRealm.class);
78-
when(mockCompationRealm.estimatedPartitionCount()).thenReturn(1L<<16);
78+
when(mockCompationRealm.estimatedPartitionCountInSSTables()).thenReturn(1L<<16);
7979

8080
for (int nodeCount = 1; nodeCount <= 6; nodeCount++)
8181
{

test/unit/org/apache/cassandra/db/compaction/ShardManagerTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import java.util.stream.Collectors;
2828

2929
import org.junit.Assert;
30-
import org.junit.Assume;
3130
import org.junit.Before;
3231
import org.junit.Test;
3332

@@ -69,7 +68,7 @@ public void setUp()
6968
localRanges = Mockito.mock(SortedLocalRanges.class, Mockito.withSettings().defaultAnswer(Mockito.CALLS_REAL_METHODS));
7069
Mockito.when(localRanges.getRanges()).thenAnswer(invocation -> weightedRanges);
7170
Mockito.when(localRanges.getRealm()).thenReturn(realm);
72-
Mockito.when(realm.estimatedPartitionCount()).thenReturn(10000L);
71+
Mockito.when(realm.estimatedPartitionCountInSSTables()).thenReturn(10000L);
7372
}
7473

7574
@Test

test/unit/org/apache/cassandra/metrics/TableMetricsTest.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.cassandra.metrics;
2020

2121
import java.io.IOException;
22+
import java.util.concurrent.TimeUnit;
2223
import java.util.function.Supplier;
2324
import java.util.stream.Collectors;
2425
import java.util.stream.Stream;
@@ -37,6 +38,7 @@
3738
import org.apache.cassandra.db.Keyspace;
3839
import org.apache.cassandra.exceptions.ConfigurationException;
3940
import org.apache.cassandra.service.EmbeddedCassandraService;
41+
import org.awaitility.Awaitility;
4042

4143
import static org.junit.Assert.assertEquals;
4244
import static org.junit.Assert.assertTrue;
@@ -363,6 +365,55 @@ public void testViewMetricsCleanupOnDrop()
363365
assertEquals(metrics.get().collect(Collectors.joining(",")), 0, metrics.get().count());
364366
}
365367

368+
@Test
369+
public void testEstimatedPartitionCount() throws InterruptedException
370+
{
371+
ColumnFamilyStore cfs = recreateTable();
372+
assertEquals(0L, cfs.metric.estimatedPartitionCount.getValue().longValue());
373+
assertEquals(0L, cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue());
374+
long startTime = System.currentTimeMillis();
375+
376+
int partitionCount = 10;
377+
int numRows = 100;
378+
379+
for (int i = 0; i < numRows; i++)
380+
session.execute(String.format("INSERT INTO %s.%s (id, val1, val2) VALUES (%d, '%s', '%s')", KEYSPACE, TABLE, i % partitionCount, "val" + i, "val" + i));
381+
382+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
383+
cfs.forceBlockingFlush(ColumnFamilyStore.FlushReason.UNIT_TESTS);
384+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
385+
386+
long estimatedPartitionCountInSSTables = cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue();
387+
long elapsedTime = System.currentTimeMillis() - startTime;
388+
// the caching time is one second; avoid flakiness by only checking if a long time has not passed
389+
// (Because we take the time after calling the method, elapsedTime < 1000 should also be stable, but let's also
390+
// accommodate the possibility that the cache uses a different timer with different tick times.)
391+
if (elapsedTime < 980)
392+
assertEquals(0, estimatedPartitionCountInSSTables);
393+
394+
for (int i = 0; i < numRows; i++)
395+
session.execute(String.format("INSERT INTO %s.%s (id, val1, val2) VALUES (%d, '%s', '%s')", KEYSPACE, TABLE, i % partitionCount, "val" + i, "val" + i));
396+
397+
estimatedPartitionCountInSSTables = cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue();
398+
elapsedTime = System.currentTimeMillis() - startTime;
399+
if (elapsedTime < 980)
400+
assertEquals(0, estimatedPartitionCountInSSTables);
401+
else if (elapsedTime >= 1020)
402+
assertEquals(partitionCount, estimatedPartitionCountInSSTables);
403+
404+
// The answer below is incorrect but what the metric currently returns.
405+
assertEquals(partitionCount * 2, cfs.metric.estimatedPartitionCount.getValue().longValue());
406+
cfs.forceBlockingFlush(ColumnFamilyStore.FlushReason.UNIT_TESTS);
407+
// Recalculation for the new sstable set will correct it.
408+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
409+
410+
// The cached estimatedPartitionCountInSSTables lags one second, check that.
411+
// Assert that the metric will return a correct value after at least a second passes
412+
Awaitility.await()
413+
.atMost(2, TimeUnit.SECONDS)
414+
.untilAsserted(() -> assertEquals(partitionCount, (long) cfs.metric.estimatedPartitionCountInSSTablesCached.getValue()));
415+
}
416+
366417

367418
@AfterClass
368419
public static void tearDown()

0 commit comments

Comments
 (0)