Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -55,7 +58,10 @@ public class DecompressionWorker {
private AtomicLong peekMemoryUsed = new AtomicLong(0);
private AtomicLong nowMemoryUsed = new AtomicLong(0);

public DecompressionWorker(Codec codec, int threads, int fetchSecondsThreshold) {
private final Optional<Semaphore> segmentPermits;

public DecompressionWorker(
Codec codec, int threads, int fetchSecondsThreshold, int maxConcurrentDecompressionSegments) {
if (codec == null) {
throw new IllegalArgumentException("Codec cannot be null");
}
Expand All @@ -67,6 +73,16 @@ public DecompressionWorker(Codec codec, int threads, int fetchSecondsThreshold)
Executors.newFixedThreadPool(threads, ThreadUtils.getThreadFactory("decompressionWorker"));
this.codec = codec;
this.fetchSecondsThreshold = fetchSecondsThreshold;

if (maxConcurrentDecompressionSegments <= 0) {
this.segmentPermits = Optional.empty();
} else if (threads != 1) {
LOG.info(
"Disable backpressure control since threads is {} to avoid potential deadlock", threads);
this.segmentPermits = Optional.empty();
} else {
this.segmentPermits = Optional.of(new Semaphore(maxConcurrentDecompressionSegments));
}
}

public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
Expand All @@ -80,34 +96,49 @@ public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
for (BufferSegment bufferSegment : bufferSegments) {
CompletableFuture<ByteBuffer> f =
CompletableFuture.supplyAsync(
() -> {
int offset = bufferSegment.getOffset();
int length = bufferSegment.getLength();
ByteBuffer buffer = sharedByteBuffer.duplicate();
buffer.position(offset);
buffer.limit(offset + length);

int uncompressedLen = bufferSegment.getUncompressLength();

long startBufferAllocation = System.currentTimeMillis();
ByteBuffer dst =
buffer.isDirect()
? ByteBuffer.allocateDirect(uncompressedLen)
: ByteBuffer.allocate(uncompressedLen);
decompressionBufferAllocationMillis.addAndGet(
System.currentTimeMillis() - startBufferAllocation);

long startDecompression = System.currentTimeMillis();
codec.decompress(buffer, uncompressedLen, dst, 0);
decompressionMillis.addAndGet(System.currentTimeMillis() - startDecompression);
decompressionBytes.addAndGet(length);

nowMemoryUsed.addAndGet(uncompressedLen);
resetPeekMemoryUsed();

return dst;
},
executorService);
() -> {
try {
if (segmentPermits.isPresent()) {
segmentPermits.get().acquire();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOG.warn("Interrupted while acquiring segment permit", e);
return null;
}

int offset = bufferSegment.getOffset();
int length = bufferSegment.getLength();
ByteBuffer buffer = sharedByteBuffer.duplicate();
buffer.position(offset);
buffer.limit(offset + length);

int uncompressedLen = bufferSegment.getUncompressLength();

long startBufferAllocation = System.currentTimeMillis();
ByteBuffer dst =
buffer.isDirect()
? ByteBuffer.allocateDirect(uncompressedLen)
: ByteBuffer.allocate(uncompressedLen);
decompressionBufferAllocationMillis.addAndGet(
System.currentTimeMillis() - startBufferAllocation);

long startDecompression = System.currentTimeMillis();
codec.decompress(buffer, uncompressedLen, dst, 0);
decompressionMillis.addAndGet(System.currentTimeMillis() - startDecompression);
decompressionBytes.addAndGet(length);

nowMemoryUsed.addAndGet(uncompressedLen);
resetPeekMemoryUsed();

return dst;
},
executorService)
.exceptionally(
ex -> {
LOG.error("Errors on decompressing shuffle block", ex);
return null;
});
ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks =
tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>());
blocks.put(
Expand All @@ -132,6 +163,7 @@ public DecompressedShuffleBlock get(int batchIndex, int segmentIndex) {
// block
if (block != null) {
nowMemoryUsed.addAndGet(-block.getUncompressLength());
segmentPermits.ifPresent(x -> x.release());
Copy link
Contributor

@roryqi roryqi Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When block is null, we don't release. Will it cause dead lock? Should we use try finally to gurantee the release of the lock.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have ensured the block is not always null here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

}
return block;
}
Expand Down Expand Up @@ -163,4 +195,17 @@ public void close() {
public long decompressionMillis() {
return decompressionMillis.get() + decompressionBufferAllocationMillis.get();
}

@VisibleForTesting
protected long getPeekMemoryUsed() {
return peekMemoryUsed.get();
}

@VisibleForTesting
protected int getAvailablePermits() {
if (segmentPermits.isPresent()) {
return segmentPermits.get().availablePermits();
}
return -1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import static org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT;
import static org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED;
import static org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD;
import static org.apache.uniffle.common.config.RssClientConf.RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS;

public class ShuffleReadClientImpl implements ShuffleReadClient {

Expand Down Expand Up @@ -165,9 +166,14 @@ private void init(ShuffleClientFactory.ReadClientBuilder builder) {
if (builder.isOverlappingDecompressionEnabled()) {
int fetchThreshold =
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_FETCH_SECONDS_THRESHOLD);
int maxSegments =
builder.getRssConf().get(RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS);
this.decompressionWorker =
new DecompressionWorker(
builder.getCodec(), builder.getOverlappingDecompressionThreadNum(), fetchThreshold);
builder.getCodec(),
builder.getOverlappingDecompressionThreadNum(),
fetchThreshold,
maxSegments);
}
this.shuffleId = builder.getShuffleId();
this.partitionId = builder.getPartitionId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;

import org.apache.uniffle.client.response.DecompressedShuffleBlock;
Expand All @@ -36,10 +38,42 @@

public class DecompressionWorkerTest {

@Test
public void testBackpressure() throws Exception {
RssConf rssConf = new RssConf();
rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
Codec codec = Codec.newInstance(rssConf).get();

int threads = 1;
int maxSegments = 10;
int fetchSecondsThreshold = 2;
DecompressionWorker worker =
new DecompressionWorker(codec, threads, fetchSecondsThreshold, maxSegments);

ShuffleDataResult shuffleDataResult = createShuffleDataResult(maxSegments + 1, codec, 1024);
worker.add(0, shuffleDataResult);

// case1: check the peek memory used is correct when the decompression is in progress
Awaitility.await()
.timeout(200, TimeUnit.MILLISECONDS)
.until(() -> 1024 * maxSegments == worker.getPeekMemoryUsed());
assertEquals(0, worker.getAvailablePermits());

// case2: after the previous segments are consumed, the blocked segments can be gotten after the
// decompression is done
for (int i = 0; i < maxSegments; i++) {
worker.get(0, i);
}
Thread.sleep(10);
worker.get(0, maxSegments).getByteBuffer();
assertEquals(1024 * maxSegments, worker.getPeekMemoryUsed());
assertEquals(maxSegments, worker.getAvailablePermits());
}

@Test
public void testEmptyGet() throws Exception {
DecompressionWorker worker =
new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10);
new DecompressionWorker(Codec.newInstance(new RssConf()).get(), 1, 10, 10000);
assertNull(worker.get(1, 1));
}

Expand Down Expand Up @@ -78,7 +112,7 @@ public void test() throws Exception {
RssConf rssConf = new RssConf();
rssConf.set(COMPRESSION_TYPE, Codec.Type.NOOP);
Codec codec = Codec.newInstance(rssConf).get();
DecompressionWorker worker = new DecompressionWorker(codec, 1, 10);
DecompressionWorker worker = new DecompressionWorker(codec, 1, 10, 100000);

// create some data
ShuffleDataResult shuffleDataResult = createShuffleDataResult(10, codec, 100);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,4 +389,11 @@ public class RssClientConf {
.defaultValue(-1)
.withDescription(
"Fetch seconds threshold for overlapping decompress shuffle blocks.");

public static final ConfigOption<Integer>
RSS_READ_OVERLAPPING_DECOMPRESSION_MAX_CONCURRENT_SEGMENTS =
ConfigOptions.key("rss.client.read.overlappingDecompressionMaxConcurrentSegments")
.intType()
.defaultValue(10)
.withDescription("Max concurrent segments number for overlapping decompression.");
}
Loading