Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ protected final CompletableFuture<IMQTTProtocolHelper.SubResult> subTopicFilter(
return CompletableFuture.completedFuture(EXCEED_LIMIT);
}
tenantMeter.recordCount(MqttTransientSubCount);
int maxTopicFiltersPerInbox = settings.maxTopicFiltersPerSub;
int maxTopicFiltersPerInbox = settings.maxTopicFiltersPerInbox;
if (topicFilters.size() >= maxTopicFiltersPerInbox) {
return CompletableFuture.completedFuture(EXCEED_LIMIT);
}
Expand Down Expand Up @@ -265,7 +265,7 @@ protected CompletableFuture<IMQTTProtocolHelper.UnsubResult> unsubTopicFilter(lo
}

@Override
public boolean isSubscribing(String topicFilter) {
public boolean hasSubscribed(String topicFilter) {
return topicFilters.containsKey(topicFilter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

package com.baidu.bifromq.mqtt.handler;

import com.baidu.bifromq.plugin.settingprovider.ISettingProvider;
import com.baidu.bifromq.type.QoS;

import static com.baidu.bifromq.plugin.settingprovider.Setting.DebugModeEnabled;
import static com.baidu.bifromq.plugin.settingprovider.Setting.ForceTransient;
import static com.baidu.bifromq.plugin.settingprovider.Setting.InBoundBandWidth;
Expand All @@ -22,6 +25,7 @@
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxResendTimes;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxSessionExpirySeconds;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicAlias;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerInbox;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerSub;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLength;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLevelLength;
Expand All @@ -41,9 +45,6 @@
import static com.baidu.bifromq.plugin.settingprovider.Setting.SubscriptionIdentifierEnabled;
import static com.baidu.bifromq.plugin.settingprovider.Setting.WildcardSubscriptionEnabled;

import com.baidu.bifromq.plugin.settingprovider.ISettingProvider;
import com.baidu.bifromq.type.QoS;

public class TenantSettings {
public final boolean mqtt3Enabled;
public final boolean mqtt4Enabled;
Expand Down Expand Up @@ -72,6 +73,7 @@ public class TenantSettings {
public final int inboxQueueLength;
public final boolean inboxDropOldest;
public final int retainMatchLimit;
public final int maxTopicFiltersPerInbox;

public TenantSettings(String tenantId, ISettingProvider provider) {
mqtt3Enabled = provider.provide(MQTT3Enabled, tenantId);
Expand Down Expand Up @@ -101,5 +103,6 @@ public TenantSettings(String tenantId, ISettingProvider provider) {
inboxQueueLength = provider.provide(SessionInboxSize, tenantId);
inboxDropOldest = provider.provide(QoS0DropOldest, tenantId);
retainMatchLimit = provider.provide(RetainMessageMatchLimit, tenantId);
maxTopicFiltersPerInbox = provider.provide(MaxTopicFiltersPerInbox, tenantId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public CheckReply.Code checkMatchInfo(String tenantId, MatchInfo matchInfo) {
return CheckReply.Code.NO_RECEIVER;
}
if (session instanceof IMQTTTransientSession transientSession) {
return transientSession.isSubscribing(matchInfo.getMatcher().getMqttTopicFilter())
return transientSession.hasSubscribed(matchInfo.getMatcher().getMqttTopicFilter())
? CheckReply.Code.OK : CheckReply.Code.NO_SUB;
} else {
// should not be here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public interface IMQTTTransientSession extends IMQTTSession {
*/
Set<MatchedTopicFilter> publish(TopicMessagePack messagePack, Set<MatchedTopicFilter> matchedTopicFilters);

boolean isSubscribing(String topicFilter);
boolean hasSubscribed(String topicFilter);

/**
* The matched topic filter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static com.baidu.bifromq.plugin.settingprovider.Setting.DebugModeEnabled;
import static com.baidu.bifromq.plugin.settingprovider.Setting.ForceTransient;
import static com.baidu.bifromq.plugin.settingprovider.Setting.InBoundBandWidth;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerInbox;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicFiltersPerSub;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLength;
import static com.baidu.bifromq.plugin.settingprovider.Setting.MaxTopicLevelLength;
Expand Down Expand Up @@ -99,6 +100,7 @@
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import io.micrometer.core.instrument.Timer;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.mqtt.MqttSubAckMessage;
import io.netty.handler.codec.mqtt.MqttUnsubAckMessage;
Expand All @@ -118,7 +120,7 @@
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.OngoingStubbing;

public class BaseSessionHandlerTest extends MockableTest {
public abstract class BaseSessionHandlerTest extends MockableTest {

protected final String tenantId = "tenantId";
protected final String serverId = "serverId";
Expand Down Expand Up @@ -178,8 +180,26 @@ public void setup(Method method) {
when(tenantMeter.timer(any())).thenReturn(mock(Timer.class));
when(oomCondition.meet()).thenReturn(false);
when(clientBalancer.needRedirect(any())).thenReturn(Optional.empty());
sessionContext = MQTTSessionContext.builder()
.serverId(serverId)
.ticker(testTicker)
.defaultKeepAliveTimeSeconds(2)
.distClient(distClient)
.retainClient(retainClient)
.authProvider(authProvider)
.localDistService(localDistService)
.localSessionRegistry(localSessionRegistry)
.sessionDictClient(sessionDictClient)
.clientBalancer(clientBalancer)
.inboxClient(inboxClient)
.eventCollector(eventCollector)
.resourceThrottler(resourceThrottler)
.settingProvider(settingProvider)
.build();
mockSettings();
}

protected abstract ChannelDuplexHandler buildChannelHandler();

protected void verifySubAck(MqttSubAckMessage subAckMessage, int[] expectedReasonCodes) {
assertEquals(subAckMessage.payload().reasonCodes().size(), expectedReasonCodes.length);
Expand Down Expand Up @@ -220,6 +240,7 @@ protected void mockSettings() {
Mockito.lenient().when(settingProvider.provide(eq(RetainEnabled), anyString())).thenReturn(true);
Mockito.lenient().when(settingProvider.provide(eq(RetainMessageMatchLimit), anyString())).thenReturn(10);
Mockito.lenient().when(settingProvider.provide(eq(MaxTopicFiltersPerSub), anyString())).thenReturn(10);
Mockito.lenient().when(settingProvider.provide(eq(MaxTopicFiltersPerInbox), anyString())).thenReturn(10);
}

protected void mockCheckPermission(boolean allow) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public class MQTTPacketFilterTest extends MockableTest {
@Test
public void mqtt3DropPacket() {
try (MockedStatic<ITenantMeter> mockedStatic = mockStatic(ITenantMeter.class)) {
// 模拟MyUtility.staticMethod()方法
// simulate MyUtility.staticMethod()
mockedStatic.when(() -> ITenantMeter.get(tenantId)).thenReturn(tenantMeter);
when(tenantMeter.timer(any())).thenReturn(timer);
MQTTPacketFilter testFilter =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,46 +93,10 @@
@Slf4j
public class MQTT3PersistentSessionHandlerTest extends BaseSessionHandlerTest {

private MQTT3PersistentSessionHandler persistentSessionHandler;

@BeforeMethod(alwaysRun = true)
public void setup(Method method) {
super.setup(method);
sessionContext = MQTTSessionContext.builder()
.serverId(serverId)
.ticker(testTicker)
.defaultKeepAliveTimeSeconds(2)
.distClient(distClient)
.retainClient(retainClient)
.localDistService(localDistService)
.authProvider(authProvider)
.localSessionRegistry(localSessionRegistry)
.sessionDictClient(sessionDictClient)
.clientBalancer(clientBalancer)
.inboxClient(inboxClient)
.eventCollector(eventCollector)
.resourceThrottler(resourceThrottler)
.settingProvider(settingProvider)
.build();
// common mocks
mockSettings();
ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.pipeline().addLast(MQTT3PersistentSessionHandler.builder()
.settings(new TenantSettings(tenantId, settingProvider))
.tenantMeter(tenantMeter)
.oomCondition(oomCondition)
.userSessionId(userSessionId(clientInfo))
.keepAliveTimeSeconds(120)
.clientInfo(clientInfo)
.willMessage(null)
.ctx(ctx)
.build());
ctx.pipeline().remove(this);
}
};
ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler();
mockInboxCreate(CreateReply.Code.OK);
mockInboxReader();
channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() {
Expand All @@ -146,7 +110,6 @@ protected void initChannel(Channel ch) {
pipeline.addLast(sessionHandlerAdder);
}
});
persistentSessionHandler = (MQTT3PersistentSessionHandler) channel.pipeline().last();
}


Expand All @@ -157,7 +120,28 @@ public void tearDown(Method method) {
fetchHints.clear();
}

// =============================================== sub & unSub =======================================================
@Override
protected ChannelDuplexHandler buildChannelHandler() {
return new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.pipeline().addLast(MQTT3PersistentSessionHandler.builder()
.settings(new TenantSettings(tenantId, settingProvider))
.tenantMeter(tenantMeter)
.oomCondition(oomCondition)
.userSessionId(userSessionId(clientInfo))
.keepAliveTimeSeconds(120)
.clientInfo(clientInfo)
.willMessage(null)
.ctx(ctx)
.build());
ctx.pipeline().remove(this);
}
};
}

// =============================================== sub & unSub =======================================================

@Test
public void persistentQoS0Sub() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
import com.baidu.bifromq.mqtt.handler.TenantSettings;
import com.baidu.bifromq.mqtt.session.IMQTTTransientSession;
import com.baidu.bifromq.mqtt.session.MQTTSessionContext;
import com.baidu.bifromq.mqtt.utils.MQTTMessageUtils;
import com.baidu.bifromq.plugin.authprovider.type.CheckResult;
import com.baidu.bifromq.plugin.authprovider.type.Granted;
Expand Down Expand Up @@ -133,26 +132,7 @@ public class MQTT3TransientSessionHandlerTest extends BaseSessionHandlerTest {
@BeforeMethod(alwaysRun = true)
public void setup(Method method) {
super.setup(method);
int keepAlive = 2;
sessionContext =
MQTTSessionContext.builder().serverId(serverId).ticker(testTicker).defaultKeepAliveTimeSeconds(keepAlive)
.distClient(distClient).retainClient(retainClient).authProvider(authProvider)
.localDistService(localDistService).localSessionRegistry(localSessionRegistry)
.sessionDictClient(sessionDictClient).clientBalancer(clientBalancer).eventCollector(eventCollector)
.resourceThrottler(resourceThrottler).settingProvider(settingProvider).build();
// common mocks
mockSettings();
ChannelDuplexHandler sessionHandlerAdder = new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.pipeline().addLast(
MQTT3TransientSessionHandler.builder().settings(new TenantSettings(tenantId, settingProvider))
.tenantMeter(tenantMeter).oomCondition(oomCondition).userSessionId(userSessionId(clientInfo))
.keepAliveTimeSeconds(120).clientInfo(clientInfo).willMessage(null).ctx(ctx).build());
ctx.pipeline().remove(this);
}
};
ChannelDuplexHandler sessionHandlerAdder = buildChannelHandler();
mockSessionReg();
channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() {
@Override
Expand Down Expand Up @@ -184,6 +164,26 @@ public void tearDown(Method method) {
super.tearDown(method);
}

@Override
protected ChannelDuplexHandler buildChannelHandler() {
return new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
ctx.pipeline().addLast(MQTT3TransientSessionHandler.builder()
.settings(new TenantSettings(tenantId, settingProvider))
.tenantMeter(tenantMeter)
.oomCondition(oomCondition)
.userSessionId(userSessionId(clientInfo))
.keepAliveTimeSeconds(120)
.clientInfo(clientInfo)
.willMessage(null)
.ctx(ctx)
.build());
ctx.pipeline().remove(this);
}
};
}

@Test
public void handleConnect() {
Expand All @@ -210,9 +210,9 @@ public void transientQoS0Sub() {
verifySubAck(subAckMessage, qos);
verifyEvent(MQTT_SESSION_START, SUB_ACKED);
shouldCleanSubs = true;
boolean isSub =
transientSessionHandler.isSubscribing(subMessage.payload().topicSubscriptions().get(0).topicFilter());
assertTrue(isSub);
boolean hasSub =
transientSessionHandler.hasSubscribed(subMessage.payload().topicSubscriptions().get(0).topicFilter());
assertTrue(hasSub);
}

@Test
Expand Down
Loading