diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTTransientSessionHandler.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTTransientSessionHandler.java index 388900161..84f7ffb1a 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTTransientSessionHandler.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTTransientSessionHandler.java @@ -173,7 +173,7 @@ protected final CompletableFuture subTopicFilter( return CompletableFuture.completedFuture(EXCEED_LIMIT); } tenantMeter.recordCount(MqttTransientSubCount); - int maxTopicFiltersPerInbox = settings.maxTopicFiltersPerSub; + int maxTopicFiltersPerInbox = settings.maxTopicFiltersPerInbox; if (topicFilters.size() >= maxTopicFiltersPerInbox) { return CompletableFuture.completedFuture(EXCEED_LIMIT); } @@ -265,7 +265,7 @@ protected CompletableFuture unsubTopicFilter(lo } @Override - public boolean isSubscribing(String topicFilter) { + public boolean hasSubscribed(String topicFilter) { return topicFilters.containsKey(topicFilter); } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TenantSettings.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TenantSettings.java index 58dd05bf6..a14106cd0 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TenantSettings.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/TenantSettings.java @@ -13,6 +13,9 @@ package com.baidu.bifromq.mqtt.handler; +import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; +import com.baidu.bifromq.type.QoS; + import static com.baidu.bifromq.plugin.settingprovider.Setting.DebugModeEnabled; import static com.baidu.bifromq.plugin.settingprovider.Setting.ForceTransient; import static com.baidu.bifromq.plugin.settingprovider.Setting.InBoundBandWidth; @@ -22,6 +25,7 @@ import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxResendTimes; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxSessionExpirySeconds; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicAlias; +import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerInbox; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerSub; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLength; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLevelLength; @@ -41,9 +45,6 @@ import static com.baidu.bifromq.plugin.settingprovider.Setting.SubscriptionIdentifierEnabled; import static com.baidu.bifromq.plugin.settingprovider.Setting.WildcardSubscriptionEnabled; -import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; -import com.baidu.bifromq.type.QoS; - public class TenantSettings { public final boolean mqtt3Enabled; public final boolean mqtt4Enabled; @@ -72,6 +73,7 @@ public class TenantSettings { public final int inboxQueueLength; public final boolean inboxDropOldest; public final int retainMatchLimit; + public final int maxTopicFiltersPerInbox; public TenantSettings(String tenantId, ISettingProvider provider) { mqtt3Enabled = provider.provide(MQTT3Enabled, tenantId); @@ -101,5 +103,6 @@ public TenantSettings(String tenantId, ISettingProvider provider) { inboxQueueLength = provider.provide(SessionInboxSize, tenantId); inboxDropOldest = provider.provide(QoS0DropOldest, tenantId); retainMatchLimit = provider.provide(RetainMessageMatchLimit, tenantId); + maxTopicFiltersPerInbox = provider.provide(MaxTopicFiltersPerInbox, tenantId); } } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/service/LocalDistService.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/service/LocalDistService.java index e0c83d64c..26afea4b8 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/service/LocalDistService.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/service/LocalDistService.java @@ -196,7 +196,7 @@ public CheckReply.Code checkMatchInfo(String tenantId, MatchInfo matchInfo) { return CheckReply.Code.NO_RECEIVER; } if (session instanceof IMQTTTransientSession transientSession) { - return transientSession.isSubscribing(matchInfo.getMatcher().getMqttTopicFilter()) + return transientSession.hasSubscribed(matchInfo.getMatcher().getMqttTopicFilter()) ? CheckReply.Code.OK : CheckReply.Code.NO_SUB; } else { // should not be here diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/session/IMQTTTransientSession.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/session/IMQTTTransientSession.java index a8446164e..b35421747 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/session/IMQTTTransientSession.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/session/IMQTTTransientSession.java @@ -31,7 +31,7 @@ public interface IMQTTTransientSession extends IMQTTSession { */ Set publish(TopicMessagePack messagePack, Set matchedTopicFilters); - boolean isSubscribing(String topicFilter); + boolean hasSubscribed(String topicFilter); /** * The matched topic filter. diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/BaseSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/BaseSessionHandlerTest.java index e7fd0ef6f..921ded592 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/BaseSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/BaseSessionHandlerTest.java @@ -17,6 +17,7 @@ import static com.baidu.bifromq.plugin.settingprovider.Setting.DebugModeEnabled; import static com.baidu.bifromq.plugin.settingprovider.Setting.ForceTransient; import static com.baidu.bifromq.plugin.settingprovider.Setting.InBoundBandWidth; +import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerInbox; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerSub; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLength; import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLevelLength; @@ -99,6 +100,7 @@ import com.google.common.collect.Lists; import com.google.protobuf.ByteString; import io.micrometer.core.instrument.Timer; +import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.mqtt.MqttSubAckMessage; import io.netty.handler.codec.mqtt.MqttUnsubAckMessage; @@ -118,7 +120,7 @@ import org.mockito.stubbing.Answer; import org.mockito.stubbing.OngoingStubbing; -public class BaseSessionHandlerTest extends MockableTest { +public abstract class BaseSessionHandlerTest extends MockableTest { protected final String tenantId = "tenantId"; protected final String serverId = "serverId"; @@ -178,8 +180,26 @@ public void setup(Method method) { when(tenantMeter.timer(any())).thenReturn(mock(Timer.class)); when(oomCondition.meet()).thenReturn(false); when(clientBalancer.needRedirect(any())).thenReturn(Optional.empty()); + sessionContext = MQTTSessionContext.builder() + .serverId(serverId) + .ticker(testTicker) + .defaultKeepAliveTimeSeconds(2) + .distClient(distClient) + .retainClient(retainClient) + .authProvider(authProvider) + .localDistService(localDistService) + .localSessionRegistry(localSessionRegistry) + .sessionDictClient(sessionDictClient) + .clientBalancer(clientBalancer) + .inboxClient(inboxClient) + .eventCollector(eventCollector) + .resourceThrottler(resourceThrottler) + .settingProvider(settingProvider) + .build(); + mockSettings(); } + protected abstract ChannelDuplexHandler buildChannelHandler(); protected void verifySubAck(MqttSubAckMessage subAckMessage, int[] expectedReasonCodes) { assertEquals(subAckMessage.payload().reasonCodes().size(), expectedReasonCodes.length); @@ -220,6 +240,7 @@ protected void mockSettings() { Mockito.lenient().when(settingProvider.provide(eq(RetainEnabled), anyString())).thenReturn(true); Mockito.lenient().when(settingProvider.provide(eq(RetainMessageMatchLimit), anyString())).thenReturn(10); Mockito.lenient().when(settingProvider.provide(eq(MaxTopicFiltersPerSub), anyString())).thenReturn(10); + Mockito.lenient().when(settingProvider.provide(eq(MaxTopicFiltersPerInbox), anyString())).thenReturn(10); } protected void mockCheckPermission(boolean allow) { diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTPacketFilterTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTPacketFilterTest.java index f61b2516f..44b0c591b 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTPacketFilterTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTPacketFilterTest.java @@ -72,7 +72,7 @@ public class MQTTPacketFilterTest extends MockableTest { @Test public void mqtt3DropPacket() { try (MockedStatic mockedStatic = mockStatic(ITenantMeter.class)) { - // 模拟MyUtility.staticMethod()方法 + // simulate MyUtility.staticMethod() mockedStatic.when(() -> ITenantMeter.get(tenantId)).thenReturn(tenantMeter); when(tenantMeter.timer(any())).thenReturn(timer); MQTTPacketFilter testFilter = diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3PersistentSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3PersistentSessionHandlerTest.java index 68a774d0e..3debeac4f 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3PersistentSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3PersistentSessionHandlerTest.java @@ -93,46 +93,10 @@ @Slf4j public class MQTT3PersistentSessionHandlerTest extends BaseSessionHandlerTest { - private MQTT3PersistentSessionHandler persistentSessionHandler; - @BeforeMethod(alwaysRun = true) public void setup(Method method) { super.setup(method); - sessionContext = MQTTSessionContext.builder() - .serverId(serverId) - .ticker(testTicker) - .defaultKeepAliveTimeSeconds(2) - .distClient(distClient) - .retainClient(retainClient) - .localDistService(localDistService) - .authProvider(authProvider) - .localSessionRegistry(localSessionRegistry) - .sessionDictClient(sessionDictClient) - .clientBalancer(clientBalancer) - .inboxClient(inboxClient) - .eventCollector(eventCollector) - .resourceThrottler(resourceThrottler) - .settingProvider(settingProvider) - .build(); - // common mocks - mockSettings(); - ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelActive(ctx); - ctx.pipeline().addLast(MQTT3PersistentSessionHandler.builder() - .settings(new TenantSettings(tenantId, settingProvider)) - .tenantMeter(tenantMeter) - .oomCondition(oomCondition) - .userSessionId(userSessionId(clientInfo)) - .keepAliveTimeSeconds(120) - .clientInfo(clientInfo) - .willMessage(null) - .ctx(ctx) - .build()); - ctx.pipeline().remove(this); - } - }; + ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler(); mockInboxCreate(CreateReply.Code.OK); mockInboxReader(); channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() { @@ -146,7 +110,6 @@ protected void initChannel(Channel ch) { pipeline.addLast(sessionHandlerAdder); } }); - persistentSessionHandler = (MQTT3PersistentSessionHandler) channel.pipeline().last(); } @@ -157,7 +120,28 @@ public void tearDown(Method method) { fetchHints.clear(); } -// =============================================== sub & unSub ======================================================= + @Override + protected ChannelDuplexHandler buildChannelHandler() { + return new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.pipeline().addLast(MQTT3PersistentSessionHandler.builder() + .settings(new TenantSettings(tenantId, settingProvider)) + .tenantMeter(tenantMeter) + .oomCondition(oomCondition) + .userSessionId(userSessionId(clientInfo)) + .keepAliveTimeSeconds(120) + .clientInfo(clientInfo) + .willMessage(null) + .ctx(ctx) + .build()); + ctx.pipeline().remove(this); + } + }; + } + + // =============================================== sub & unSub ======================================================= @Test public void persistentQoS0Sub() { diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java index cd398d1fc..683857601 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3TransientSessionHandlerTest.java @@ -80,7 +80,6 @@ import com.baidu.bifromq.mqtt.handler.ChannelAttrs; import com.baidu.bifromq.mqtt.handler.TenantSettings; import com.baidu.bifromq.mqtt.session.IMQTTTransientSession; -import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.mqtt.utils.MQTTMessageUtils; import com.baidu.bifromq.plugin.authprovider.type.CheckResult; import com.baidu.bifromq.plugin.authprovider.type.Granted; @@ -133,26 +132,7 @@ public class MQTT3TransientSessionHandlerTest extends BaseSessionHandlerTest { @BeforeMethod(alwaysRun = true) public void setup(Method method) { super.setup(method); - int keepAlive = 2; - sessionContext = - MQTTSessionContext.builder().serverId(serverId).ticker(testTicker).defaultKeepAliveTimeSeconds(keepAlive) - .distClient(distClient).retainClient(retainClient).authProvider(authProvider) - .localDistService(localDistService).localSessionRegistry(localSessionRegistry) - .sessionDictClient(sessionDictClient).clientBalancer(clientBalancer).eventCollector(eventCollector) - .resourceThrottler(resourceThrottler).settingProvider(settingProvider).build(); - // common mocks - mockSettings(); - ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelActive(ctx); - ctx.pipeline().addLast( - MQTT3TransientSessionHandler.builder().settings(new TenantSettings(tenantId, settingProvider)) - .tenantMeter(tenantMeter).oomCondition(oomCondition).userSessionId(userSessionId(clientInfo)) - .keepAliveTimeSeconds(120).clientInfo(clientInfo).willMessage(null).ctx(ctx).build()); - ctx.pipeline().remove(this); - } - }; + ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler(); mockSessionReg(); channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() { @Override @@ -184,6 +164,26 @@ public void tearDown(Method method) { super.tearDown(method); } + @Override + protected ChannelDuplexHandler buildChannelHandler() { + return new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.pipeline().addLast(MQTT3TransientSessionHandler.builder() + .settings(new TenantSettings(tenantId, settingProvider)) + .tenantMeter(tenantMeter) + .oomCondition(oomCondition) + .userSessionId(userSessionId(clientInfo)) + .keepAliveTimeSeconds(120) + .clientInfo(clientInfo) + .willMessage(null) + .ctx(ctx) + .build()); + ctx.pipeline().remove(this); + } + }; + } @Test public void handleConnect() { @@ -210,9 +210,9 @@ public void transientQoS0Sub() { verifySubAck(subAckMessage, qos); verifyEvent(MQTT_SESSION_START, SUB_ACKED); shouldCleanSubs = true; - boolean isSub = - transientSessionHandler.isSubscribing(subMessage.payload().topicSubscriptions().get(0).topicFilter()); - assertTrue(isSub); + boolean hasSub = + transientSessionHandler.hasSubscribed(subMessage.payload().topicSubscriptions().get(0).topicFilter()); + assertTrue(hasSub); } @Test diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5PersistentSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5PersistentSessionHandlerTest.java index c39ed7f47..2ecd15adf 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5PersistentSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5PersistentSessionHandlerTest.java @@ -50,7 +50,6 @@ import com.baidu.bifromq.mqtt.handler.ChannelAttrs; import com.baidu.bifromq.mqtt.handler.TenantSettings; import com.baidu.bifromq.mqtt.handler.v5.reason.MQTT5DisconnectReasonCode; -import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.mqtt.utils.MQTTMessageUtils; import com.baidu.bifromq.plugin.eventcollector.mqttbroker.pushhandling.QoS1Confirmed; import com.baidu.bifromq.plugin.eventcollector.mqttbroker.pushhandling.QoS2Confirmed; @@ -85,51 +84,10 @@ @Slf4j public class MQTT5PersistentSessionHandlerTest extends BaseSessionHandlerTest { - private MQTT5PersistentSessionHandler persistentSessionHandler; - private final int sessionExpirySeconds = 120; - @BeforeMethod(alwaysRun = true) public void setup(Method method) { super.setup(method); - sessionContext = MQTTSessionContext.builder() - .serverId(serverId) - .ticker(testTicker) - .defaultKeepAliveTimeSeconds(2) - .distClient(distClient) - .retainClient(retainClient) - .localDistService(localDistService) - .authProvider(authProvider) - .localSessionRegistry(localSessionRegistry) - .sessionDictClient(sessionDictClient) - .clientBalancer(clientBalancer) - .inboxClient(inboxClient) - .eventCollector(eventCollector) - .resourceThrottler(resourceThrottler) - .settingProvider(settingProvider) - .build(); - // common mocks - mockSettings(); - ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelActive(ctx); - ctx.pipeline().addLast(MQTT5PersistentSessionHandler.builder() - .settings(new TenantSettings(tenantId, settingProvider)) - .tenantMeter(tenantMeter) - .oomCondition(oomCondition) - .connMsg(MqttMessageBuilders.connect() - .protocolVersion(MqttVersion.MQTT_5) - .build()) - .userSessionId(userSessionId(clientInfo)) - .keepAliveTimeSeconds(120) - .sessionExpirySeconds(sessionExpirySeconds) - .clientInfo(clientInfo) - .willMessage(null) - .ctx(ctx) - .build()); - ctx.pipeline().remove(this); - } - }; + ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler(); mockInboxCreate(CreateReply.Code.OK); mockInboxReader(); channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() { @@ -143,7 +101,6 @@ protected void initChannel(Channel ch) { pipeline.addLast(sessionHandlerAdder); } }); - persistentSessionHandler = (MQTT5PersistentSessionHandler) channel.pipeline().last(); } @@ -154,6 +111,31 @@ public void tearDown(Method method) { fetchHints.clear(); } + @Override + protected ChannelDuplexHandler buildChannelHandler() { + return new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.pipeline().addLast(MQTT5PersistentSessionHandler.builder() + .settings(new TenantSettings(tenantId, settingProvider)) + .tenantMeter(tenantMeter) + .oomCondition(oomCondition) + .connMsg(MqttMessageBuilders.connect() + .protocolVersion(MqttVersion.MQTT_5) + .build()) + .userSessionId(userSessionId(clientInfo)) + .keepAliveTimeSeconds(120) + .sessionExpirySeconds(120) + .clientInfo(clientInfo) + .willMessage(null) + .ctx(ctx) + .build()); + ctx.pipeline().remove(this); + } + }; + } + // =============================================== connect & disconnect ============================================= @Test diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java index bf5b438e1..b8aab9c9e 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5TransientSessionHandlerTest.java @@ -76,7 +76,6 @@ import com.baidu.bifromq.mqtt.handler.TenantSettings; import com.baidu.bifromq.mqtt.handler.v5.reason.MQTT5SubAckReasonCode; import com.baidu.bifromq.mqtt.session.IMQTTTransientSession; -import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.mqtt.utils.MQTTMessageUtils; import com.baidu.bifromq.plugin.authprovider.type.CheckResult; import com.baidu.bifromq.plugin.authprovider.type.Granted; @@ -129,48 +128,7 @@ public class MQTT5TransientSessionHandlerTest extends BaseSessionHandlerTest { @BeforeMethod(alwaysRun = true) public void setup(Method method) { super.setup(method); - int keepAlive = 2; - sessionContext = MQTTSessionContext.builder() - .serverId(serverId) - .ticker(testTicker) - .defaultKeepAliveTimeSeconds(keepAlive) - .distClient(distClient) - .retainClient(retainClient) - .authProvider(authProvider) - .localDistService(localDistService) - .localSessionRegistry(localSessionRegistry) - .sessionDictClient(sessionDictClient) - .clientBalancer(clientBalancer) - .eventCollector(eventCollector) - .resourceThrottler(resourceThrottler) - .settingProvider(settingProvider) - .build(); - // common mocks - mockSettings(); - MqttProperties mqttProperties = new MqttProperties(); - mqttProperties.add(new MqttProperties.IntegerProperty(TOPIC_ALIAS_MAXIMUM.value(), 10)); - ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() { - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelActive(ctx); - ctx.pipeline() - .addLast(MQTT5TransientSessionHandler.builder() - .settings(new TenantSettings(tenantId, settingProvider)) - .tenantMeter(tenantMeter) - .oomCondition(oomCondition) - .connMsg(MqttMessageBuilders.connect() - .protocolVersion(MqttVersion.MQTT_5) - .properties(mqttProperties) - .build()) - .userSessionId(userSessionId(clientInfo)) - .keepAliveTimeSeconds(120) - .clientInfo(clientInfo) - .willMessage(null) - .ctx(ctx) - .build()); - ctx.pipeline().remove(this); - } - }; + ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler(); mockSessionReg(); channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() { @Override @@ -202,8 +160,34 @@ public void tearDown(Method method) { super.tearDown(method); } + @Override + protected ChannelDuplexHandler buildChannelHandler() { + MqttProperties mqttProperties = new MqttProperties(); + mqttProperties.add(new MqttProperties.IntegerProperty(TOPIC_ALIAS_MAXIMUM.value(), 10)); + return new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + ctx.pipeline().addLast(MQTT5TransientSessionHandler.builder() + .settings(new TenantSettings(tenantId, settingProvider)) + .tenantMeter(tenantMeter) + .oomCondition(oomCondition) + .connMsg(MqttMessageBuilders.connect() + .protocolVersion(MqttVersion.MQTT_5) + .properties(mqttProperties) + .build()) + .userSessionId(userSessionId(clientInfo)) + .keepAliveTimeSeconds(120) + .clientInfo(clientInfo) + .willMessage(null) + .ctx(ctx) + .build()); + ctx.pipeline().remove(this); + } + }; + } -// =============================================== sub & unSub ====================================================== + // =============================================== sub & unSub ====================================================== @Test public void transientMixedSub() { @@ -236,6 +220,31 @@ public void transientMixedSubWithDistSubFailed() { shouldCleanSubs = true; } + @Test + public void transientSubExceedInboxLimit() { + mockCheckPermission(true); + mockDistMatch(true); + mockRetainMatch(); + int settingLimit = 10; + String[] tfs = new String[settingLimit]; + int[] qos = new int[10]; + for (int index = 0; index < 10; index++) { + tfs[index] = "t/" + index; + qos[index] = 0; + } + MqttSubscribeMessage subMessage = MQTTMessageUtils.qoSMqttSubMessages(tfs, qos); + channel.writeInbound(subMessage); + MqttSubAckMessage subAckMessage = channel.readOutbound(); + verifySubAck(subAckMessage, new int[settingLimit]); + + subMessage = MQTTMessageUtils.qoSMqttSubMessages(new String[]{"anotherTFS"}, new int[]{0}); + channel.writeInbound(subMessage); + subAckMessage = channel.readOutbound(); + verifySubAck(subAckMessage, new int[] {MQTT5SubAckReasonCode.QuotaExceeded.value()}); + verifyEvent(MQTT_SESSION_START, SUB_ACKED, SUB_ACKED); + shouldCleanSubs = true; + } + @Test public void transientUnSub() { diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/service/LocalDistServiceTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/service/LocalDistServiceTest.java index e080f0d80..c15b81415 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/service/LocalDistServiceTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/service/LocalDistServiceTest.java @@ -96,7 +96,7 @@ public void checkMatchInfoForSharedSub() { ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build(); when(session.clientInfo()).thenReturn(clientInfo); when(session.channelId()).thenReturn(channelId); - when(session.isSubscribing(topicFilter)).thenReturn(true); + when(session.hasSubscribed(topicFilter)).thenReturn(true); when(localSessionRegistry.get(channelId)).thenReturn(session); long reqId = System.nanoTime(); localDistService.match(reqId, topicFilter, 1, session); @@ -189,7 +189,7 @@ public void unmatchSharedSubTopicFilter() { ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build(); when(session.clientInfo()).thenReturn(clientInfo); when(session.channelId()).thenReturn(channelId); - when(session.isSubscribing(topicFilter)).thenReturn(false); + when(session.hasSubscribed(topicFilter)).thenReturn(false); when(localSessionRegistry.get(channelId)).thenReturn(session); long reqId = System.nanoTime(); localDistService.unmatch(reqId, topicFilter, 1L, session);