From 52c763a466cfc25420d3df4921dc7ca34c815111 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Sat, 28 Feb 2026 11:25:45 +0800 Subject: [PATCH 1/3] [#2716] feat(spark): Introduce option of max segments decompression to control memory usage --- .../client/impl/DecompressionWorker.java | 42 ++++++++++++++++++- .../client/impl/ShuffleReadClientImpl.java | 8 +++- .../client/impl/DecompressionWorkerTest.java | 38 ++++++++++++++++- .../uniffle/common/config/RssClientConf.java | 7 ++++ 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java index 1208447186..4026b6c35b 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java @@ -19,12 +19,15 @@ import java.nio.ByteBuffer; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,7 +58,10 @@ public class DecompressionWorker { private AtomicLong peekMemoryUsed = new AtomicLong(0); private AtomicLong nowMemoryUsed = new AtomicLong(0); - public DecompressionWorker(Codec codec, int threads, int fetchSecondsThreshold) { + private final Optional segmentPermits; + + public DecompressionWorker( + Codec codec, int threads, int fetchSecondsThreshold, int maxConcurrentDecompressionSegments) { if (codec == null) { throw new IllegalArgumentException("Codec cannot be null"); } @@ -67,6 +73,16 @@ public DecompressionWorker(Codec codec, int threads, int fetchSecondsThreshold) Executors.newFixedThreadPool(threads, ThreadUtils.getThreadFactory("decompressionWorker")); this.codec = codec; this.fetchSecondsThreshold = fetchSecondsThreshold; + + if (maxConcurrentDecompressionSegments <= 0) { + this.segmentPermits = Optional.empty(); + } else if (threads != 1) { + LOG.info( + "Disable backpressure control since threads is {} to avoid potential deadlock", threads); + this.segmentPermits = Optional.empty(); + } else { + this.segmentPermits = Optional.of(new Semaphore(maxConcurrentDecompressionSegments)); + } } public void add(int batchIndex, ShuffleDataResult shuffleDataResult) { @@ -81,6 +97,16 @@ public void add(int batchIndex, ShuffleDataResult shuffleDataResult) { CompletableFuture f = CompletableFuture.supplyAsync( () -> { + segmentPermits.ifPresent( + x -> { + try { + x.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while acquiring segment permit", e); + } + }); + int offset = bufferSegment.getOffset(); int length = bufferSegment.getLength(); ByteBuffer buffer = sharedByteBuffer.duplicate(); @@ -132,6 +158,7 @@ public DecompressedShuffleBlock get(int batchIndex, int segmentIndex) { // block if (block != null) { nowMemoryUsed.addAndGet(-block.getUncompressLength()); + segmentPermits.ifPresent(x -> x.release()); } return block; } @@ -163,4 +190,17 @@ public void close() { public long decompressionMillis() { return decompressionMillis.get() + decompressionBufferAllocationMillis.get(); } + + @VisibleForTesting + protected long getPeekMemoryUsed() { + return peekMemoryUsed.get(); + } + + @VisibleForTesting + protected int getAvailablePermits() { + if (segmentPermits.isPresent()) { + return segmentPermits.get().availablePermits(); + } + return -1; + } } diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index 364e98526c..7cc386ad7b 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -62,6 +62,7 @@ import static org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT; import static org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED; import static org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD; +import static org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS; public class ShuffleReadClientImpl implements ShuffleReadClient { @@ -165,9 +166,14 @@ private void init(ShuffleClientFactory.ReadClientBuilder builder) { if (builder.isOverlappingDecompressionEnabled()) { int fetchThreshold = builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD); + int maxSegments = + builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS); this.decompressionWorker = new DecompressionWorker( - builder.getCodec(), builder.getOverlappingDecompressionThreadNum(), fetchThreshold); + builder.getCodec(), + builder.getOverlappingDecompressionThreadNum(), + fetchThreshold, + maxSegments); } this.shuffleId = builder.getShuffleId(); this.partitionId = builder.getPartitionId(); diff --git a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java index 5d41ac4815..92b33b2181 100644 --- a/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java +++ b/client/src/test/java/org/apache/uniffle/client/impl/DecompressionWorkerTest.java @@ -21,7 +21,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import org.apache.uniffle.client.response.DecompressedShuffleBlock; @@ -36,10 +38,42 @@ public class DecompressionWorkerTest { + @Test + public void testBackpressure() throws Exception { + RssConf rssConf = new RssConf(); + rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP); + Codec codec = Codec.newInstance(rssConf).get(); + + int threads = 1; + int maxSegments = 10; + int fetchSecondsThreshold = 2; + DecompressionWorker worker = + new DecompressionWorker(codec, threads, fetchSecondsThreshold, maxSegments); + + ShuffleDataResult shuffleDataResult = createShuffleDataResult(maxSegments + 1, codec, 1024); + worker.add(0, shuffleDataResult); + + // case1: check the peek memory used is correct when the decompression is in progress + Awaitility.await() + .timeout(200, TimeUnit.MILLISECONDS) + .until(() -> 1024 * maxSegments == worker.getPeekMemoryUsed()); + assertEquals(0, worker.getAvailablePermits()); + + // case2: after the previous segments are consumed, the blocked segments can be gotten after the + // decompression is done + for (int i = 0; i < maxSegments; i++) { + worker.get(0, i); + } + Thread.sleep(10); + worker.get(0, maxSegments).getByteBuffer(); + assertEquals(1024 * maxSegments, worker.getPeekMemoryUsed()); + assertEquals(maxSegments, worker.getAvailablePermits()); + } + @Test public void testEmptyGet() throws Exception { DecompressionWorker worker = - new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10); + new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10, 10000); assertNull(worker.get(1, 1)); } @@ -78,7 +112,7 @@ public void test() throws Exception { RssConf rssConf = new RssConf(); rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP); Codec codec = Codec.newInstance(rssConf).get(); - DecompressionWorker worker = new DecompressionWorker(codec, 1, 10); + DecompressionWorker worker = new DecompressionWorker(codec, 1, 10, 100000); // create some data ShuffleDataResult shuffleDataResult = createShuffleDataResult(10, codec, 100); diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java index 0ce7be6351..37a79ce0e6 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java @@ -389,4 +389,11 @@ public class RssClientConf { .defaultValue(-1) .withDescription( "Fetch seconds threshold for overlapping decompress shuffle blocks."); + + public static final ConfigOption + RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS = + ConfigOptions.key("rss.client.read.overlappingDecompressionMaxConcurrentSegments") + .intType() + .defaultValue(-1) + .withDescription("Max concurrent segments number for overlapping decompression."); } From 59df817d58917a063b52803d0dbbf54a26e2b466 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Sat, 28 Feb 2026 13:38:58 +0800 Subject: [PATCH 2/3] checkstyle fix --- .../client/impl/DecompressionWorker.java | 79 ++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java index 4026b6c35b..66ff8d9cd8 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/DecompressionWorker.java @@ -96,44 +96,49 @@ public void add(int batchIndex, ShuffleDataResult shuffleDataResult) { for (BufferSegment bufferSegment : bufferSegments) { CompletableFuture f = CompletableFuture.supplyAsync( - () -> { - segmentPermits.ifPresent( - x -> { - try { - x.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while acquiring segment permit", e); + () -> { + try { + if (segmentPermits.isPresent()) { + segmentPermits.get().acquire(); } - }); - - int offset = bufferSegment.getOffset(); - int length = bufferSegment.getLength(); - ByteBuffer buffer = sharedByteBuffer.duplicate(); - buffer.position(offset); - buffer.limit(offset + length); - - int uncompressedLen = bufferSegment.getUncompressLength(); - - long startBufferAllocation = System.currentTimeMillis(); - ByteBuffer dst = - buffer.isDirect() - ? ByteBuffer.allocateDirect(uncompressedLen) - : ByteBuffer.allocate(uncompressedLen); - decompressionBufferAllocationMillis.addAndGet( - System.currentTimeMillis() - startBufferAllocation); - - long startDecompression = System.currentTimeMillis(); - codec.decompress(buffer, uncompressedLen, dst, 0); - decompressionMillis.addAndGet(System.currentTimeMillis() - startDecompression); - decompressionBytes.addAndGet(length); - - nowMemoryUsed.addAndGet(uncompressedLen); - resetPeekMemoryUsed(); - - return dst; - }, - executorService); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted while acquiring segment permit", e); + return null; + } + + int offset = bufferSegment.getOffset(); + int length = bufferSegment.getLength(); + ByteBuffer buffer = sharedByteBuffer.duplicate(); + buffer.position(offset); + buffer.limit(offset + length); + + int uncompressedLen = bufferSegment.getUncompressLength(); + + long startBufferAllocation = System.currentTimeMillis(); + ByteBuffer dst = + buffer.isDirect() + ? ByteBuffer.allocateDirect(uncompressedLen) + : ByteBuffer.allocate(uncompressedLen); + decompressionBufferAllocationMillis.addAndGet( + System.currentTimeMillis() - startBufferAllocation); + + long startDecompression = System.currentTimeMillis(); + codec.decompress(buffer, uncompressedLen, dst, 0); + decompressionMillis.addAndGet(System.currentTimeMillis() - startDecompression); + decompressionBytes.addAndGet(length); + + nowMemoryUsed.addAndGet(uncompressedLen); + resetPeekMemoryUsed(); + + return dst; + }, + executorService) + .exceptionally( + ex -> { + LOG.error("Errors on decompressing shuffle block", ex); + return null; + }); ConcurrentHashMap blocks = tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>()); blocks.put( From c16b4ee35e1eae0aaf0dfc55dd9311d0c46255c4 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Mon, 2 Mar 2026 10:48:11 +0800 Subject: [PATCH 3/3] adjust to 10 --- .../java/org/apache/uniffle/common/config/RssClientConf.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java index 37a79ce0e6..c1672fca80 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java @@ -394,6 +394,6 @@ public class RssClientConf { RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS = ConfigOptions.key("rss.client.read.overlappingDecompressionMaxConcurrentSegments") .intType() - .defaultValue(-1) + .defaultValue(10) .withDescription("Max concurrent segments number for overlapping decompression."); }