Skip to content

Commit a1c064f

Browse files
blambovszymon-miezal
authored andcommitted
CNDB-12899: Add a cached version of the estimatedPartitionCount metric (#1559)
[CNDB-12899](riptano/cndb#12899) `CompactionRealm.estimatedPartitionCount()` is very expensive Adds a cached version of the metric and removes the memtable partitions from the calculation to make it more precise for the compaction use case. Also makes sure that the `estimatedPartitionCount` metric is not recalculated if the table's data view (i.e. sstable and memtable set) has not changed. --------- Co-authored-by: Szymon Miężał <[email protected]>
1 parent 72eeb37 commit a1c064f

File tree

10 files changed

+111
-17
lines changed

10 files changed

+111
-17
lines changed

src/java/org/apache/cassandra/db/compaction/CompactionRealm.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ default IPartitioner getPartitioner()
135135
* Return the estimated partition count, used when the number of partitions in an sstable is not sufficient to give
136136
* a sensible range estimation.
137137
*/
138-
default long estimatedPartitionCount()
138+
default long estimatedPartitionCountInSSTables()
139139
{
140140
final long INITIAL_ESTIMATED_PARTITION_COUNT = 1 << 16; // If we don't yet have a count, use a sensible default.
141141
if (metrics() == null)
142142
return INITIAL_ESTIMATED_PARTITION_COUNT;
143-
final Long estimation = metrics().estimatedPartitionCount.getValue();
143+
final Long estimation = metrics().estimatedPartitionCountInSSTablesCached.getValue();
144144
if (estimation == null || estimation == 0)
145145
return INITIAL_ESTIMATED_PARTITION_COUNT;
146146
return estimation;

src/java/org/apache/cassandra/db/compaction/DelegatingShardManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public double shardSetCoverage()
6060
@Override
6161
public double minimumPerPartitionSpan()
6262
{
63-
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCount());
63+
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCountInSSTables());
6464
}
6565

6666
@Override

src/java/org/apache/cassandra/db/compaction/ShardManagerNoDisks.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public double shardSetCoverage()
8787

8888
public double minimumPerPartitionSpan()
8989
{
90-
return localSpaceCoverage() / Math.max(1, localRanges.getRealm().estimatedPartitionCount());
90+
return localSpaceCoverage() / Math.max(1, localRanges.getRealm().estimatedPartitionCountInSSTables());
9191
}
9292

9393
@Override

src/java/org/apache/cassandra/db/compaction/ShardManagerReplicaAware.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public double shardSetCoverage()
8888
@Override
8989
public double minimumPerPartitionSpan()
9090
{
91-
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCount());
91+
return localSpaceCoverage() / Math.max(1, realm.estimatedPartitionCountInSSTables());
9292
}
9393

9494
@Override

src/java/org/apache/cassandra/io/sstable/format/SSTableReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ public static long getApproximateKeyCount(Iterable<? extends SSTableReader> ssta
316316
long count = -1;
317317

318318
if (Iterables.isEmpty(sstables))
319-
return count;
319+
return 0;
320320

321321
boolean failed = false;
322322
ICardinality cardinality = null;

src/java/org/apache/cassandra/metrics/TableMetrics.java

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.cassandra.metrics;
1919

20+
import java.lang.ref.WeakReference;
2021
import java.nio.BufferUnderflowException;
2122
import java.nio.ByteBuffer;
2223
import java.util.ArrayList;
@@ -30,6 +31,8 @@
3031
import java.util.concurrent.ConcurrentHashMap;
3132
import java.util.concurrent.ConcurrentMap;
3233
import java.util.concurrent.TimeUnit;
34+
import java.util.concurrent.atomic.AtomicReference;
35+
import java.util.function.LongSupplier;
3336
import java.util.function.Predicate;
3437
import java.util.stream.Stream;
3538
import javax.annotation.Nullable;
@@ -44,6 +47,7 @@
4447
import org.slf4j.Logger;
4548
import org.slf4j.LoggerFactory;
4649

50+
import com.codahale.metrics.CachedGauge;
4751
import com.codahale.metrics.Counter;
4852
import com.codahale.metrics.Gauge;
4953
import com.codahale.metrics.Histogram;
@@ -76,6 +80,7 @@
7680
import org.apache.cassandra.utils.Hex;
7781
import org.apache.cassandra.utils.MovingAverage;
7882
import org.apache.cassandra.utils.Pair;
83+
import org.apache.cassandra.utils.concurrent.Refs;
7984

8085
import static java.util.concurrent.TimeUnit.MICROSECONDS;
8186
import static org.apache.cassandra.metrics.CassandraMetricsRegistry.Metrics;
@@ -177,6 +182,12 @@ public String asCQLString()
177182
public final Gauge<long[]> estimatedPartitionSizeHistogram;
178183
/** Approximate number of keys in table. */
179184
public final Gauge<Long> estimatedPartitionCount;
185+
/** This function is used to calculate estimated partition count in sstables and store the calculated value for the
186+
* current set of sstables. */
187+
public final LongSupplier estimatedPartitionCountInSSTables;
188+
/** A cached version of the estimated partition count in sstables, used by compaction. This value will be more
189+
* precise when the table has a small number of partitions that keep getting written to. */
190+
public final Gauge<Long> estimatedPartitionCountInSSTablesCached;
180191
/** Histogram of estimated number of columns. */
181192
public final Gauge<long[]> estimatedColumnCountHistogram;
182193
/** Approximate number of rows in SSTable*/
@@ -649,20 +660,53 @@ public Long getValue()
649660
estimatedPartitionSizeHistogram = createTableGauge("EstimatedPartitionSizeHistogram", "EstimatedRowSizeHistogram",
650661
() -> combineHistograms(cfs.getSSTables(SSTableSet.CANONICAL),
651662
SSTableReader::getEstimatedPartitionSize), null);
652-
663+
664+
estimatedPartitionCountInSSTables = new LongSupplier()
665+
{
666+
// Since the sstables only change when the tracker view changes, we can cache the value.
667+
AtomicReference<Pair<WeakReference<View>, Long>> collected = new AtomicReference<>(Pair.create(new WeakReference<>(null), 0L));
668+
669+
public long getAsLong()
670+
{
671+
final View currentView = cfs.getTracker().getView();
672+
final Pair<WeakReference<View>, Long> currentCollected = collected.get();
673+
if (currentView != currentCollected.left.get())
674+
{
675+
Refs<SSTableReader> refs = Refs.tryRef(currentView.select(SSTableSet.CANONICAL));
676+
if (refs != null)
677+
{
678+
try (refs)
679+
{
680+
long collectedValue = SSTableReader.getApproximateKeyCount(refs);
681+
final Pair<WeakReference<View>, Long> newCollected = Pair.create(new WeakReference<>(currentView), collectedValue);
682+
collected.compareAndSet(currentCollected, newCollected); // okay if failed, a different thread did it
683+
return collectedValue;
684+
}
685+
}
686+
// If we can't reference, simply return the previous collected value; it can only result in a delay
687+
// in reporting the correct key count.
688+
}
689+
return currentCollected.right;
690+
}
691+
};
653692
estimatedPartitionCount = createTableGauge("EstimatedPartitionCount", "EstimatedRowCount", new Gauge<Long>()
654693
{
655694
public Long getValue()
656695
{
657-
long memtablePartitions = 0;
696+
long estimatedPartitions = estimatedPartitionCountInSSTables.getAsLong();
658697
for (Memtable memtable : cfs.getTracker().getView().getAllMemtables())
659-
memtablePartitions += memtable.partitionCount();
660-
try(ColumnFamilyStore.RefViewFragment refViewFragment = cfs.selectAndReference(View.selectFunction(SSTableSet.CANONICAL)))
661-
{
662-
return SSTableReader.getApproximateKeyCount(refViewFragment.sstables) + memtablePartitions;
663-
}
698+
estimatedPartitions += memtable.partitionCount();
699+
return estimatedPartitions;
664700
}
665701
}, null);
702+
estimatedPartitionCountInSSTablesCached = new CachedGauge<Long>(1, TimeUnit.SECONDS)
703+
{
704+
public Long loadValue()
705+
{
706+
return estimatedPartitionCountInSSTables.getAsLong();
707+
}
708+
};
709+
666710
estimatedColumnCountHistogram = createTableGauge("EstimatedColumnCountHistogram", "EstimatedColumnCountHistogram",
667711
() -> combineHistograms(cfs.getSSTables(SSTableSet.CANONICAL),
668712
SSTableReader::getEstimatedCellPerPartitionCount), null);

test/unit/org/apache/cassandra/db/compaction/DelegatingShardManagerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void testWrappingShardManagerNoDisks()
4242
{
4343
CompactionRealm realm = Mockito.mock(CompactionRealm.class);
4444
when(realm.getPartitioner()).thenReturn(partitioner);
45-
when(realm.estimatedPartitionCount()).thenReturn(1L << 16);
45+
when(realm.estimatedPartitionCountInSSTables()).thenReturn(1L << 16);
4646
SortedLocalRanges localRanges = SortedLocalRanges.forTestingFull(realm);
4747
ShardManager delegate = new ShardManagerNoDisks(localRanges);
4848

test/unit/org/apache/cassandra/db/compaction/ShardManagerReplicaAwareTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class ShardManagerReplicaAwareTest
4949
public void testRangeEndsForShardCountEqualtToNumTokensPlusOne() throws UnknownHostException
5050
{
5151
var mockCompationRealm = mock(CompactionRealm.class);
52-
when(mockCompationRealm.estimatedPartitionCount()).thenReturn(1L<<16);
52+
when(mockCompationRealm.estimatedPartitionCountInSSTables()).thenReturn(1L<<16);
5353

5454
for (int numTokens = 1; numTokens < 32; numTokens++)
5555
{
@@ -75,7 +75,7 @@ public void testRangeEndsForShardCountEqualtToNumTokensPlusOne() throws UnknownH
7575
public void testRangeEndsAreFromTokenListAndContainLowerRangeEnds() throws UnknownHostException
7676
{
7777
var mockCompationRealm = mock(CompactionRealm.class);
78-
when(mockCompationRealm.estimatedPartitionCount()).thenReturn(1L<<16);
78+
when(mockCompationRealm.estimatedPartitionCountInSSTables()).thenReturn(1L<<16);
7979

8080
for (int nodeCount = 1; nodeCount <= 6; nodeCount++)
8181
{

test/unit/org/apache/cassandra/db/compaction/ShardManagerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public void setUp()
6868
localRanges = Mockito.mock(SortedLocalRanges.class, Mockito.withSettings().defaultAnswer(Mockito.CALLS_REAL_METHODS));
6969
Mockito.when(localRanges.getRanges()).thenAnswer(invocation -> weightedRanges);
7070
Mockito.when(localRanges.getRealm()).thenReturn(realm);
71-
Mockito.when(realm.estimatedPartitionCount()).thenReturn(10000L);
71+
Mockito.when(realm.estimatedPartitionCountInSSTables()).thenReturn(10000L);
7272
}
7373

7474
@Test

test/unit/org/apache/cassandra/metrics/TableMetricsTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.apache.cassandra.exceptions.ConfigurationException;
4141
import org.apache.cassandra.service.EmbeddedCassandraService;
4242
import org.apache.cassandra.service.StorageService;
43+
import org.awaitility.Awaitility;
4344

4445
import static org.junit.Assert.assertEquals;
4546
import static org.junit.Assert.assertTrue;
@@ -407,6 +408,55 @@ public void testViewMetricsCleanupOnDrop()
407408
assertEquals(metrics.get().collect(Collectors.joining(",")), 0, metrics.get().count());
408409
}
409410

411+
@Test
412+
public void testEstimatedPartitionCount() throws InterruptedException
413+
{
414+
ColumnFamilyStore cfs = recreateTable();
415+
assertEquals(0L, cfs.metric.estimatedPartitionCount.getValue().longValue());
416+
assertEquals(0L, cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue());
417+
long startTime = System.currentTimeMillis();
418+
419+
int partitionCount = 10;
420+
int numRows = 100;
421+
422+
for (int i = 0; i < numRows; i++)
423+
session.execute(String.format("INSERT INTO %s.%s (id, val1, val2) VALUES (%d, '%s', '%s')", KEYSPACE, TABLE, i % partitionCount, "val" + i, "val" + i));
424+
425+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
426+
cfs.forceBlockingFlush(ColumnFamilyStore.FlushReason.UNIT_TESTS);
427+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
428+
429+
long estimatedPartitionCountInSSTables = cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue();
430+
long elapsedTime = System.currentTimeMillis() - startTime;
431+
// the caching time is one second; avoid flakiness by only checking if a long time has not passed
432+
// (Because we take the time after calling the method, elapsedTime < 1000 should also be stable, but let's also
433+
// accommodate the possibility that the cache uses a different timer with different tick times.)
434+
if (elapsedTime < 980)
435+
assertEquals(0, estimatedPartitionCountInSSTables);
436+
437+
for (int i = 0; i < numRows; i++)
438+
session.execute(String.format("INSERT INTO %s.%s (id, val1, val2) VALUES (%d, '%s', '%s')", KEYSPACE, TABLE, i % partitionCount, "val" + i, "val" + i));
439+
440+
estimatedPartitionCountInSSTables = cfs.metric.estimatedPartitionCountInSSTablesCached.getValue().longValue();
441+
elapsedTime = System.currentTimeMillis() - startTime;
442+
if (elapsedTime < 980)
443+
assertEquals(0, estimatedPartitionCountInSSTables);
444+
else if (elapsedTime >= 1020)
445+
assertEquals(partitionCount, estimatedPartitionCountInSSTables);
446+
447+
// The answer below is incorrect but what the metric currently returns.
448+
assertEquals(partitionCount * 2, cfs.metric.estimatedPartitionCount.getValue().longValue());
449+
cfs.forceBlockingFlush(ColumnFamilyStore.FlushReason.UNIT_TESTS);
450+
// Recalculation for the new sstable set will correct it.
451+
assertEquals(partitionCount, cfs.metric.estimatedPartitionCount.getValue().longValue());
452+
453+
// The cached estimatedPartitionCountInSSTables lags one second, check that.
454+
// Assert that the metric will return a correct value after at least a second passes
455+
Awaitility.await()
456+
.atMost(2, TimeUnit.SECONDS)
457+
.untilAsserted(() -> assertEquals(partitionCount, (long) cfs.metric.estimatedPartitionCountInSSTablesCached.getValue()));
458+
}
459+
410460

411461
@AfterClass
412462
public static void tearDown()

0 commit comments

Comments
 (0)