diff --git a/Source/MQTTnet/Formatter/MqttSubAckPacketFactory.cs b/Source/MQTTnet/Formatter/MqttSubAckPacketFactory.cs index 4f5436d92..88c695f91 100644 --- a/Source/MQTTnet/Formatter/MqttSubAckPacketFactory.cs +++ b/Source/MQTTnet/Formatter/MqttSubAckPacketFactory.cs @@ -4,7 +4,7 @@ using System; using MQTTnet.Packets; -using MQTTnet.Server; +using MQTTnet.Server.Internal; namespace MQTTnet.Formatter { diff --git a/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs new file mode 100644 index 000000000..4b3b1a4be --- /dev/null +++ b/Source/MQTTnet/Server/IMqttRetainedMessagesManager.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Server +{ + public interface IMqttRetainedMessagesManager + { + Task ClearMessages(); + + Task> GetMessages(CancellationToken cancellationToken = default); + + Task LoadMessages(IEnumerable subscriptions, CancellationToken cancellationToken = default); + + Task Start(); + + Task Stop(); + + Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs index e3dee096e..6bdd2dec6 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSessionsManager.cs @@ -29,7 +29,7 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification readonly MqttNetSourceLogger _logger; readonly MqttServerOptions _options; - readonly MqttRetainedMessagesManager _retainedMessagesManager; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; readonly IMqttNetLogger _rootLogger; // The _sessions dictionary contains all session, the _subscriberSessions hash set contains subscriber sessions only. @@ -41,7 +41,7 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification public MqttClientSessionsManager( MqttServerOptions options, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttRetainedMessagesManager retainedMessagesManager, MqttServerEventContainer eventContainer, IMqttNetLogger logger) { diff --git a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs index b409d959f..de6d20754 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs @@ -10,16 +10,15 @@ using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; +using MQTTnet.Server.Internal; namespace MQTTnet.Server { public sealed class MqttClientSubscriptionsManager : IDisposable { - static readonly List EmptySubscriptionIdentifiers = new List(); - readonly MqttServerEventContainer _eventContainer; readonly Dictionary> _noWildcardSubscriptionsByTopicHash = new Dictionary>(); - readonly MqttRetainedMessagesManager _retainedMessagesManager; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; readonly MqttSession _session; @@ -37,7 +36,7 @@ public sealed class MqttClientSubscriptionsManager : IDisposable public MqttClientSubscriptionsManager( MqttSession session, MqttServerEventContainer eventContainer, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttRetainedMessagesManager retainedMessagesManager, ISubscriptionChangedNotification subscriptionChangedNotification) { _session = session ?? throw new ArgumentNullException(nameof(session)); @@ -85,7 +84,7 @@ public CheckSubscriptionsResult CheckSubscriptions(string topic, ulong topicHash var senderIsReceiver = string.Equals(senderId, _session.Id); var maxQoSLevel = -1; // Not subscribed. - HashSet subscriptionIdentifiers = null; + var subscriptionIdentifiers = new HashSet(); var retainAsPublished = false; foreach (var subscription in possibleSubscriptions) @@ -114,11 +113,6 @@ public CheckSubscriptionsResult CheckSubscriptions(string topic, ulong topicHash if (subscription.Identifier > 0) { - if (subscriptionIdentifiers == null) - { - subscriptionIdentifiers = new HashSet(); - } - subscriptionIdentifiers.Add(subscription.Identifier); } } @@ -132,7 +126,7 @@ public CheckSubscriptionsResult CheckSubscriptions(string topic, ulong topicHash { IsSubscribed = true, RetainAsPublished = retainAsPublished, - SubscriptionIdentifiers = subscriptionIdentifiers?.ToList() ?? EmptySubscriptionIdentifiers, + SubscriptionIdentifiers = subscriptionIdentifiers.ToList(), // Start with the same QoS as the publisher. QualityOfServiceLevel = qualityOfServiceLevel @@ -167,12 +161,11 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket throw new ArgumentNullException(nameof(subscribePacket)); } - var retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false); var result = new SubscribeResult(subscribePacket.TopicFilters.Count); - var addedSubscriptions = new List(); - var finalTopicFilters = new List(); - + var addedSubscriptions = new List(subscribePacket.TopicFilters.Count); + var finalTopicFilters = new List(subscribePacket.TopicFilters.Count); + var subscriptionRetainedMessages = new List(); // The topic filters are order by its QoS so that the higher QoS will win over a // lower one. foreach (var topicFilterItem in subscribePacket.TopicFilters.OrderByDescending(f => f.QualityOfServiceLevel)) @@ -197,12 +190,46 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket continue; } - var createSubscriptionResult = CreateSubscription(topicFilter, subscribePacket.SubscriptionIdentifier, subscriptionEventArgs.Response.ReasonCode); + var qualtityOfService = SubscribeReasonCodeToQualityOfServiceLevel(subscriptionEventArgs.Response.ReasonCode); + var createSubscriptionResult = CreateSubscription(topicFilter, subscribePacket.SubscriptionIdentifier, qualtityOfService); addedSubscriptions.Add(topicFilter.Topic); finalTopicFilters.Add(topicFilter); - FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result); + if (createSubscriptionResult.Subscription.RetainHandling == MqttRetainHandling.DoNotSendOnSubscribe) + { + // This is a MQTT V5+ feature. + continue; + } + + if (createSubscriptionResult.Subscription.RetainHandling == MqttRetainHandling.SendAtSubscribeIfNewSubscriptionOnly && !createSubscriptionResult.IsNewSubscription) + { + // This is a MQTT V5+ feature. + continue; + } + + subscriptionRetainedMessages.Add(createSubscriptionResult); + } + + await _retainedMessagesManager.LoadMessages(subscriptionRetainedMessages, cancellationToken); + + foreach (var subscriptionRetainedMessage in subscriptionRetainedMessages) + { + foreach (var retainedMessage in subscriptionRetainedMessage.RetainedMessages) + { + var retainedMessageMatch = new MqttRetainedMessageMatch(retainedMessage, subscriptionRetainedMessage.Subscription.GrantedQualityOfServiceLevel); + if (retainedMessageMatch.SubscriptionQualityOfServiceLevel > retainedMessageMatch.ApplicationMessage.QualityOfServiceLevel) + { + // UPGRADING the QoS is not allowed! + // From MQTT spec: Subscribing to a Topic Filter at QoS 2 is equivalent to saying + // "I would like to receive Messages matching this filter at the QoS with which they were published". + // This means a publisher is responsible for determining the maximum QoS a Message can be delivered at, + // but a subscriber is able to require that the Server downgrades the QoS to one more suitable for its usage. + retainedMessageMatch.SubscriptionQualityOfServiceLevel = retainedMessageMatch.ApplicationMessage.QualityOfServiceLevel; + } + + result.RetainedMessages.Add(retainedMessageMatch); + } } // This call will add the new subscription to the internal storage. @@ -307,27 +334,8 @@ public async Task Unsubscribe(MqttUnsubscribePacket unsubscri return result; } - CreateSubscriptionResult CreateSubscription(MqttTopicFilter topicFilter, uint subscriptionIdentifier, MqttSubscribeReasonCode reasonCode) + SubscriptionRetainedMessagesResult CreateSubscription(MqttTopicFilter topicFilter, uint subscriptionIdentifier, MqttQualityOfServiceLevel grantedQualityOfServiceLevel) { - MqttQualityOfServiceLevel grantedQualityOfServiceLevel; - - if (reasonCode == MqttSubscribeReasonCode.GrantedQoS0) - { - grantedQualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce; - } - else if (reasonCode == MqttSubscribeReasonCode.GrantedQoS1) - { - grantedQualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce; - } - else if (reasonCode == MqttSubscribeReasonCode.GrantedQoS2) - { - grantedQualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce; - } - else - { - throw new InvalidOperationException(); - } - var subscription = new MqttSubscription( topicFilter.Topic, topicFilter.NoLocal, @@ -339,7 +347,6 @@ CreateSubscriptionResult CreateSubscription(MqttTopicFilter topicFilter, uint su bool isNewSubscription; // Add to subscriptions and maintain topic hash dictionaries - using (_subscriptionsLock.EnterAsync(CancellationToken.None).GetAwaiter().GetResult()) { MqttSubscription.CalculateTopicHash(topicFilter.Topic, out var topicHash, out var topicHashMask, out var hasWildcard); @@ -391,64 +398,21 @@ CreateSubscriptionResult CreateSubscription(MqttTopicFilter topicFilter, uint su } } - return new CreateSubscriptionResult - { - IsNewSubscription = isNewSubscription, - Subscription = subscription - }; + return new SubscriptionRetainedMessagesResult(subscription, isNewSubscription); } - static void FilterRetainedApplicationMessages( - IList retainedMessages, - CreateSubscriptionResult createSubscriptionResult, - SubscribeResult subscribeResult) + private static MqttQualityOfServiceLevel SubscribeReasonCodeToQualityOfServiceLevel(MqttSubscribeReasonCode reasonCode) { - for (var index = retainedMessages.Count - 1; index >= 0; index--) + switch (reasonCode) { - var retainedMessage = retainedMessages[index]; - if (retainedMessage == null) - { - continue; - } - - if (createSubscriptionResult.Subscription.RetainHandling == MqttRetainHandling.DoNotSendOnSubscribe) - { - // This is a MQTT V5+ feature. - continue; - } - - if (createSubscriptionResult.Subscription.RetainHandling == MqttRetainHandling.SendAtSubscribeIfNewSubscriptionOnly && !createSubscriptionResult.IsNewSubscription) - { - // This is a MQTT V5+ feature. - continue; - } - - if (MqttTopicFilterComparer.Compare(retainedMessage.Topic, createSubscriptionResult.Subscription.Topic) != MqttTopicFilterCompareResult.IsMatch) - { - continue; - } - - var retainedMessageMatch = new MqttRetainedMessageMatch(retainedMessage, createSubscriptionResult.Subscription.GrantedQualityOfServiceLevel); - if (retainedMessageMatch.SubscriptionQualityOfServiceLevel > retainedMessageMatch.ApplicationMessage.QualityOfServiceLevel) - { - // UPGRADING the QoS is not allowed! - // From MQTT spec: Subscribing to a Topic Filter at QoS 2 is equivalent to saying - // "I would like to receive Messages matching this filter at the QoS with which they were published". - // This means a publisher is responsible for determining the maximum QoS a Message can be delivered at, - // but a subscriber is able to require that the Server downgrades the QoS to one more suitable for its usage. - retainedMessageMatch.SubscriptionQualityOfServiceLevel = retainedMessageMatch.ApplicationMessage.QualityOfServiceLevel; - } - - if (subscribeResult.RetainedMessages == null) - { - subscribeResult.RetainedMessages = new List(); - } - - subscribeResult.RetainedMessages.Add(retainedMessageMatch); - - // Clear the retained message from the list because the client should receive every message only - // one time even if multiple subscriptions affect them. - retainedMessages[index] = null; + case MqttSubscribeReasonCode.GrantedQoS0: + return MqttQualityOfServiceLevel.AtMostOnce; + case MqttSubscribeReasonCode.GrantedQoS1: + return MqttQualityOfServiceLevel.AtLeastOnce; + case MqttSubscribeReasonCode.GrantedQoS2: + return MqttQualityOfServiceLevel.ExactlyOnce; + default: + throw new InvalidOperationException(); } } @@ -495,12 +459,5 @@ async Task InterceptUnsubscribe(string topi return clientUnsubscribingTopicEventArgs; } - - sealed class CreateSubscriptionResult - { - public bool IsNewSubscription { get; set; } - - public MqttSubscription Subscription { get; set; } - } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs index a9b86a7a7..2d8694582 100644 --- a/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs @@ -5,13 +5,14 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using MQTTnet.Diagnostics; using MQTTnet.Internal; namespace MQTTnet.Server { - public sealed class MqttRetainedMessagesManager + public sealed class MqttRetainedMessagesManager: IMqttRetainedMessagesManager { readonly Dictionary _messages = new Dictionary(4096); readonly AsyncLock _storageAccessLock = new AsyncLock(); @@ -53,6 +54,16 @@ public async Task Start() } } + public Task Stop() + { +#if NET461_OR_GREATER + return Task.CompletedTask; +#else + return Task.FromResult(0); +#endif + + } + public async Task UpdateMessage(string clientId, MqttApplicationMessage applicationMessage) { if (applicationMessage == null) @@ -114,7 +125,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat } } - public Task> GetMessages() + public Task> GetMessages(CancellationToken cancellationToken = default) { lock (_messages) { @@ -135,5 +146,26 @@ public async Task ClearMessages() await _eventContainer.RetainedMessagesClearedEvent.InvokeAsync(EventArgs.Empty).ConfigureAwait(false); } } + + public async Task LoadMessages(IEnumerable subscriptions, CancellationToken cancellationToken = default) + { + var allRetainedMessages = await GetMessages(cancellationToken); + + for (var i = allRetainedMessages.Count - 1; i >= 0; i--) + { + var retainedMessage = allRetainedMessages[i]; + foreach (var subscription in subscriptions) + { + if (MqttTopicFilterComparer.Compare(retainedMessage.Topic, subscription.Subscription.Topic) == MqttTopicFilterCompareResult.IsMatch) + { + subscription.RetainedMessages.Add(retainedMessage); + + // Skip the following subscriptions, as each message may only be sent once. + // Following subscriptions do not have to be checked, as the message is sent anyway. + break; + } + } + } + } } } diff --git a/Source/MQTTnet/Server/Internal/MqttSession.cs b/Source/MQTTnet/Server/Internal/MqttSession.cs index b40c26455..718d6cdde 100644 --- a/Source/MQTTnet/Server/Internal/MqttSession.cs +++ b/Source/MQTTnet/Server/Internal/MqttSession.cs @@ -12,6 +12,7 @@ using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; +using MQTTnet.Server.Internal; namespace MQTTnet.Server { @@ -36,7 +37,7 @@ public MqttSession( IDictionary items, MqttServerOptions serverOptions, MqttServerEventContainer eventContainer, - MqttRetainedMessagesManager retainedMessagesManager, + IMqttRetainedMessagesManager retainedMessagesManager, MqttClientSessionsManager clientSessionsManager) { Id = clientId ?? throw new ArgumentNullException(nameof(clientId)); diff --git a/Source/MQTTnet/Server/Internal/SubscribeResult.cs b/Source/MQTTnet/Server/Internal/SubscribeResult.cs index 9c828edad..27d2c2980 100644 --- a/Source/MQTTnet/Server/Internal/SubscribeResult.cs +++ b/Source/MQTTnet/Server/Internal/SubscribeResult.cs @@ -6,22 +6,23 @@ using MQTTnet.Packets; using MQTTnet.Protocol; -namespace MQTTnet.Server +namespace MQTTnet.Server.Internal { public sealed class SubscribeResult { public SubscribeResult(int topicsCount) { ReasonCodes = new List(topicsCount); + RetainedMessages = new List(); } - + public bool CloseConnection { get; set; } - + public List ReasonCodes { get; set; } public string ReasonString { get; set; } - public List RetainedMessages { get; set; } + public List RetainedMessages { get; } public List UserProperties { get; set; } } diff --git a/Source/MQTTnet/Server/MqttServer.cs b/Source/MQTTnet/Server/MqttServer.cs index d84826140..abdd3c275 100644 --- a/Source/MQTTnet/Server/MqttServer.cs +++ b/Source/MQTTnet/Server/MqttServer.cs @@ -25,7 +25,7 @@ public class MqttServer : Disposable readonly MqttServerKeepAliveMonitor _keepAliveMonitor; readonly MqttNetSourceLogger _logger; readonly MqttServerOptions _options; - readonly MqttRetainedMessagesManager _retainedMessagesManager; + readonly IMqttRetainedMessagesManager _retainedMessagesManager; readonly IMqttNetLogger _rootLogger; CancellationTokenSource _cancellationTokenSource; @@ -44,7 +44,7 @@ public MqttServer(MqttServerOptions options, IEnumerable ada _rootLogger = logger ?? throw new ArgumentNullException(nameof(logger)); _logger = logger.WithSource(nameof(MqttServer)); - _retainedMessagesManager = new MqttRetainedMessagesManager(_eventContainer, _rootLogger); + _retainedMessagesManager = options.RetainedMessagesManager ?? new MqttRetainedMessagesManager(_eventContainer, _rootLogger); _clientSessionsManager = new MqttClientSessionsManager(options, _retainedMessagesManager, _eventContainer, _rootLogger); _keepAliveMonitor = new MqttServerKeepAliveMonitor(options, _clientSessionsManager, _rootLogger); } @@ -287,7 +287,7 @@ public async Task StopAsync() _cancellationTokenSource?.Dispose(); _cancellationTokenSource = null; } - + await _retainedMessagesManager.Stop().ConfigureAwait(false); await _eventContainer.StoppedEvent.InvokeAsync(EventArgs.Empty).ConfigureAwait(false); _logger.Info("Stopped."); diff --git a/Source/MQTTnet/Server/Options/MqttServerOptions.cs b/Source/MQTTnet/Server/Options/MqttServerOptions.cs index 59e46ab65..6c74575a1 100644 --- a/Source/MQTTnet/Server/Options/MqttServerOptions.cs +++ b/Source/MQTTnet/Server/Options/MqttServerOptions.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Server public sealed class MqttServerOptions { public TimeSpan DefaultCommunicationTimeout { get; set; } = TimeSpan.FromSeconds(100); - + public MqttServerTcpEndpointOptions DefaultEndpointOptions { get; } = new MqttServerTcpEndpointOptions(); public bool EnablePersistentSessions { get; set; } @@ -20,6 +20,8 @@ public sealed class MqttServerOptions public MqttPendingMessagesOverflowStrategy PendingMessagesOverflowStrategy { get; set; } = MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage; + public IMqttRetainedMessagesManager RetainedMessagesManager { get; set; } + public MqttServerTlsTcpEndpointOptions TlsEndpointOptions { get; } = new MqttServerTlsTcpEndpointOptions(); /// diff --git a/Source/MQTTnet/Server/SubscriptionRetainedMessagesResult.cs b/Source/MQTTnet/Server/SubscriptionRetainedMessagesResult.cs new file mode 100644 index 000000000..53c76355d --- /dev/null +++ b/Source/MQTTnet/Server/SubscriptionRetainedMessagesResult.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace MQTTnet.Server +{ + public readonly struct SubscriptionRetainedMessagesResult + { + public SubscriptionRetainedMessagesResult(MqttSubscription subscription, bool isNewSubscription) + { + Subscription = subscription ?? throw new ArgumentNullException(nameof(subscription)); + IsNewSubscription = isNewSubscription; + RetainedMessages = new List(); + } + + public bool IsNewSubscription { get; } + + public MqttSubscription Subscription { get; } + + public List RetainedMessages { get; } + } +} \ No newline at end of file