Skip to content

Commit b62d2c4

Browse files
committed
Fixed potential bytebuf leaks in edge cases
1 parent 0def69f commit b62d2c4

File tree

4 files changed

+50
-15
lines changed

4 files changed

+50
-15
lines changed

bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandler.java

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,25 @@
2121

2222
import static org.apache.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal;
2323

24-
import org.apache.bifromq.plugin.eventcollector.IEventCollector;
25-
import org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
2624
import com.google.common.util.concurrent.RateLimiter;
2725
import io.netty.channel.ChannelDuplexHandler;
2826
import io.netty.channel.ChannelHandler;
2927
import io.netty.channel.ChannelHandlerContext;
3028
import io.netty.channel.ChannelPipeline;
29+
import io.netty.util.ReferenceCountUtil;
3130
import java.util.concurrent.ThreadLocalRandom;
3231
import java.util.concurrent.TimeUnit;
3332
import lombok.extern.slf4j.Slf4j;
33+
import org.apache.bifromq.plugin.eventcollector.IEventCollector;
34+
import org.apache.bifromq.plugin.eventcollector.mqttbroker.channelclosed.ChannelError;
3435

3536
@Slf4j
3637
@ChannelHandler.Sharable
3738
public class ConnectionRateLimitHandler extends ChannelDuplexHandler {
38-
/**
39-
* Initialize the pipeline when the connection is accepted.
40-
*/
41-
public interface ChannelPipelineInitializer {
42-
void initialize(ChannelPipeline pipeline);
43-
}
44-
4539
private final RateLimiter rateLimiter;
4640
private final IEventCollector eventCollector;
4741
private final ChannelPipelineInitializer initializer;
42+
private boolean accepted = false;
4843

4944
public ConnectionRateLimitHandler(RateLimiter limiter,
5045
IEventCollector eventCollector,
@@ -57,9 +52,13 @@ public ConnectionRateLimitHandler(RateLimiter limiter,
5752
@Override
5853
public void channelActive(ChannelHandlerContext ctx) {
5954
if (rateLimiter.tryAcquire()) {
55+
accepted = true;
6056
initializer.initialize(ctx.pipeline());
6157
ctx.fireChannelActive();
58+
// Remove this handler after the connection is accepted
59+
ctx.pipeline().remove(this);
6260
} else {
61+
accepted = false;
6362
log.debug("Connection dropped due to exceed limit");
6463
eventCollector.report(getLocal(ChannelError.class)
6564
.peerAddress(ChannelAttrs.socketAddress(ctx.channel()))
@@ -73,4 +72,20 @@ public void channelActive(ChannelHandlerContext ctx) {
7372
}, ThreadLocalRandom.current().nextLong(100, 3000), TimeUnit.MILLISECONDS);
7473
}
7574
}
75+
76+
@Override
77+
public void channelRead(ChannelHandlerContext ctx, Object msg) {
78+
if (!accepted) {
79+
ReferenceCountUtil.release(msg);
80+
return;
81+
}
82+
ctx.fireChannelRead(msg);
83+
}
84+
85+
/**
86+
* Initialize the pipeline when the connection is accepted.
87+
*/
88+
public interface ChannelPipelineInitializer {
89+
void initialize(ChannelPipeline pipeline);
90+
}
7691
}

bifromq-mqtt/bifromq-mqtt-server/src/main/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io.netty.handler.codec.http.FullHttpResponse;
2828
import io.netty.handler.codec.http.HttpHeaderNames;
2929
import io.netty.handler.codec.http.HttpResponseStatus;
30+
import io.netty.util.ReferenceCountUtil;
3031

3132
/**
3233
* A simple handler that rejects all requests that are not WebSocket upgrade requests.
@@ -46,6 +47,7 @@ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
4647
!req.headers().get(HttpHeaderNames.UPGRADE, "").equalsIgnoreCase("websocket")) {
4748
FullHttpResponse response =
4849
new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.BAD_REQUEST);
50+
ReferenceCountUtil.release(req);
4951
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
5052
} else {
5153
// Proceed with the pipeline setup for WebSocket.

bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ConnectionRateLimitHandlerTest.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1515
* KIND, either express or implied. See the License for the
1616
* specific language governing permissions and limitations
17-
* under the License.
17+
* under the License.
1818
*/
1919

2020
package org.apache.bifromq.mqtt.handler;
@@ -24,15 +24,18 @@
2424
import static org.mockito.Mockito.never;
2525
import static org.mockito.Mockito.verify;
2626
import static org.mockito.Mockito.when;
27+
import static org.testng.Assert.assertEquals;
2728
import static org.testng.Assert.assertFalse;
2829
import static org.testng.Assert.assertTrue;
2930

30-
import org.apache.bifromq.plugin.eventcollector.EventType;
31-
import org.apache.bifromq.plugin.eventcollector.IEventCollector;
3231
import com.google.common.util.concurrent.RateLimiter;
32+
import io.netty.buffer.ByteBuf;
33+
import io.netty.buffer.Unpooled;
3334
import io.netty.channel.ChannelPipeline;
3435
import io.netty.channel.embedded.EmbeddedChannel;
3536
import java.util.concurrent.TimeUnit;
37+
import org.apache.bifromq.plugin.eventcollector.EventType;
38+
import org.apache.bifromq.plugin.eventcollector.IEventCollector;
3639
import org.mockito.Mock;
3740
import org.mockito.MockitoAnnotations;
3841
import org.testng.annotations.BeforeMethod;
@@ -64,6 +67,8 @@ public void testChannelActiveWhenRateLimiterAllows() {
6467

6568
verify(initializer).initialize(channel.pipeline());
6669
assertTrue(channel.isActive());
70+
// After initialization, the handler should be removed
71+
assertFalse(channel.pipeline().toMap().containsValue(handler));
6772
}
6873

6974
@Test
@@ -78,4 +83,15 @@ public void testChannelActiveWhenRateLimiterDenies() {
7883
assertFalse(channel.isActive());
7984
verify(eventCollector).report(argThat(e -> e.type() == EventType.CHANNEL_ERROR));
8085
}
86+
87+
@Test
88+
public void testRejectedConnectionReleasesInboundByteBuf() {
89+
when(rateLimiter.tryAcquire()).thenReturn(false);
90+
EmbeddedChannel channel = new EmbeddedChannel(handler);
91+
92+
ByteBuf buf = Unpooled.buffer();
93+
assertTrue(buf.refCnt() > 0);
94+
channel.writeInbound(buf);
95+
assertEquals(buf.refCnt(), 0);
96+
}
8197
}

bifromq-mqtt/bifromq-mqtt-server/src/test/java/org/apache/bifromq/mqtt/handler/ws/WebSocketOnlyHandlerTest.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
import org.testng.annotations.Test;
3838

3939
public class WebSocketOnlyHandlerTest {
40-
private EmbeddedChannel channel;
4140
private final String websocketPath = "/mqtt";
41+
private EmbeddedChannel channel;
4242

4343
@BeforeMethod
4444
public void setUp() {
@@ -67,10 +67,12 @@ public void testInvalidRequestPathMismatch() {
6767
HttpVersion.HTTP_1_1, HttpMethod.GET, "/wrongpath");
6868
request.headers().set(HttpHeaderNames.UPGRADE, "websocket");
6969

70+
assertTrue(request.refCnt() > 0);
7071
assertFalse(channel.writeInbound(request));
7172
FullHttpResponse response = channel.readOutbound();
73+
assertEquals(request.refCnt(), 0);
7274
assertNotNull(response);
73-
assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
75+
assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
7476
}
7577

7678
@Test
@@ -82,7 +84,7 @@ public void testInvalidRequestNoUpgradeHeader() {
8284
assertFalse(channel.writeInbound(request));
8385
FullHttpResponse response = channel.readOutbound();
8486
assertNotNull(response);
85-
assertEquals(HttpResponseStatus.BAD_REQUEST, response.status());
87+
assertEquals(response.status(), HttpResponseStatus.BAD_REQUEST);
8688
}
8789

8890
@Test

0 commit comments

Comments
 (0)