Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve session message injection #2117

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace MQTTnet.Server.Exceptions;

public class MqttPendingMessagesOverflowException : Exception
{
public MqttPendingMessagesOverflowException(string sessionId, MqttPendingMessagesOverflowStrategy overflowStrategy) : base(
$"Send buffer max pending messages overflow occurred for session '{sessionId}'. Strategy: {overflowStrategy}.")
{
SessionId = sessionId;
OverflowStrategy = overflowStrategy;
}

public MqttPendingMessagesOverflowStrategy OverflowStrategy { get; }

public string SessionId { get; }
}
15 changes: 11 additions & 4 deletions Source/MQTTnet.Server/Internal/MqttSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;
using MQTTnet.Server.Exceptions;

namespace MQTTnet.Server.Internal;

Expand Down Expand Up @@ -111,10 +112,11 @@ public void EnqueueControlPacket(MqttPacketBusItem packetBusItem)

public EnqueueDataPacketResult EnqueueDataPacket(MqttPacketBusItem packetBusItem)
{
if (_packetBus.ItemsCount(MqttPacketBusPartition.Data) >= _serverOptions.MaxPendingMessagesPerClient)
if (PendingDataPacketsCount >= _serverOptions.MaxPendingMessagesPerClient)
{
if (_serverOptions.PendingMessagesOverflowStrategy == MqttPendingMessagesOverflowStrategy.DropNewMessage)
{
packetBusItem.Fail(new MqttPendingMessagesOverflowException(Id, _serverOptions.PendingMessagesOverflowStrategy));
return EnqueueDataPacketResult.Dropped;
}

Expand All @@ -123,10 +125,15 @@ public EnqueueDataPacketResult EnqueueDataPacket(MqttPacketBusItem packetBusItem
// Only drop from the data partition. Dropping from control partition might break the connection
// because the client does not receive PINGREQ packets etc. any longer.
var firstItem = _packetBus.DropFirstItem(MqttPacketBusPartition.Data);
if (firstItem != null && _eventContainer.QueuedApplicationMessageOverwrittenEvent.HasHandlers)
if (firstItem != null)
{
var eventArgs = new QueueMessageOverwrittenEventArgs(Id, firstItem.Packet);
_eventContainer.QueuedApplicationMessageOverwrittenEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
firstItem.Fail(new MqttPendingMessagesOverflowException(Id, _serverOptions.PendingMessagesOverflowStrategy));

if (_eventContainer.QueuedApplicationMessageOverwrittenEvent.HasHandlers)
{
var eventArgs = new QueueMessageOverwrittenEventArgs(Id, firstItem.Packet);
_eventContainer.QueuedApplicationMessageOverwrittenEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
}
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions Source/MQTTnet.Server/Options/MqttServerOptionsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ public MqttServerOptionsBuilder WithMaxPendingMessagesPerClient(int value)
return this;
}

public MqttServerOptionsBuilder WithPendingMessagesOverflowStrategy(MqttPendingMessagesOverflowStrategy value)
{
_options.PendingMessagesOverflowStrategy = value;
return this;
}

public MqttServerOptionsBuilder WithoutDefaultEndpoint()
{
_options.DefaultEndpointOptions.IsEnabled = false;
Expand Down
53 changes: 48 additions & 5 deletions Source/MQTTnet.Server/Status/MqttSessionStatus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,65 @@ public Task DeleteAsync()
return _session.DeleteAsync();
}

public Task DeliverApplicationMessageAsync(MqttApplicationMessage applicationMessage)
/// <summary>
/// Delivers an application message immediately to the session.
/// </summary>
/// <param name="applicationMessage">The application message to deliver.</param>
/// <returns>
/// A task that represents the asynchronous operation.
/// The result contains the <see cref="InjectMqttApplicationMessageResult"/> that includes the packet identifier of the enqueued message.
/// </returns>
public async Task<InjectMqttApplicationMessageResult> DeliverApplicationMessageAsync(MqttApplicationMessage applicationMessage)
{
ArgumentNullException.ThrowIfNull(applicationMessage);

var packetBusItem = new MqttPacketBusItem(MqttPublishPacketFactory.Create(applicationMessage));
var publishPacket = MqttPublishPacketFactory.Create(applicationMessage);
var packetBusItem = new MqttPacketBusItem(publishPacket);
_session.EnqueueDataPacket(packetBusItem);

return packetBusItem.WaitAsync();
await packetBusItem.WaitAsync().ConfigureAwait(false);

var injectResult = new InjectMqttApplicationMessageResult()
{
PacketIdentifier = publishPacket.PacketIdentifier
};

return injectResult;
}

public Task EnqueueApplicationMessageAsync(MqttApplicationMessage applicationMessage)
/// <summary>
/// Attempts to enqueue an application message to the session's send buffer.
/// </summary>
/// <param name="applicationMessage">The application message to enqueue.</param>
/// <param name="injectResult"><see cref="InjectMqttApplicationMessageResult"/> that includes the packet identifier of the enqueued message.</param>
/// <returns><c>true</c> if the message was successfully enqueued; otherwise, <c>false</c>.</returns>
/// <remarks>
/// When <see cref="MqttServerOptions.PendingMessagesOverflowStrategy"/> is set to <see cref="MqttPendingMessagesOverflowStrategy.DropOldestQueuedMessage"/>,
/// this method always returns <c>true</c>.
/// However, an existing message in the queue may be <b>dropped later</b> to make room for the newly enqueued message.
/// Such dropped messages can be tracked by subscribing to <see cref="MqttServer.QueuedApplicationMessageOverwrittenAsync"/> event.
/// </remarks>
public bool TryEnqueueApplicationMessage(MqttApplicationMessage applicationMessage, out InjectMqttApplicationMessageResult injectResult)
{
ArgumentNullException.ThrowIfNull(applicationMessage);

_session.EnqueueDataPacket(new MqttPacketBusItem(MqttPublishPacketFactory.Create(applicationMessage)));
var publishPacket = MqttPublishPacketFactory.Create(applicationMessage);
var enqueueDataPacketResult = _session.EnqueueDataPacket(new MqttPacketBusItem(publishPacket));

if (enqueueDataPacketResult != EnqueueDataPacketResult.Enqueued)
{
injectResult = null;
return false;
}

injectResult = new InjectMqttApplicationMessageResult() { PacketIdentifier = publishPacket.PacketIdentifier };
return true;
}

[Obsolete("This method is obsolete. Use TryEnqueueApplicationMessage instead.")]
public Task EnqueueApplicationMessageAsync(MqttApplicationMessage applicationMessage)
{
TryEnqueueApplicationMessage(applicationMessage, out _);
return CompletedTask.Instance;
}
}
7 changes: 4 additions & 3 deletions Source/MQTTnet.Tests/BaseTestClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ namespace MQTTnet.Tests
public abstract class BaseTestClass
{
public TestContext TestContext { get; set; }

protected TestEnvironment CreateTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311)

protected TestEnvironment CreateTestEnvironment(
MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311, bool trackUnobservedTaskException = true)
{
return new TestEnvironment(TestContext, protocolVersion);
return new TestEnvironment(TestContext, protocolVersion, trackUnobservedTaskException);
}

protected Task LongTestDelay()
Expand Down
8 changes: 6 additions & 2 deletions Source/MQTTnet.Tests/Mockups/TestEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ public TestEnvironment() : this(null)
{
}

public TestEnvironment(TestContext testContext, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311)
public TestEnvironment(
TestContext testContext, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311, bool trackUnobservedTaskException = true)
{
_protocolVersion = protocolVersion;
TestContext = testContext;

TaskScheduler.UnobservedTaskException += TrackUnobservedTaskException;
if (trackUnobservedTaskException)
{
TaskScheduler.UnobservedTaskException += TrackUnobservedTaskException;
}

ServerLogger.LogMessagePublished += (s, e) =>
{
Expand Down
Loading