Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.stream.ChunkedInput;
import io.netty.handler.stream.ChunkedStream;
import io.netty.util.ReferenceCountUtil;

import org.apache.celeborn.common.network.buffer.ManagedBuffer;

/**
* A wrapper message that holds two separate pieces (a header and a body).
*
* <p>The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream Based on
* common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
* <p>The header must be a ByteBuf, while the body can be a ByteBuf, InputStream, or ChunkedStream.
* Based on common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
*/
public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {

Expand All @@ -61,8 +63,8 @@ public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {
public EncryptedMessageWithHeader(
@Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) {
Preconditions.checkArgument(
body instanceof InputStream || body instanceof ChunkedStream,
"Body must be an InputStream or a ChunkedStream.");
body instanceof ByteBuf || body instanceof InputStream || body instanceof ChunkedStream,
"Body must be a ByteBuf, an InputStream, or a ChunkedStream.");
this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
Expand All @@ -82,40 +84,47 @@ public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
return null;
}

if (totalBytesTransferred < headerLength) {
totalBytesTransferred += headerLength;
return header.retain();
} else if (body instanceof InputStream) {
InputStream stream = (InputStream) body;
int available = stream.available();
if (available <= 0) {
available = (int) (length() - totalBytesTransferred);
} else {
available = (int) Math.min(available, length() - totalBytesTransferred);
}
ByteBuf buffer = allocator.buffer(available);
int toRead = Math.min(available, buffer.writableBytes());
int read = buffer.writeBytes(stream, toRead);
if (read >= 0) {
totalBytesTransferred += read;
return buffer;
} else {
throw new EOFException("Unable to read bytes from InputStream");
}
} else if (body instanceof ChunkedStream) {
ChunkedStream stream = (ChunkedStream) body;
long old = stream.transferredBytes();
ByteBuf buffer = stream.readChunk(allocator);
long read = stream.transferredBytes() - old;
if (read >= 0) {
totalBytesTransferred += read;
assert (totalBytesTransferred <= length());
return buffer;
if (body instanceof ByteBuf) {
// For ByteBuf bodies, return header + body as a single composite buffer.
ByteBuf bodyBuf = (ByteBuf) body;
totalBytesTransferred = headerLength + bodyLength;
return Unpooled.wrappedBuffer(header.retain(), bodyBuf.retain());
} else {
if (totalBytesTransferred < headerLength) {
totalBytesTransferred += headerLength;
return header.retain();
} else if (body instanceof InputStream) {
InputStream stream = (InputStream) body;
int available = stream.available();
if (available <= 0) {
available = (int) (length() - totalBytesTransferred);
} else {
available = (int) Math.min(available, length() - totalBytesTransferred);
}
ByteBuf buffer = allocator.buffer(available);
int toRead = Math.min(available, buffer.writableBytes());
int read = buffer.writeBytes(stream, toRead);
if (read >= 0) {
totalBytesTransferred += read;
return buffer;
} else {
throw new EOFException("Unable to read bytes from InputStream");
}
} else if (body instanceof ChunkedStream) {
ChunkedStream stream = (ChunkedStream) body;
long old = stream.transferredBytes();
ByteBuf buffer = stream.readChunk(allocator);
long read = stream.transferredBytes() - old;
if (read >= 0) {
totalBytesTransferred += read;
assert (totalBytesTransferred <= length());
return buffer;
} else {
throw new EOFException("Unable to read bytes from ChunkedStream");
}
} else {
throw new EOFException("Unable to read bytes from ChunkedStream");
return null;
}
} else {
return null;
}
}

Expand All @@ -137,6 +146,7 @@ public boolean isEndOfInput() throws Exception {
@Override
public void close() throws Exception {
header.release();
ReferenceCountUtil.release(body);
if (managedBuffer != null) {
managedBuffer.release();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.celeborn.common.network.protocol;

import java.io.InputStream;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.handler.stream.ChunkedStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -89,15 +86,9 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) thro
assert header.writableBytes() == 0;

if (body != null && bodyLength > 0) {
if (body instanceof ByteBuf) {
out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
} else if (body instanceof InputStream || body instanceof ChunkedStream) {
// For now, assume the InputStream is doing proper chunking.
out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength));
} else {
throw new IllegalArgumentException(
"Body must be a ByteBuf, ChunkedStream or an InputStream");
}
// We transfer ownership of the reference on in.body() to EncryptedMessageWithHeader.
// This reference will be freed when EncryptedMessageWithHeader.close() is called.
out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength));
} else {
out.add(header);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -146,15 +145,43 @@ public void testChunkedStream() throws Exception {
assertEquals(0, header.refCnt());
}

// Tests the case where the body is a ByteBuf and that we manage the refcounts of the
// header, body, and managed buffer properly
@Test
public void testByteBufIsNotSupported() throws Exception {
// Validate that ByteBufs are not supported. This test can be updated
// when we add support for them
public void testByteBufBodyFromManagedBuffer() throws Exception {
byte[] randomData = new byte[128];
new Random().nextBytes(randomData);
ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
// convertToNettyForSsl() returns buf.duplicate().retain(), simulate that here
ByteBuf body = sourceBuffer.duplicate().retain();
ByteBuf header = Unpooled.copyLong(42);
assertThrows(
IllegalArgumentException.class,
() -> {
EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null, header, header, 4);
});

long expectedHeaderValue = header.getLong(header.readerIndex());
assertEquals(1, header.refCnt());
assertEquals(2, sourceBuffer.refCnt()); // original + duplicate retain
ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);

EncryptedMessageWithHeader msg =
new EncryptedMessageWithHeader(managedBuf, header, body, managedBuf.size());
ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;

assertFalse(msg.isEndOfInput());

// Single read should return header + body as a composite buffer
ByteBuf result = msg.readChunk(allocator);
assertEquals(header.capacity() + randomData.length, result.readableBytes());
assertEquals(expectedHeaderValue, result.readLong());
for (int i = 0; i < randomData.length; i++) {
assertEquals(randomData[i], result.readByte());
}
assertTrue(msg.isEndOfInput());

// Release the chunk (simulates Netty writing it out)
result.release();

// Closing the message should release the source buffer via managedBuffer.release()
msg.close();
assertEquals(0, sourceBuffer.refCnt());
assertEquals(0, header.refCnt());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.network.protocol;

import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

import java.util.ArrayList;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;

import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;

/**
* Verifies reference counting correctness in SslMessageEncoder.encode() for messages with a
* NettyManagedBuffer body.
*
* <p>When convertToNettyForSsl() returns a ByteBuf, the encoder wraps it in an
* EncryptedMessageWithHeader whose close() releases the ManagedBuffer. This mirrors the non-SSL
* MessageEncoder which uses MessageWithHeader.deallocate().
*/
public class SslMessageEncoderSuiteJ {

/**
* Core regression test: encoding a PushData with a NettyManagedBuffer body must leave the
* underlying ByteBuf at refCnt=0 after Netty reads and closes the EncryptedMessageWithHeader.
*/
@Test
public void testNettyManagedBufferBodyIsReleasedAfterEncoding() throws Exception {
ByteBuf bodyBuf = Unpooled.copyLong(1L);
assertEquals(1, bodyBuf.refCnt());

PushData pushData =
new PushData((byte) 0, "shuffleKey", "partitionId", new NettyManagedBuffer(bodyBuf));

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);

List<Object> out = new ArrayList<>();
SslMessageEncoder.INSTANCE.encode(ctx, pushData, out);

assertEquals(1, out.size());
assertTrue(out.get(0) instanceof EncryptedMessageWithHeader);

EncryptedMessageWithHeader msg = (EncryptedMessageWithHeader) out.get(0);

// convertToNettyForSsl() called retain on a duplicate, so refCnt is 2
// (original + duplicate). The ManagedBuffer has not been released yet — that
// happens when close() is called.
assertEquals(2, bodyBuf.refCnt());

// Simulate Netty's ChunkedWriteHandler: read the chunk, then release it.
ByteBuf chunk = msg.readChunk(UnpooledByteBufAllocator.DEFAULT);
assertNotNull(chunk);
assertTrue(msg.isEndOfInput());
ReferenceCountUtil.release(chunk);

// Simulate Netty closing the ChunkedInput after transfer completes.
msg.close();

// After close(), the ManagedBuffer is released, bringing refCnt to 0.
assertEquals(0, bodyBuf.refCnt());
}
}
Loading
Loading