diff --git a/Samples/Client/Client_Publish_Samples.cs b/Samples/Client/Client_Publish_Samples.cs index 16cc67c78..66bdfbcba 100644 --- a/Samples/Client/Client_Publish_Samples.cs +++ b/Samples/Client/Client_Publish_Samples.cs @@ -31,12 +31,10 @@ public static async Task Publish_Application_Message() await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); - var applicationMessage = new MqttApplicationMessageBuilder() - .WithTopic("samples/temperature/living_room") - .WithPayload("19.5") - .Build(); - - await mqttClient.PublishAsync(applicationMessage, CancellationToken.None); + await mqttClient.PublishStringAsync( + topic: "samples/temperature/living_room", + payload: "19.5", + cancellationToken: CancellationToken.None); await mqttClient.DisconnectAsync(); @@ -61,27 +59,21 @@ public static async Task Publish_Multiple_Application_Messages() .Build(); await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); - - var applicationMessage = new MqttApplicationMessageBuilder() - .WithTopic("samples/temperature/living_room") - .WithPayload("19.5") - .Build(); - - await mqttClient.PublishAsync(applicationMessage, CancellationToken.None); - - applicationMessage = new MqttApplicationMessageBuilder() - .WithTopic("samples/temperature/living_room") - .WithPayload("20.0") - .Build(); - - await mqttClient.PublishAsync(applicationMessage, CancellationToken.None); - - applicationMessage = new MqttApplicationMessageBuilder() - .WithTopic("samples/temperature/living_room") - .WithPayload("21.0") - .Build(); - - await mqttClient.PublishAsync(applicationMessage, CancellationToken.None); + + await mqttClient.PublishStringAsync( + topic: "samples/temperature/living_room", + payload: "19.5", + cancellationToken: CancellationToken.None); + + await mqttClient.PublishStringAsync( + topic: "samples/temperature/living_room", + payload: "20.0", + cancellationToken: CancellationToken.None); + + await mqttClient.PublishStringAsync( + topic: "samples/temperature/living_room", + payload: "21.0", + cancellationToken: CancellationToken.None); await mqttClient.DisconnectAsync(); diff --git a/Samples/Diagnostics/PackageInspection_Samples.cs b/Samples/Diagnostics/PackageInspection_Samples.cs index 1ab996a28..741a7a3e4 100644 --- a/Samples/Diagnostics/PackageInspection_Samples.cs +++ b/Samples/Diagnostics/PackageInspection_Samples.cs @@ -7,6 +7,7 @@ // ReSharper disable InconsistentNaming using MQTTnet.Diagnostics.PacketInspection; +using System.Buffers; namespace MQTTnet.Samples.Diagnostics; @@ -43,11 +44,11 @@ static Task OnInspectPacket(InspectMqttPacketEventArgs eventArgs) { if (eventArgs.Direction == MqttPacketFlowDirection.Inbound) { - Console.WriteLine($"IN: {Convert.ToBase64String(eventArgs.Buffer)}"); + Console.WriteLine($"IN: {Convert.ToBase64String(eventArgs.Buffer.ToArray())}"); } else { - Console.WriteLine($"OUT: {Convert.ToBase64String(eventArgs.Buffer)}"); + Console.WriteLine($"OUT: {Convert.ToBase64String(eventArgs.Buffer.ToArray())}"); } return Task.CompletedTask; diff --git a/Samples/MQTTnet.Samples.csproj b/Samples/MQTTnet.Samples.csproj index 5fe84d380..2441f1c9c 100644 --- a/Samples/MQTTnet.Samples.csproj +++ b/Samples/MQTTnet.Samples.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Samples/RpcClient/RpcClient_Samples.cs b/Samples/RpcClient/RpcClient_Samples.cs index b50c91218..4bed81b0c 100644 --- a/Samples/RpcClient/RpcClient_Samples.cs +++ b/Samples/RpcClient/RpcClient_Samples.cs @@ -38,7 +38,7 @@ public static async Task Send_Request() { // Access to a fully featured application message is not supported for RPC calls! // The method will throw an exception when the response was not received in time. - await mqttRpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); + await mqttRpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } Console.WriteLine("The RPC call was successful."); diff --git a/Samples/Server/Server_Retained_Messages_Samples.cs b/Samples/Server/Server_Retained_Messages_Samples.cs index 4564acae9..d64fcd78d 100644 --- a/Samples/Server/Server_Retained_Messages_Samples.cs +++ b/Samples/Server/Server_Retained_Messages_Samples.cs @@ -92,8 +92,8 @@ public static async Task Persist_Retained_Messages() sealed class MqttRetainedMessageModel { public string? ContentType { get; set; } - public byte[]? CorrelationData { get; set; } - public byte[]? Payload { get; set; } + public ReadOnlyMemory CorrelationData { get; set; } + public ReadOnlySequence Payload { get; set; } public MqttPayloadFormatIndicator PayloadFormatIndicator { get; set; } public MqttQualityOfServiceLevel QualityOfServiceLevel { get; set; } public string? ResponseTopic { get; set; } @@ -110,7 +110,7 @@ public static MqttRetainedMessageModel Create(MqttApplicationMessage message) // Create a copy of the buffer from the payload segment because // it cannot be serialized and deserialized with the JSON serializer. - Payload = message.Payload.ToArray(), + Payload = new ReadOnlySequence(message.Payload.ToArray()), UserProperties = message.UserProperties, ResponseTopic = message.ResponseTopic, CorrelationData = message.CorrelationData, @@ -128,7 +128,7 @@ public MqttApplicationMessage ToApplicationMessage() return new MqttApplicationMessage { Topic = Topic, - PayloadSegment = new ArraySegment(Payload ?? Array.Empty()), + Payload = Payload, PayloadFormatIndicator = PayloadFormatIndicator, ResponseTopic = ResponseTopic, CorrelationData = CorrelationData, diff --git a/Samples/Server/Server_Simple_Samples.cs b/Samples/Server/Server_Simple_Samples.cs index e7ed4bfec..d8733fedd 100644 --- a/Samples/Server/Server_Simple_Samples.cs +++ b/Samples/Server/Server_Simple_Samples.cs @@ -44,18 +44,13 @@ public static async Task Publish_Message_From_Broker() * See _Run_Minimal_Server_ for more information. */ - using (var mqttServer = await StartMqttServer()) - { - // Create a new message using the builder as usual. - var message = new MqttApplicationMessageBuilder().WithTopic("HelloWorld").WithPayload("Test").Build(); + using var mqttServer = await StartMqttServer(); - // Now inject the new message at the broker. - await mqttServer.InjectApplicationMessage( - new InjectedMqttApplicationMessage(message) - { - SenderClientId = "SenderClientId" - }); - } + // Now inject the new message at the broker. + await mqttServer.InjectStringAsync( + clientId: "SenderClientId", + topic: "HelloWorld", + payload: "Test"); } public static async Task Run_Minimal_Server() diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 5357dd702..6b7f60ce2 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -37,7 +37,7 @@ true low low - latest-Recommended + diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index 61184c4d5..9e127efa1 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -231,8 +231,8 @@ static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) var span = output.GetSpan(buffer.Length); - buffer.Packet.AsSpan().CopyTo(span); - int offset = buffer.Packet.Count; + buffer.Packet.Span.CopyTo(span); + int offset = buffer.Packet.Length; buffer.Payload.CopyTo(destination: span.Slice(offset)); output.Advance(buffer.Length); } diff --git a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs b/Source/MQTTnet.AspnetCore/ReaderExtensions.cs index 9b4f24ca5..de55b52bf 100644 --- a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ReaderExtensions.cs @@ -48,9 +48,7 @@ public static bool TryDecode( } var bodySlice = copy.Slice(0, bodyLength); - var bodySegment = GetArraySegment(ref bodySlice); - - var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, bodySegment, headerLength + bodyLength); + var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, bodySlice, headerLength + bodyLength); if (formatter.ProtocolVersion == MqttProtocolVersion.Unknown) { formatter.DetectProtocolVersion(receivedMqttPacket); @@ -62,19 +60,7 @@ public static bool TryDecode( bytesRead = headerLength + bodyLength; return true; } - - static ArraySegment GetArraySegment(ref ReadOnlySequence input) - { - if (input.IsSingleSegment && MemoryMarshal.TryGetArray(input.First, out var segment)) - { - return segment; - } - - // Should be rare - var array = input.ToArray(); - return new ArraySegment(array); - } - + static void ThrowProtocolViolationException(ReadOnlySpan valueSpan, int index) { diff --git a/Source/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs b/Source/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs index cbc3f7996..e236284d4 100644 --- a/Source/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ChannelAdapterBenchmark.cs @@ -76,12 +76,12 @@ public void Setup() _channelAdapter = new MqttChannelAdapter(channel, serializer, new MqttNetEventLogger()); } - static byte[] Join(params ArraySegment[] chunks) + static byte[] Join(params ReadOnlyMemory[] chunks) { var buffer = new MemoryStream(); foreach (var chunk in chunks) { - buffer.Write(chunk.Array, chunk.Offset, chunk.Count); + buffer.Write(chunk.Span); } return buffer.ToArray(); diff --git a/Source/MQTTnet.Benchmarks/JsonPayloadBenchmark.cs b/Source/MQTTnet.Benchmarks/JsonPayloadBenchmark.cs new file mode 100644 index 000000000..808846c4f --- /dev/null +++ b/Source/MQTTnet.Benchmarks/JsonPayloadBenchmark.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; +using MQTTnet.Internal; +using System.Text.Json; +using System.Threading.Tasks; + +namespace MQTTnet.Benchmarks +{ + [SimpleJob(RuntimeMoniker.Net80)] + [MemoryDiagnoser] + public class JsonPayloadBenchmark : BaseBenchmark + { + [Params(1 * 1024, 4 * 1024, 8 * 1024)] + public int PayloadSize { get; set; } + private Model model; + + + [GlobalSetup] + public void Setup() + { + var stringValue = new char[PayloadSize]; + model = new Model { StringValue = new string(stringValue) }; + } + + [Benchmark] + public ValueTask SerializeToString_Payload_SendAsync() + { + string payload = JsonSerializer.Serialize(model); + var message = new MqttApplicationMessageBuilder() + .WithTopic("t") + .WithPayload(payload) + .Build(); + + // send message async + return ValueTask.CompletedTask; + } + + [Benchmark] + public ValueTask SerializeToUtf8Bytes_Payload_SendAsync() + { + byte[] payload = JsonSerializer.SerializeToUtf8Bytes(model); + var message = new MqttApplicationMessageBuilder() + .WithTopic("t") + .WithPayload(payload) + .Build(); + + // send message async + return ValueTask.CompletedTask; + } + + [Benchmark(Baseline = true)] + public async ValueTask MqttPayloadOwnerFactory_Payload_SendAsync() + { + await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(async writer => + await JsonSerializer.SerializeAsync(writer.AsStream(leaveOpen: true), model)); + + var message = new MqttApplicationMessageBuilder() + .WithTopic("t") + .WithPayload(payloadOwner.Payload) + .Build(); + + // send message async + } + + + + public class Model + { + public string StringValue { get; set; } + + public int IntValue { get; set; } + + public bool BoolValue { get; set; } + + public double DoubleValue { get; set; } + } + } +} diff --git a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index d50ca5cd9..b206fe801 100644 --- a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index 894ef19e5..30e211106 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -5,44 +5,52 @@ using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Jobs; using MQTTnet.Server; +using System.Threading.Tasks; namespace MQTTnet.Benchmarks; -[SimpleJob(RuntimeMoniker.Net60)] +[SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [RankColumn] [MemoryDiagnoser] public class MessageProcessingBenchmark : BaseBenchmark { - MqttApplicationMessage _message; IMqttClient _mqttClient; MqttServer _mqttServer; + string _payload = string.Empty; + + [Params(1 * 1024, 4 * 1024, 8 * 1024)] + public int PayloadSize { get; set; } [Benchmark] - public void Send_10000_Messages() + public async Task Send_1000_Messages() { - for (var i = 0; i < 10000; i++) + for (var i = 0; i < 1000; i++) { - _mqttClient.PublishAsync(_message).GetAwaiter().GetResult(); + await _mqttClient.PublishStringAsync("A", _payload); } } [GlobalSetup] - public void Setup() + public async Task Setup() { - var serverOptions = new MqttServerOptionsBuilder().Build(); - var serverFactory = new MqttServerFactory(); + var serverOptions = new MqttServerOptionsBuilder() + .WithDefaultEndpoint() + .Build(); + _mqttServer = serverFactory.CreateMqttServer(serverOptions); + await _mqttServer.StartAsync(); + var clientFactory = new MqttClientFactory(); _mqttClient = clientFactory.CreateMqttClient(); - _mqttServer.StartAsync().GetAwaiter().GetResult(); - - var clientOptions = new MqttClientOptionsBuilder().WithTcpServer("localhost").Build(); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); - _mqttClient.ConnectAsync(clientOptions).GetAwaiter().GetResult(); + await _mqttClient.ConnectAsync(clientOptions); - _message = new MqttApplicationMessageBuilder().WithTopic("A").Build(); + _payload = string.Empty.PadLeft(PayloadSize, '0'); } } \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs index bfa3d209c..e49dd79f0 100644 --- a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs @@ -2,11 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Text; using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Jobs; using MQTTnet.Formatter; +using System; +using System.Runtime.InteropServices; +using System.Text; namespace MQTTnet.Benchmarks { @@ -14,8 +15,7 @@ namespace MQTTnet.Benchmarks [MemoryDiagnoser] public class MqttBufferReaderBenchmark { - byte[] _buffer; - int _bufferLength; + ArraySegment _buffer; [GlobalSetup] public void GlobalSetup() @@ -23,21 +23,22 @@ public void GlobalSetup() var writer = new MqttBufferWriter(1024, 1024); writer.WriteString("hgfjkdfkjlghfdjghdfljkdfhgdlkjfshgsldkfjghsdflkjghdsflkjhrstiuoghlkfjbhnfbutghjoiöjhklötnbhtroliöuhbjntluiobkjzbhtdrlskbhtruhjkfthgbkftgjhgfiklhotriuöhbjtrsioöbtrsötrhträhtrühjtriüoätrhjtsrölbktrbnhtrulöbionhströloubinströoliubhnsöotrunbtöroisntröointröioujhgötiohjgötorshjnbgtöorihbnjtröoihbjntröobntröoibntrjhötrohjbtröoihntröoibnrtoiöbtrjnboöitrhjtnriohötrhjtöroihjtroöihjtroösibntsroönbotöirsbntöoihjntröoihntroöbtrboöitrnhoöitrhjntröoishbnjtröosbhtröbntriohjtröoijtöoitbjöotibjnhöotirhbjntroöibhnjrtoöibnhtroöibnhtörsbnhtöoirbnhtöroibntoörhjnbträöbtrbträbtrbtirbätrsibohjntrsöiobthnjiohjsrtoib"); - _buffer = writer.GetBuffer(); - _bufferLength = writer.Length; + if (MemoryMarshal.TryGetArray(writer.GetWrittenMemory(), out var segment)) + { + _buffer = segment; + } } [Benchmark] public void Use_Span() { - var span = _buffer.AsSpan(0, _bufferLength); - Encoding.UTF8.GetString(span); + Encoding.UTF8.GetString(_buffer.AsSpan()); } - + [Benchmark] public void Use_Encoding() { - Encoding.UTF8.GetString(_buffer, 0, _bufferLength); + Encoding.UTF8.GetString(_buffer.Array, _buffer.Offset, _buffer.Count); } } } \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs index 0efc7ffac..fca65aaab 100644 --- a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs @@ -15,8 +15,8 @@ namespace MQTTnet.Benchmarks public class MqttPacketReaderWriterBenchmark : BaseBenchmark { readonly byte[] _demoPayload = new byte[1024]; - - byte[] _readPayload; + + ReadOnlyMemory _readPayload; [GlobalCleanup] public void GlobalCleanup() @@ -27,7 +27,7 @@ public void GlobalCleanup() public void GlobalSetup() { TestEnvironment.EnableLogger = false; - + var writer = new MqttBufferWriter(4096, 65535); writer.WriteString("A relative short string."); writer.WriteBinary(_demoPayload); @@ -42,18 +42,16 @@ public void GlobalSetup() writer.WriteString("fjgffiogfhgfhoihgoireghreghreguhreguireoghreouighreouighreughreguiorehreuiohruiorehreuioghreug"); writer.WriteBinary(_demoPayload); - _readPayload = new ArraySegment(writer.GetBuffer(), 0, writer.Length).ToArray(); + _readPayload = writer.GetWrittenMemory(); } [Benchmark] public void Read_100_000_Messages() { - var reader = new MqttBufferReader(); - reader.SetBuffer(_readPayload, 0, _readPayload.Length); - for (var i = 0; i < 100000; i++) { - reader.Seek(0); + var reader = new MqttBufferReader(); + reader.SetBuffer(_readPayload); reader.ReadString(); reader.ReadBinaryData(); @@ -69,7 +67,7 @@ public void Read_100_000_Messages() reader.ReadBinaryData(); } } - + [Benchmark] public void Write_100_000_Messages() { diff --git a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs index 5f2242461..70fcc8e9b 100644 --- a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs @@ -35,7 +35,7 @@ public void GlobalSetup() var buffer = mqttPacketFormatter.Encode(packet); stream = new MemoryStream(); - stream.Write(buffer.Packet); + stream.Write(buffer.Packet.Span); stream.Write(buffer.Payload.ToArray()); mqttPacketFormatter.Cleanup(); } @@ -170,9 +170,7 @@ public static bool TryDecode(MqttPacketFormatterAdapter formatter, } var bodySlice = copy.Slice(0, bodyLength); - var buffer = GetMemory(bodySlice).ToArray(); - - var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, new ArraySegment(buffer, 0, buffer.Length), buffer.Length + 2); + var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, bodySlice, (int)bodySlice.Length + 2); if (formatter.ProtocolVersion == MqttProtocolVersion.Unknown) { diff --git a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs index b31782e66..1e96823df 100644 --- a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs @@ -21,7 +21,7 @@ public class SendPacketAsyncBenchmark : BaseBenchmark public void GlobalSetup() { stream = new MemoryStream(1024); - var packet = new ArraySegment(new byte[10]); + var packet = new byte[10]; buffer = new MqttPacketBuffer(packet); } @@ -60,8 +60,8 @@ static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) var span = output.GetSpan(buffer.Length); - buffer.Packet.AsSpan().CopyTo(span); - buffer.Payload.CopyTo(span.Slice(buffer.Packet.Count)); + buffer.Packet.Span.CopyTo(span); + buffer.Payload.CopyTo(span.Slice(buffer.Packet.Length)); output.Advance(buffer.Length); } diff --git a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs index 48117232c..af71d9ecc 100644 --- a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -25,7 +25,7 @@ namespace MQTTnet.Benchmarks public class SerializerBenchmark : BaseBenchmark { MqttPacket _packet; - ArraySegment _serializedPacket; + ReadOnlyMemory _serializedPacket; IMqttPacketFormatter _serializer; MqttBufferWriter _bufferWriter; @@ -68,13 +68,13 @@ public void Deserialize_10000_Messages() class BenchmarkMqttChannel : IMqttChannel { - readonly ArraySegment _buffer; + readonly ReadOnlyMemory _buffer; int _position; - public BenchmarkMqttChannel(ArraySegment buffer) + public BenchmarkMqttChannel(ReadOnlyMemory buffer) { _buffer = buffer; - _position = _buffer.Offset; + _position = 0; } public EndPoint RemoteEndPoint { get; set; } @@ -85,7 +85,7 @@ public BenchmarkMqttChannel(ArraySegment buffer) public void Reset() { - _position = _buffer.Offset; + _position = 0; } public Task ConnectAsync(CancellationToken cancellationToken) @@ -100,7 +100,8 @@ public Task DisconnectAsync(CancellationToken cancellationToken) public Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - Array.Copy(_buffer.Array, _position, buffer, offset, count); + _buffer.Slice(_position).Span.CopyTo(buffer.AsSpan(offset, count)); + _position += count; return Task.FromResult(count); diff --git a/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs b/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs index b21affed2..95b9b457a 100644 --- a/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs +++ b/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs @@ -12,7 +12,7 @@ public MqttRpcTopicPair CreateRpcTopics(TopicGenerationContext context) { ArgumentNullException.ThrowIfNull(context); - if (context.MethodName.Contains("/") || context.MethodName.Contains("+") || context.MethodName.Contains("#")) + if (context.MethodName.Contains('/') || context.MethodName.Contains('+') || context.MethodName.Contains('#')) { throw new ArgumentException("The method name cannot contain /, + or #."); } diff --git a/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs b/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs index ccece444a..877370767 100644 --- a/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs +++ b/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Protocol; using System; +using System.Buffers; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public interface IMqttRpcClient : IDisposable - { - Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null); - - Task ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default); + { + Task> ExecuteAsync(string methodName, ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj index b38fb489d..f4353d21a 100644 --- a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj +++ b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj @@ -35,7 +35,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs index c64b93c2d..bebcf7a11 100644 --- a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs +++ b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Exceptions; -using MQTTnet.Formatter; -using MQTTnet.Internal; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { @@ -21,7 +20,7 @@ public sealed class MqttRpcClient : IMqttRpcClient readonly IMqttClient _mqttClient; readonly MqttRpcClientOptions _options; - readonly ConcurrentDictionary> _waitingCalls = new ConcurrentDictionary>(); + readonly ConcurrentDictionary>> _waitingCalls = new(); public MqttRpcClient(IMqttClient mqttClient, MqttRpcClientOptions options) { @@ -43,27 +42,7 @@ public void Dispose() _waitingCalls.Clear(); } - public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) - { - using (var timeoutToken = new CancellationTokenSource(timeout)) - { - try - { - return await ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, timeoutToken.Token).ConfigureAwait(false); - } - catch (OperationCanceledException exception) - { - if (timeoutToken.IsCancellationRequested) - { - throw new MqttCommunicationTimedOutException(exception); - } - - throw; - } - } - } - - public async Task ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + public async Task> ExecuteAsync(string methodName, ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(methodName); @@ -94,7 +73,7 @@ public async Task ExecuteAsync(string methodName, byte[] payload, MqttQu try { - var awaitable = new AsyncTaskCompletionSource(); + var awaitable = new AsyncTaskCompletionSource>(); if (!_waitingCalls.TryAdd(responseTopic, awaitable)) { @@ -106,11 +85,7 @@ public async Task ExecuteAsync(string methodName, byte[] payload, MqttQu await _mqttClient.SubscribeAsync(subscribeOptions, cancellationToken).ConfigureAwait(false); await _mqttClient.PublishAsync(requestMessage, cancellationToken).ConfigureAwait(false); - using (cancellationToken.Register( - () => - { - awaitable.TrySetCanceled(); - })) + using (cancellationToken.Register(awaitable.TrySetCanceled)) { return await awaitable.Task.ConfigureAwait(false); } @@ -129,8 +104,8 @@ Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventAr return CompletedTask.Instance; } - var payloadBuffer = eventArgs.ApplicationMessage.Payload.ToArray(); - awaitable.TrySetResult(payloadBuffer); + var payload = eventArgs.ApplicationMessage.Payload; + awaitable.TrySetResult(payload); // Set this message to handled to that other code can avoid execution etc. eventArgs.IsHandled = true; diff --git a/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs b/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs index a0b0ebfee..d7ec6d926 100644 --- a/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs +++ b/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs @@ -2,23 +2,123 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Internal; +using MQTTnet.Protocol; using System; +using System.Buffers; using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; using System.Text; +using System.Threading; using System.Threading.Tasks; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public static class MqttRpcClientExtensions { - public static Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + /// + [Obsolete("Use the method ExecuteTimeoutAsync instead.")] + public static async Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + { + var response = await client.ExecuteTimeoutAsync(timeout, methodName, payload, qualityOfServiceLevel, parameters).ConfigureAwait(false); + return response.ToArray(); + } + + /// + [Obsolete("Use the method ExecuteTimeoutAsync instead.")] + public static async Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, ReadOnlyMemory payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + { + var response = await client.ExecuteTimeoutAsync(timeout, methodName, payload, qualityOfServiceLevel, parameters).ConfigureAwait(false); + return response.ToArray(); + } + + /// + public static Task> ExecuteTimeoutAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return MqttTimeoutAsync(timeout, cancellationToken, linkedCancellationToken + => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, linkedCancellationToken)); + } + + /// + public static Task> ExecuteTimeoutAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, Stream payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return MqttTimeoutAsync(timeout, cancellationToken, linkedCancellationToken + => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, linkedCancellationToken)); + } + + /// + public static Task> ExecuteTimeoutAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, ReadOnlyMemory payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return MqttTimeoutAsync(timeout, cancellationToken, linkedCancellationToken + => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, linkedCancellationToken)); + } + + /// + public static Task> ExecuteTimeoutAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return MqttTimeoutAsync(timeout, cancellationToken, linkedCancellationToken + => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, linkedCancellationToken)); + } + + /// + public static Task> ExecuteTimeoutAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, Func payloadFactory, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return MqttTimeoutAsync(timeout, cancellationToken, linkedCancellationToken + => client.ExecuteAsync(methodName, payloadFactory, qualityOfServiceLevel, parameters, linkedCancellationToken)); + } + + /// + private static async Task MqttTimeoutAsync(TimeSpan timeout, CancellationToken cancellationToken, Func> executor) + { + using var timeoutTokenSource = new CancellationTokenSource(timeout); + using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken); + + try + { + return await executor(linkedTokenSource.Token).ConfigureAwait(false); + } + catch (OperationCanceledException exception) when (timeoutTokenSource.IsCancellationRequested) + { + throw new MqttCommunicationTimedOutException(exception); + } + } + + public static Task> ExecuteAsync(this IMqttRpcClient client, string methodName, ReadOnlyMemory payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) { - if (client == null) throw new ArgumentNullException(nameof(client)); + return client.ExecuteAsync(methodName, new ReadOnlySequence(payload), qualityOfServiceLevel, parameters, cancellationToken); + } - var buffer = Encoding.UTF8.GetBytes(payload ?? string.Empty); + public static Task> ExecuteAsync(this IMqttRpcClient client, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return string.IsNullOrEmpty(payload) + ? client.ExecuteAsync(methodName, ReadOnlySequence.Empty, qualityOfServiceLevel, parameters, cancellationToken) + : client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken); - return client.ExecuteAsync(timeout, methodName, buffer, qualityOfServiceLevel, parameters); + async ValueTask WritePayloadAsync(PipeWriter writer) + { + Encoding.UTF8.GetBytes(payload, writer); + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + + public static Task> ExecuteAsync(this IMqttRpcClient client, string methodName, Stream payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payload); + return client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + await payload.CopyToAsync(writer, cancellationToken).ConfigureAwait(false); + } + } + + public static async Task> ExecuteAsync(this IMqttRpcClient client, string methodName, Func payloadFactory, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(client); + await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(payloadFactory, cancellationToken).ConfigureAwait(false); + return await client.ExecuteAsync(methodName, payloadOwner.Payload, qualityOfServiceLevel, parameters, cancellationToken).ConfigureAwait(false); } } } \ No newline at end of file diff --git a/Source/MQTTnet.Server/Events/ClientConnectedEventArgs.cs b/Source/MQTTnet.Server/Events/ClientConnectedEventArgs.cs index 4249eea6f..9055ebf08 100644 --- a/Source/MQTTnet.Server/Events/ClientConnectedEventArgs.cs +++ b/Source/MQTTnet.Server/Events/ClientConnectedEventArgs.cs @@ -24,7 +24,7 @@ public ClientConnectedEventArgs(MqttConnectPacket connectPacket, MqttProtocolVer SessionItems = sessionItems ?? throw new ArgumentNullException(nameof(sessionItems)); } - public byte[] AuthenticationData => _connectPacket.AuthenticationData; + public ReadOnlyMemory AuthenticationData => _connectPacket.AuthenticationData; public string AuthenticationMethod => _connectPacket.AuthenticationMethod; diff --git a/Source/MQTTnet.Server/Events/ValidatingConnectionEventArgs.cs b/Source/MQTTnet.Server/Events/ValidatingConnectionEventArgs.cs index 325973c7e..a76f242ef 100644 --- a/Source/MQTTnet.Server/Events/ValidatingConnectionEventArgs.cs +++ b/Source/MQTTnet.Server/Events/ValidatingConnectionEventArgs.cs @@ -37,7 +37,7 @@ public ValidatingConnectionEventArgs(MqttConnectPacket connectPacket, IMqttChann /// Gets or sets the authentication data. /// MQTT 5.0.0+ feature. /// - public byte[] AuthenticationData => _connectPacket.AuthenticationData; + public ReadOnlyMemory AuthenticationData => _connectPacket.AuthenticationData; /// /// Gets or sets the authentication method. @@ -93,7 +93,7 @@ public ValidatingConnectionEventArgs(MqttConnectPacket connectPacket, IMqttChann /// public uint MaximumPacketSize => _connectPacket.MaximumPacketSize; - public string Password => Encoding.UTF8.GetString(RawPassword ?? EmptyBuffer.Array); + public string Password => Encoding.UTF8.GetString(RawPassword.AsSpan()); public MqttProtocolVersion ProtocolVersion => ChannelAdapter.PacketFormatterAdapter.ProtocolVersion; diff --git a/Source/MQTTnet.Server/Internal/Formatter/MqttPublishPacketFactory.cs b/Source/MQTTnet.Server/Internal/Formatter/MqttPublishPacketFactory.cs index 03c236639..071d4cf9c 100644 --- a/Source/MQTTnet.Server/Internal/Formatter/MqttPublishPacketFactory.cs +++ b/Source/MQTTnet.Server/Internal/Formatter/MqttPublishPacketFactory.cs @@ -4,6 +4,7 @@ using MQTTnet.Exceptions; using MQTTnet.Packets; +using System.Buffers; namespace MQTTnet.Server.Internal.Formatter; @@ -18,16 +19,10 @@ public static MqttPublishPacket Create(MqttConnectPacket connectPacket) throw new MqttProtocolViolationException("The CONNECT packet contains no will message (WillFlag)."); } - ArraySegment willMessageBuffer = default; - if (connectPacket.WillMessage?.Length > 0) - { - willMessageBuffer = new ArraySegment(connectPacket.WillMessage); - } - var packet = new MqttPublishPacket { Topic = connectPacket.WillTopic, - PayloadSegment = willMessageBuffer, + Payload = new ReadOnlySequence(connectPacket.WillMessage), QualityOfServiceLevel = connectPacket.WillQoS, Retain = connectPacket.WillRetain, ContentType = connectPacket.WillContentType, diff --git a/Source/MQTTnet.Server/MQTTnet.Server.csproj b/Source/MQTTnet.Server/MQTTnet.Server.csproj index df4863607..094fb10f8 100644 --- a/Source/MQTTnet.Server/MQTTnet.Server.csproj +++ b/Source/MQTTnet.Server/MQTTnet.Server.csproj @@ -36,7 +36,7 @@ low enable disable - latest-Recommended + diff --git a/Source/MQTTnet.Server/MqttServerExtensions.cs b/Source/MQTTnet.Server/MqttServerExtensions.cs index c25c89453..1482fa5dd 100644 --- a/Source/MQTTnet.Server/MqttServerExtensions.cs +++ b/Source/MQTTnet.Server/MqttServerExtensions.cs @@ -2,10 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Text; using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; +using System.Buffers; +using System.Collections; +using System.IO.Pipelines; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; namespace MQTTnet.Server; @@ -18,33 +23,177 @@ public static Task DisconnectClientAsync(this MqttServer server, string id, Mqtt return server.DisconnectClientAsync(id, new MqttServerClientDisconnectOptions { ReasonCode = reasonCode }); } + [Obsolete("Use method InjectStringAsync() instead.")] public static Task InjectApplicationMessage( this MqttServer server, string topic, string payload = null, MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, bool retain = false) + { + return server.InjectStringAsync(string.Empty, topic, payload, qualityOfServiceLevel, retain); + } + + public static Task InjectApplicationMessage( + this MqttServer server, + string clientId, + MqttApplicationMessage applicationMessage, + IDictionary customSessionItems = default, + CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(server); + ArgumentNullException.ThrowIfNull(clientId); + ArgumentNullException.ThrowIfNull(applicationMessage); + + var injectedApplicationMessage = new InjectedMqttApplicationMessage(applicationMessage) + { + SenderClientId = clientId, + CustomSessionItems = customSessionItems, + }; + return server.InjectApplicationMessage(injectedApplicationMessage, cancellationToken); + } + + public static Task InjectSequenceAsync( + this MqttServer server, + string clientId, + string topic, + ReadOnlySequence payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { ArgumentNullException.ThrowIfNull(topic); + var applicationMessage = new MqttApplicationMessage + { + Topic = topic, + Payload = payload, + Retain = retain, + QualityOfServiceLevel = qualityOfServiceLevel + }; + + return server.InjectApplicationMessage(clientId, applicationMessage, customSessionItems: null, cancellationToken); + } + + public static async Task InjectSequenceAsync( + this MqttServer server, + string clientId, + string topic, + Func payloadFactory, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payloadFactory); + + await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(payloadFactory, cancellationToken).ConfigureAwait(false); + await server.InjectSequenceAsync(clientId, topic, payloadOwner.Payload, qualityOfServiceLevel, retain, cancellationToken).ConfigureAwait(false); + } + + public static Task InjectBinaryAsync( + this MqttServer server, + string clientId, + string topic, + ReadOnlyMemory payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + return server.InjectSequenceAsync(clientId, topic, new ReadOnlySequence(payload), qualityOfServiceLevel, retain, cancellationToken); + } + + public static async Task InjectBinaryAsync( + this MqttServer server, + string clientId, + string topic, + int payloadSize, + Action> payloadFactory, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + await using var payloadOwner = MqttPayloadOwnerFactory.CreateSingleSegment(payloadSize, payloadFactory); + await server.InjectSequenceAsync(clientId, topic, payloadOwner.Payload, qualityOfServiceLevel, retain, cancellationToken).ConfigureAwait(false); + } + + public static Task InjectStringAsync( + this MqttServer server, + string clientId, + string topic, + string payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + return string.IsNullOrEmpty(payload) + ? server.InjectSequenceAsync(clientId, topic, ReadOnlySequence.Empty, qualityOfServiceLevel, retain, cancellationToken) + : server.InjectSequenceAsync(clientId, topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); - var payloadBuffer = EmptyBuffer.Array; - if (payload is string stringPayload) + async ValueTask WritePayloadAsync(PipeWriter writer) { - payloadBuffer = Encoding.UTF8.GetBytes(stringPayload); + Encoding.UTF8.GetBytes(payload, writer); + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); } + } + + public static Task InjectJsonAsync( + this MqttServer server, + string clientId, + string topic, + TValue payload, + JsonSerializerOptions jsonSerializerOptions = default, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + return server.InjectSequenceAsync(clientId, topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); - return server.InjectApplicationMessage( - new InjectedMqttApplicationMessage( - new MqttApplicationMessage - { - Topic = topic, - PayloadSegment = new ArraySegment(payloadBuffer), - QualityOfServiceLevel = qualityOfServiceLevel, - Retain = retain - })); + async ValueTask WritePayloadAsync(PipeWriter writer) + { + var stream = writer.AsStream(leaveOpen: true); + await JsonSerializer.SerializeAsync(stream, payload, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + } } + + public static Task InjectJsonAsync( + this MqttServer server, + string clientId, + string topic, + TValue payload, + JsonTypeInfo jsonTypeInfo, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(jsonTypeInfo); + return server.InjectSequenceAsync(clientId, topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + var stream = writer.AsStream(leaveOpen: true); + await JsonSerializer.SerializeAsync(stream, payload, jsonTypeInfo, cancellationToken).ConfigureAwait(false); + } + } + + public static Task InjectStreamAsync( + this MqttServer server, + string clientId, + string topic, + Stream payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payload); + return server.InjectSequenceAsync(clientId, topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + await payload.CopyToAsync(writer, cancellationToken).ConfigureAwait(false); + } + } + + public static Task StopAsync(this MqttServer server) { ArgumentNullException.ThrowIfNull(server); diff --git a/Source/MQTTnet.TestApp/PerformanceTest.cs b/Source/MQTTnet.TestApp/PerformanceTest.cs index 644803243..337985e8f 100644 --- a/Source/MQTTnet.TestApp/PerformanceTest.cs +++ b/Source/MQTTnet.TestApp/PerformanceTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; using System.Linq; using System.Net; @@ -199,7 +200,7 @@ static MqttApplicationMessage CreateMessage() return new MqttApplicationMessage { Topic = "A/B/C", - PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes(Payload)), + Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes(Payload)), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }; } @@ -232,7 +233,7 @@ public static async Task RunQoS2Test() var message = new MqttApplicationMessage { Topic = "A/B/C", - PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes("Hello World")), + Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes("Hello World")), QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce }; @@ -283,7 +284,7 @@ public static async Task RunQoS1Test() var message = new MqttApplicationMessage { Topic = "A/B/C", - PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes("Hello World")), + Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes("Hello World")), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }; @@ -334,7 +335,7 @@ public static async Task RunQoS0Test() var message = new MqttApplicationMessage { Topic = "A/B/C", - PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes("Hello World")), + Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes("Hello World")), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }; diff --git a/Source/MQTTnet.TestApp/ServerTest.cs b/Source/MQTTnet.TestApp/ServerTest.cs index 241abcca3..4279126e6 100644 --- a/Source/MQTTnet.TestApp/ServerTest.cs +++ b/Source/MQTTnet.TestApp/ServerTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; using System.Text; @@ -102,7 +103,7 @@ public static async Task RunAsync() { // Replace the payload with the timestamp. But also extending a JSON // based payload with the timestamp is a suitable use case. - e.ApplicationMessage.PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes(DateTime.Now.ToString("O"))); + e.ApplicationMessage.Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes(DateTime.Now.ToString("O"))); } if (e.ApplicationMessage.Topic == "not_allowed_topic") diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index bfd0f8431..d9400be1d 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -100,7 +100,7 @@ public async Task TestLargePacket() connection.Transport = pipe; var ctx = new MqttConnectionContext(serializer, connection); - await ctx.SendPacketAsync(new MqttPublishPacket { PayloadSegment = new byte[20_000] }, CancellationToken.None).ConfigureAwait(false); + await ctx.SendPacketAsync(new MqttPublishPacket { Payload = new ReadOnlySequence(new byte[20_000]) }, CancellationToken.None).ConfigureAwait(false); var readResult = await pipe.Send.Reader.ReadAsync(); Assert.IsTrue(readResult.Buffer.Length > 20000); diff --git a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs b/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs index 6c9cac8f8..9c54173b3 100644 --- a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs +++ b/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs @@ -18,9 +18,9 @@ public void TestTryDeserialize() { var serializer = new MqttPacketFormatterAdapter(MqttProtocolVersion.V311, new MqttBufferWriter(4096, 65535)); - var buffer = serializer.Encode(new MqttPublishPacket { Topic = "a", PayloadSegment = new byte[5] }).Join(); + var buffer = serializer.Encode(new MqttPublishPacket { Topic = "a", Payload = new ReadOnlySequence(new byte[5]) }).Join(); - var sequence = new ReadOnlySequence(buffer.Array, buffer.Offset, buffer.Count); + var sequence = new ReadOnlySequence(buffer); var part = sequence; var consumed = part.Start; diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs index b8b1ba9e2..ae3aff3ed 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs @@ -686,7 +686,7 @@ public async Task Send_Reply_For_Any_Received_Message() async Task Handler1(MqttApplicationMessageReceivedEventArgs eventArgs) { - await client1.PublishStringAsync($"reply/{eventArgs.ApplicationMessage.Topic}"); + await client1.PublishStringAsync($"reply/{eventArgs.ApplicationMessage.Topic}", default); } client1.ApplicationMessageReceivedAsync += Handler1; @@ -710,9 +710,9 @@ Task Handler2(MqttApplicationMessageReceivedEventArgs eventArgs) await Task.Delay(500); - await client2.PublishStringAsync("request/a"); - await client2.PublishStringAsync("request/b"); - await client2.PublishStringAsync("request/c"); + await client2.PublishStringAsync("request/a",default); + await client2.PublishStringAsync("request/b", default); + await client2.PublishStringAsync("request/c", default); await Task.Delay(500); @@ -751,7 +751,7 @@ public async Task Send_Reply_In_Message_Handler() // Use AtMostOnce here because with QoS 1 or even QoS 2 the process waits for // the ACK etc. The problem is that the SpinUntil below only waits until the // flag is set. It does not wait until the client has sent the ACK - await client2.PublishStringAsync("reply"); + await client2.PublishStringAsync("reply", default); } }; diff --git a/Source/MQTTnet.Tests/Clients/MqttClientOptionsBuilder_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClientOptionsBuilder_Tests.cs index 65e03f7d1..6e131a558 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClientOptionsBuilder_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClientOptionsBuilder_Tests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Linq; using System.Text; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -19,7 +20,7 @@ public void WithConnectionUri_Credential_Test() .Build(); Assert.AreEqual("user", options.Credentials.GetUserName(null)); - Assert.IsTrue(Encoding.UTF8.GetBytes("password").SequenceEqual(options.Credentials.GetPassword(null))); + Assert.IsTrue(Encoding.UTF8.GetBytes("password").AsSpan().SequenceEqual(options.Credentials.GetPassword(null))); } } } diff --git a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs index 7db513f58..0d5eb16ec 100644 --- a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs +++ b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -34,7 +35,7 @@ public async Task Inspect_Client_Packets() mqttClient.InspectPacketAsync += eventArgs => { - packets.Add(eventArgs.Direction + ":" + Convert.ToBase64String(eventArgs.Buffer)); + packets.Add(eventArgs.Direction + ":" + Convert.ToBase64String(eventArgs.Buffer.ToArray())); return CompletedTask.Instance; }; diff --git a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs index 02dddbfa8..17dc99173 100644 --- a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs @@ -22,27 +22,23 @@ public sealed class Rpc_Tests : BaseTestClass [TestMethod] public async Task Execute_Success_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = CreateTestEnvironment()) - { - await testEnvironment.StartServer(); - var responseSender = await testEnvironment.ConnectClient(); - await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); + using var testEnvironment = CreateTestEnvironment(); + await testEnvironment.StartServer(); + var responseSender = await testEnvironment.ConnectClient(); + await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); - responseSender.ApplicationMessageReceivedAsync += async e => - { - Assert.IsNull(e.ApplicationMessage.ResponseTopic); - await responseSender.PublishStringAsync(e.ApplicationMessage.Topic + "/response", "pong"); - }; + responseSender.ApplicationMessageReceivedAsync += async e => + { + Assert.IsNull(e.ApplicationMessage.ResponseTopic); + await responseSender.PublishStringAsync(e.ApplicationMessage.Topic + "/response", "pong"); + }; - var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); + var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); - using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) - { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); + using var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); + var response = await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(5), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); - Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); - } - } + Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); } [TestMethod] @@ -63,13 +59,11 @@ public async Task Execute_Success_Parameters_Propagated_Correctly() responseSender.ApplicationMessageReceivedAsync += e => responseSender.PublishStringAsync(e.ApplicationMessage.Topic + "/response", "pong"); - using (var rpcClient = await testEnvironment.ConnectRpcClient(new MqttRpcClientOptionsBuilder() - .WithTopicGenerationStrategy(new TestParametersTopicGenerationStrategy()).Build())) - { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", MqttQualityOfServiceLevel.AtMostOnce, parameters); + using var rpcClient = await testEnvironment.ConnectRpcClient(new MqttRpcClientOptionsBuilder() + .WithTopicGenerationStrategy(new TestParametersTopicGenerationStrategy()).Build()); + var response = await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(5), "ping", "", MqttQualityOfServiceLevel.AtMostOnce, parameters); - Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); - } + Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); } } @@ -132,111 +126,93 @@ public Task Execute_Success_With_QoS_2_MQTT_V5_Use_ResponseTopic() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout() { - using (var testEnvironment = CreateTestEnvironment()) - { - await testEnvironment.StartServer(); + using var testEnvironment = CreateTestEnvironment(); + await testEnvironment.StartServer(); - var requestSender = await testEnvironment.ConnectClient(); + var requestSender = await testEnvironment.ConnectClient(); - var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); - await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); - } + var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); + await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } [TestMethod] [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = new TestEnvironment(TestContext)) - { - await testEnvironment.StartServer(); - var responseSender = await testEnvironment.ConnectClient(); - await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); + using var testEnvironment = new TestEnvironment(TestContext); + await testEnvironment.StartServer(); + var responseSender = await testEnvironment.ConnectClient(); + await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping"); - responseSender.ApplicationMessageReceivedAsync += async e => - { - Assert.IsNull(e.ApplicationMessage.ResponseTopic); - await CompletedTask.Instance; - }; + responseSender.ApplicationMessageReceivedAsync += async e => + { + Assert.IsNull(e.ApplicationMessage.ResponseTopic); + await CompletedTask.Instance; + }; - var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); + var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); - using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) - { - await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); - } - } + using var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); + await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } [TestMethod] [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_With_Custom_Topic_Names() { - using (var testEnvironment = CreateTestEnvironment()) - { - await testEnvironment.StartServer(); + using var testEnvironment = CreateTestEnvironment(); + await testEnvironment.StartServer(); - var rpcClient = await testEnvironment.ConnectRpcClient(new MqttRpcClientOptionsBuilder().WithTopicGenerationStrategy(new TestTopicStrategy()).Build()); + var rpcClient = await testEnvironment.ConnectRpcClient(new MqttRpcClientOptionsBuilder().WithTopicGenerationStrategy(new TestTopicStrategy()).Build()); - await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); - } + await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(2), "ping", "", MqttQualityOfServiceLevel.AtMostOnce); } [TestMethod] public void Use_Factory() { var factory = new MqttClientFactory(); - using (var client = factory.CreateMqttClient()) - { - var rpcClient = factory.CreateMqttRpcClient(client); + using var client = factory.CreateMqttClient(); + var rpcClient = factory.CreateMqttRpcClient(client); - Assert.IsNotNull(rpcClient); - } + Assert.IsNotNull(rpcClient); } async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersion protocolVersion) { - using (var testEnvironment = CreateTestEnvironment()) - { - await testEnvironment.StartServer(); - var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); - await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping", qosLevel); + using var testEnvironment = CreateTestEnvironment(); + await testEnvironment.StartServer(); + var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); + await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping", qosLevel); - responseSender.ApplicationMessageReceivedAsync += e => responseSender.PublishStringAsync(e.ApplicationMessage.Topic + "/response", "pong"); + responseSender.ApplicationMessageReceivedAsync += e => responseSender.PublishStringAsync(e.ApplicationMessage.Topic + "/response", "pong"); - var requestSender = await testEnvironment.ConnectClient(); + var requestSender = await testEnvironment.ConnectClient(); - using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) - { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", qosLevel); + using var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); + var response = await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(5), "ping", "", qosLevel); - Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); - } - } + Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); } async Task Execute_Success_MQTT_V5(MqttQualityOfServiceLevel qosLevel) { - using (var testEnvironment = CreateTestEnvironment()) - { - await testEnvironment.StartServer(); - var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); - await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping", qosLevel); + using var testEnvironment = CreateTestEnvironment(); + await testEnvironment.StartServer(); + var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); + await responseSender.SubscribeAsync("MQTTnet.RPC/+/ping", qosLevel); - responseSender.ApplicationMessageReceivedAsync += async e => - { - await responseSender.PublishStringAsync(e.ApplicationMessage.ResponseTopic, "pong"); - }; + responseSender.ApplicationMessageReceivedAsync += async e => + { + await responseSender.PublishStringAsync(e.ApplicationMessage.ResponseTopic, "pong"); + }; - var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); + var requestSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); - using (var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build())) - { - var response = await rpcClient.ExecuteAsync(TimeSpan.FromSeconds(5), "ping", "", qosLevel); + using var rpcClient = new MqttRpcClient(requestSender, new MqttRpcClientOptionsBuilder().Build()); + var response = await rpcClient.ExecuteTimeoutAsync(TimeSpan.FromSeconds(5), "ping", "", qosLevel); - Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); - } - } + Assert.AreEqual("pong", Encoding.UTF8.GetString(response)); } class TestTopicStrategy : IMqttRpcClientTopicGenerationStrategy diff --git a/Source/MQTTnet.Tests/Formatter/MqttBufferReader_Tests.cs b/Source/MQTTnet.Tests/Formatter/MqttBufferReader_Tests.cs index 84fb48e04..2c8b6665e 100644 --- a/Source/MQTTnet.Tests/Formatter/MqttBufferReader_Tests.cs +++ b/Source/MQTTnet.Tests/Formatter/MqttBufferReader_Tests.cs @@ -2,12 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.IO; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Exceptions; using MQTTnet.Formatter; +using System; +using System.Collections.Generic; +using System.IO; namespace MQTTnet.Tests.Formatter { @@ -22,7 +22,7 @@ public void Fire_Exception_If_Not_Enough_Data() var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, 3); + reader.SetBuffer(buffer.AsMemory(0, 3)); // 1 byte is missing. reader.ReadFourByteInteger(); @@ -36,7 +36,7 @@ public void Fire_Exception_If_Not_Enough_Data_With_Longer_Buffer() var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, 3); + reader.SetBuffer(buffer.AsMemory(0, 3)); // 1 byte is missing. reader.ReadFourByteInteger(); @@ -59,7 +59,7 @@ public void Read_Remaining_Data_From_Larger_Buffer() var reader = new MqttBufferReader(); // The used buffer contains more data than used! - reader.SetBuffer(buffer, 0, 5); + reader.SetBuffer(buffer.AsMemory(0, 5)); // This should only read 5 bytes even if more data is in the buffer // due to custom bounds. @@ -127,29 +127,28 @@ public void Read_Various_Positions_and_Offsets() elementBytes = BitConverter.GetBytes(uintValue); break; case ElementReference.BufferElementType.VariableSizeInt: - { - elementNumberValue = (uint)i; - var writer = new MqttBufferWriter(4, 4); - writer.WriteVariableByteInteger(elementNumberValue); - elementSize = writer.Length; - elementBytes = new byte[elementSize]; - var buffer = writer.GetBuffer(); - Array.Copy(buffer, elementBytes, elementSize); - alreadyBigEndian = true; // nothing to swap - } + { + elementNumberValue = (uint)i; + var writer = new MqttBufferWriter(4, 4); + writer.WriteVariableByteInteger(elementNumberValue); + elementSize = writer.Length; + elementBytes = new byte[elementSize]; + writer.GetWrittenSpan().CopyTo(elementBytes); + + alreadyBigEndian = true; // nothing to swap + } break; case ElementReference.BufferElementType.String: - { - var stringLen = rnd.Next(TestString.Length); - elementStringValue = TestString.Substring(0, stringLen); // could be empty - var writer = new MqttBufferWriter(stringLen + 1, stringLen + 1); - writer.WriteString(elementStringValue); - elementSize = writer.Length; - elementBytes = new byte[elementSize]; - var buffer = writer.GetBuffer(); - Array.Copy(buffer, elementBytes, elementSize); - alreadyBigEndian = true; // nothing to swap - } + { + var stringLen = rnd.Next(TestString.Length); + elementStringValue = TestString.Substring(0, stringLen); // could be empty + var writer = new MqttBufferWriter(stringLen + 1, stringLen + 1); + writer.WriteString(elementStringValue); + elementSize = writer.Length; + elementBytes = new byte[elementSize]; + writer.GetWrittenSpan().CopyTo(elementBytes); + alreadyBigEndian = true; // nothing to swap + } break; } @@ -180,7 +179,7 @@ public void Read_Various_Positions_and_Offsets() var segmentLength = segmentEndPosition - segmentStartPosition; var reader = new MqttBufferReader(); - reader.SetBuffer(elementBuffer, segmentStartPosition, segmentLength); + reader.SetBuffer(elementBuffer.AsMemory(segmentStartPosition, segmentLength)); // read all elements in the buffer segment; values should be as expected for (var n = 0; n < elementCount; n++) @@ -195,29 +194,29 @@ public void Read_Various_Positions_and_Offsets() switch (element.Type) { case ElementReference.BufferElementType.Byte: - { - elementNumberValue = reader.ReadByte(); - } + { + elementNumberValue = reader.ReadByte(); + } break; case ElementReference.BufferElementType.TwoByteInt: - { - elementNumberValue = reader.ReadTwoByteInteger(); - } + { + elementNumberValue = reader.ReadTwoByteInteger(); + } break; case ElementReference.BufferElementType.FourByteInt: - { - elementNumberValue = reader.ReadFourByteInteger(); - } + { + elementNumberValue = reader.ReadFourByteInteger(); + } break; case ElementReference.BufferElementType.VariableSizeInt: - { - elementNumberValue = reader.ReadVariableByteInteger(); - } + { + elementNumberValue = reader.ReadVariableByteInteger(); + } break; case ElementReference.BufferElementType.String: - { - elementStringValue = reader.ReadString(); - } + { + elementStringValue = reader.ReadString(); + } break; } @@ -243,7 +242,7 @@ public void Report_Correct_Length_For_Full_Buffer() var buffer = new byte[] { 5, 6, 7, 8, 9 }; var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, 5); + reader.SetBuffer(buffer.AsMemory(0, 5)); Assert.IsFalse(reader.EndOfStream); Assert.AreEqual(5, reader.BytesLeft); @@ -257,7 +256,7 @@ public void Report_Correct_Length_For_Partial_End_Buffer() var reader = new MqttBufferReader(); // The used buffer contains more data than used! - reader.SetBuffer(buffer, 5, 5); + reader.SetBuffer(buffer.AsMemory(5, 5)); Assert.IsFalse(reader.EndOfStream); Assert.AreEqual(5, reader.BytesLeft); @@ -271,7 +270,7 @@ public void Report_Correct_Length_For_Partial_Start_Buffer() var reader = new MqttBufferReader(); // The used buffer contains more data than used! - reader.SetBuffer(buffer, 0, 5); + reader.SetBuffer(buffer.AsMemory(0, 5)); Assert.IsFalse(reader.EndOfStream); Assert.AreEqual(5, reader.BytesLeft); diff --git a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Binary_Tests.cs b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Binary_Tests.cs index 8343e2f5a..230411c40 100644 --- a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Binary_Tests.cs +++ b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Binary_Tests.cs @@ -111,7 +111,7 @@ public void DeserializeV311_MqttPublishPacket() PacketIdentifier = 123, Dup = true, Retain = true, - PayloadSegment = new ArraySegment(Encoding.ASCII.GetBytes("HELLO")), + Payload = new ReadOnlySequence(Encoding.ASCII.GetBytes("HELLO")), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, Topic = "A/B/C" }; @@ -318,7 +318,7 @@ public void Serialize_LargePacket() var publishPacket = new MqttPublishPacket { Topic = "abcdefghijklmnopqrstuvwxyz0123456789", - PayloadSegment = new ArraySegment(payload) + Payload = new ReadOnlySequence(payload) }; var serializationHelper = new MqttPacketSerializationHelper(); @@ -331,7 +331,7 @@ public void Serialize_LargePacket() CollectionAssert.AreEqual(publishPacket.Payload.ToArray(), publishPacketCopy.Payload.ToArray()); // Now modify the payload and test again. - publishPacket.PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes("MQTT")); + publishPacket.Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes("MQTT")); buffer = serializationHelper.Encode(publishPacket); var publishPacketCopy2 = serializationHelper.Decode(buffer) as MqttPublishPacket; @@ -462,7 +462,7 @@ public void SerializeV311_MqttPublishPacket() PacketIdentifier = 123, Dup = true, Retain = true, - PayloadSegment = new ArraySegment(Encoding.ASCII.GetBytes("HELLO")), + Payload = new ReadOnlySequence(Encoding.ASCII.GetBytes("HELLO")), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, Topic = "A/B/C" }; diff --git a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Tests.cs b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Tests.cs index 92ee5aa6a..70f74b6ac 100644 --- a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Tests.cs +++ b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V3_Tests.cs @@ -57,7 +57,7 @@ public void Serialize_Full_MqttConnAckPacket_V311() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(connAckPacket, MqttProtocolVersion.V311); - CollectionAssert.AreEqual(null, deserialized.AuthenticationData); // Not supported in v3.1.1 + Assert.AreEqual(true, deserialized.AuthenticationData.IsEmpty); // Not supported in v3.1.1 Assert.AreEqual(null, deserialized.AuthenticationMethod); // Not supported in v3.1.1 //Assert.AreEqual(connAckPacket.ReasonCode, deserialized.ReasonCode); Assert.AreEqual(null, deserialized.ReasonString); // Not supported in v3.1.1 @@ -78,7 +78,7 @@ public void Serialize_Full_MqttConnAckPacket_V311() Assert.AreEqual(false, deserialized.WildcardSubscriptionAvailable); Assert.IsNull(deserialized.UserProperties); // Not supported in v3.1.1 } - + [TestMethod] public void Serialize_Full_MqttConnAckPacket_V310() { @@ -111,7 +111,7 @@ public void Serialize_Full_MqttConnAckPacket_V310() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(connAckPacket, MqttProtocolVersion.V310); - CollectionAssert.AreEqual(null, deserialized.AuthenticationData); // Not supported in v3.1.1 + Assert.AreEqual(true, deserialized.AuthenticationData.IsEmpty); // Not supported in v3.1.1 Assert.AreEqual(null, deserialized.AuthenticationMethod); // Not supported in v3.1.1 //Assert.AreEqual(connAckPacket.ReasonCode, deserialized.ReasonCode); Assert.AreEqual(null, deserialized.ReasonString); // Not supported in v3.1.1 @@ -175,15 +175,16 @@ public void Serialize_Full_MqttConnectPacket_V311() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(connectPacket, MqttProtocolVersion.V311); Assert.AreEqual(connectPacket.Username, deserialized.Username); - CollectionAssert.AreEqual(connectPacket.Password, deserialized.Password); + + Assert.IsTrue(connectPacket.Password.SequenceEqual(deserialized.Password)); Assert.AreEqual(connectPacket.ClientId, deserialized.ClientId); - CollectionAssert.AreEqual(null, deserialized.AuthenticationData); // Not supported in v3.1.1 + Assert.AreEqual(true, deserialized.AuthenticationData.IsEmpty); // Not supported in v3.1.1 Assert.AreEqual(null, deserialized.AuthenticationMethod); // Not supported in v3.1.1 - Assert.AreEqual(connectPacket.CleanSession, deserialized.CleanSession); + Assert.AreEqual(connectPacket.CleanSession, deserialized.CleanSession); Assert.AreEqual(0L, deserialized.ReceiveMaximum); // Not supported in v3.1.1 Assert.AreEqual(connectPacket.WillFlag, deserialized.WillFlag); Assert.AreEqual(connectPacket.WillTopic, deserialized.WillTopic); - CollectionAssert.AreEqual(connectPacket.WillMessage, deserialized.WillMessage); + Assert.IsTrue(connectPacket.WillMessage.Span.SequenceEqual(deserialized.WillMessage.Span)); Assert.AreEqual(connectPacket.WillRetain, deserialized.WillRetain); Assert.AreEqual(connectPacket.KeepAlivePeriod, deserialized.KeepAlivePeriod); // MaximumPacketSize not available in MQTTv3. @@ -299,7 +300,7 @@ public void Serialize_Full_MqttPublishPacket_V311() PacketIdentifier = 123, Dup = true, Retain = true, - PayloadSegment = new ArraySegment(Encoding.ASCII.GetBytes("Payload")), + Payload = new ReadOnlySequence(Encoding.ASCII.GetBytes("Payload")), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, Topic = "Topic", ResponseTopic = "/Response", @@ -328,7 +329,7 @@ public void Serialize_Full_MqttPublishPacket_V311() Assert.AreEqual(publishPacket.Topic, deserialized.Topic); Assert.AreEqual(null, deserialized.ResponseTopic); // Not supported in v3.1.1. Assert.AreEqual(null, deserialized.ContentType); // Not supported in v3.1.1. - CollectionAssert.AreEqual(null, deserialized.CorrelationData); // Not supported in v3.1.1. + Assert.AreEqual(true, deserialized.CorrelationData.IsEmpty); // Not supported in v3.1.1. Assert.AreEqual(0U, deserialized.TopicAlias); // Not supported in v3.1.1. CollectionAssert.AreEqual(null, deserialized.SubscriptionIdentifiers); // Not supported in v3.1.1 Assert.AreEqual(0U, deserialized.MessageExpiryInterval); // Not supported in v3.1.1 @@ -400,7 +401,7 @@ public void Serialize_Full_MqttSubAckPacket_V311() }; var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(subAckPacket, MqttProtocolVersion.V311); - + Assert.AreEqual(subAckPacket.PacketIdentifier, deserialized.PacketIdentifier); Assert.AreEqual(null, deserialized.ReasonString); // Not supported in v3.1.1 Assert.AreEqual(subAckPacket.ReasonCodes.Count, deserialized.ReasonCodes.Count); diff --git a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V5_Tests.cs b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V5_Tests.cs index f74ba975a..342a655fb 100644 --- a/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V5_Tests.cs +++ b/Source/MQTTnet.Tests/Formatter/MqttPacketSerialization_V5_Tests.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; +using System.Linq; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -38,7 +39,7 @@ public void Serialize_Full_MqttAuthPacket_V500() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(authPacket, MqttProtocolVersion.V500); - CollectionAssert.AreEqual(authPacket.AuthenticationData, deserialized.AuthenticationData); + Assert.IsTrue(authPacket.AuthenticationData.Span.SequenceEqual(deserialized.AuthenticationData.Span)); Assert.AreEqual(authPacket.AuthenticationMethod, deserialized.AuthenticationMethod); Assert.AreEqual(authPacket.ReasonCode, deserialized.ReasonCode); Assert.AreEqual(authPacket.ReasonString, deserialized.ReasonString); @@ -74,7 +75,7 @@ public void Serialize_Full_MqttConnAckPacket_V500() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(connAckPacket, MqttProtocolVersion.V500); - CollectionAssert.AreEqual(connAckPacket.AuthenticationData, deserialized.AuthenticationData); + Assert.IsTrue(connAckPacket.AuthenticationData.Span.SequenceEqual(deserialized.AuthenticationData.Span)); Assert.AreEqual(connAckPacket.AuthenticationMethod, deserialized.AuthenticationMethod); Assert.AreEqual(connAckPacket.ReasonCode, deserialized.ReasonCode); Assert.AreEqual(connAckPacket.ReasonString, deserialized.ReasonString); @@ -132,15 +133,15 @@ public void Serialize_Full_MqttConnectPacket_V500() var deserialized = MqttPacketSerializationHelper.EncodeAndDecodePacket(connectPacket, MqttProtocolVersion.V500); Assert.AreEqual(connectPacket.Username, deserialized.Username); - CollectionAssert.AreEqual(connectPacket.Password, deserialized.Password); + Assert.IsTrue(connectPacket.Password.SequenceEqual(deserialized.Password)); Assert.AreEqual(connectPacket.ClientId, deserialized.ClientId); - CollectionAssert.AreEqual(connectPacket.AuthenticationData, deserialized.AuthenticationData); + Assert.IsTrue(connectPacket.AuthenticationData.Span.SequenceEqual(deserialized.AuthenticationData.Span)); Assert.AreEqual(connectPacket.AuthenticationMethod, deserialized.AuthenticationMethod); Assert.AreEqual(connectPacket.CleanSession, deserialized.CleanSession); Assert.AreEqual(connectPacket.ReceiveMaximum, deserialized.ReceiveMaximum); Assert.AreEqual(connectPacket.WillFlag, deserialized.WillFlag); Assert.AreEqual(connectPacket.WillTopic, deserialized.WillTopic); - CollectionAssert.AreEqual(connectPacket.WillMessage, deserialized.WillMessage); + Assert.IsTrue(connectPacket.WillMessage.Span.SequenceEqual(deserialized.WillMessage.Span)); Assert.AreEqual(connectPacket.WillRetain, deserialized.WillRetain); Assert.AreEqual(connectPacket.KeepAlivePeriod, deserialized.KeepAlivePeriod); Assert.AreEqual(connectPacket.MaximumPacketSize, deserialized.MaximumPacketSize); @@ -149,7 +150,7 @@ public void Serialize_Full_MqttConnectPacket_V500() Assert.AreEqual(connectPacket.SessionExpiryInterval, deserialized.SessionExpiryInterval); Assert.AreEqual(connectPacket.TopicAliasMaximum, deserialized.TopicAliasMaximum); Assert.AreEqual(connectPacket.WillContentType, deserialized.WillContentType); - CollectionAssert.AreEqual(connectPacket.WillCorrelationData, deserialized.WillCorrelationData); + Assert.IsTrue(connectPacket.WillCorrelationData.Span.SequenceEqual(deserialized.WillCorrelationData.Span)); Assert.AreEqual(connectPacket.WillDelayInterval, deserialized.WillDelayInterval); Assert.AreEqual(connectPacket.WillQoS, deserialized.WillQoS); Assert.AreEqual(connectPacket.WillResponseTopic, deserialized.WillResponseTopic); @@ -246,7 +247,7 @@ public void Serialize_Full_MqttPublishPacket_V500() PacketIdentifier = 123, Dup = true, Retain = true, - PayloadSegment = new ArraySegment("Payload"u8.ToArray()), + Payload = new ReadOnlySequence("Payload"u8.ToArray()), QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce, Topic = "Topic", ResponseTopic = "/Response", @@ -269,7 +270,7 @@ public void Serialize_Full_MqttPublishPacket_V500() Assert.AreEqual(publishPacket.Topic, deserialized.Topic); Assert.AreEqual(publishPacket.ResponseTopic, deserialized.ResponseTopic); Assert.AreEqual(publishPacket.ContentType, deserialized.ContentType); - CollectionAssert.AreEqual(publishPacket.CorrelationData, deserialized.CorrelationData); + Assert.IsTrue(publishPacket.CorrelationData.Span.SequenceEqual(deserialized.CorrelationData.Span)); Assert.AreEqual(publishPacket.TopicAlias, deserialized.TopicAlias); CollectionAssert.AreEqual(publishPacket.SubscriptionIdentifiers, deserialized.SubscriptionIdentifiers); Assert.AreEqual(publishPacket.MessageExpiryInterval, deserialized.MessageExpiryInterval); diff --git a/Source/MQTTnet.Tests/Formatter/MqttPacketWriter_Tests.cs b/Source/MQTTnet.Tests/Formatter/MqttPacketWriter_Tests.cs index 7a16ff30a..68d673f5c 100644 --- a/Source/MQTTnet.Tests/Formatter/MqttPacketWriter_Tests.cs +++ b/Source/MQTTnet.Tests/Formatter/MqttPacketWriter_Tests.cs @@ -36,10 +36,10 @@ public void Use_All_Data_Types() writer.WriteVariableByteInteger(1234U); writer.WriteVariableByteInteger(9876U); - var buffer = writer.GetBuffer(); + var buffer = writer.GetWrittenMemory(); var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, writer.Length); + reader.SetBuffer(buffer); Assert.AreEqual("AString", reader.ReadString()); Assert.IsTrue(reader.ReadByte() == 1); @@ -48,13 +48,13 @@ public void Use_All_Data_Types() Assert.AreEqual(1234U, reader.ReadVariableByteInteger()); Assert.AreEqual(9876U, reader.ReadVariableByteInteger()); } - + [TestMethod] [ExpectedException(typeof(MqttProtocolViolationException))] public void Throw_If_String_Too_Long() { var writer = new MqttBufferWriter(4096, 65535); - + writer.WriteString(string.Empty.PadLeft(65536)); } @@ -75,14 +75,12 @@ public void Write_And_Read_Multiple_Times() writer.WriteString("fjgffiogfhgfhoihgoireghreghreguhreguireoghreouighreouighreughreguiorehreuiohruiorehreuioghreug"); writer.WriteBinary(new byte[3]); - var readPayload = new ArraySegment(writer.GetBuffer(), 0, writer.Length).ToArray(); - - var reader = new MqttBufferReader(); - reader.SetBuffer(readPayload, 0, readPayload.Length); + var readPayload = writer.GetWrittenMemory(); for (var i = 0; i < 100000; i++) { - reader.Seek(0); + var reader = new MqttBufferReader(); + reader.SetBuffer(readPayload); reader.ReadString(); reader.ReadBinaryData(); diff --git a/Source/MQTTnet.Tests/Helpers/MqttPacketWriterExtensions.cs b/Source/MQTTnet.Tests/Helpers/MqttPacketWriterExtensions.cs index 1c6e1bf11..7153d1aef 100644 --- a/Source/MQTTnet.Tests/Helpers/MqttPacketWriterExtensions.cs +++ b/Source/MQTTnet.Tests/Helpers/MqttPacketWriterExtensions.cs @@ -4,6 +4,7 @@ using MQTTnet.Formatter; using MQTTnet.Protocol; +using System; namespace MQTTnet.Tests.Helpers { @@ -13,8 +14,8 @@ public static byte[] AddMqttHeader(this MqttBufferWriter writer, MqttControlPack { writer.WriteByte(MqttBufferWriter.BuildFixedHeader(header)); writer.WriteVariableByteInteger((uint)body.Length); - writer.WriteBinary(body, 0, body.Length); - return writer.GetBuffer(); + writer.Write(body); + return writer.GetWrittenMemory().ToArray(); } } } diff --git a/Source/MQTTnet.Tests/Internal/MqttPayloadOwnerFactory_Test.cs b/Source/MQTTnet.Tests/Internal/MqttPayloadOwnerFactory_Test.cs new file mode 100644 index 000000000..cff1553a9 --- /dev/null +++ b/Source/MQTTnet.Tests/Internal/MqttPayloadOwnerFactory_Test.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; +using System; +using System.Buffers; +using System.IO; +using System.Linq; +using System.Threading.Tasks; + +namespace MQTTnet.Tests.Internal +{ + [TestClass] + public class MqttPayloadOwnerFactory_Test + { + [TestMethod] + public async Task CreateSingleSegmentTest() + { + var size = 10; + var buffer = new byte[size]; + Random.Shared.NextBytes(buffer); + + await using var owner = MqttPayloadOwnerFactory.CreateSingleSegment(size, payload => + { + buffer.AsSpan().CopyTo(payload.Span); + }); + + Assert.AreEqual(size, owner.Payload.Length); + Assert.IsTrue(buffer.AsSpan().SequenceEqual(owner.Payload.ToArray())); + } + + [TestMethod] + public async Task CreateMultipleSegment_1x4K_Test() + { + await CreateMultipleSegmentTest(x4K: 1); + } + + [TestMethod] + public async Task CreateMultipleSegment_2x4K_Test() + { + await CreateMultipleSegmentTest(x4K: 2); + } + + [TestMethod] + public async Task CreateMultipleSegment_4x4K_Test() + { + await CreateMultipleSegmentTest(x4K: 4); + } + + [TestMethod] + public async Task CreateMultipleSegment_8x4K_Test() + { + await CreateMultipleSegmentTest(x4K: 8); + } + + private async Task CreateMultipleSegmentTest(int x4K) + { + const int size4K = 4096; + var stream = new MemoryStream(); + await using var owner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(async writer => + { + for (var i = 0; i < x4K; i++) + { + var memory = writer.GetMemory(size4K)[..size4K]; + Random.Shared.NextBytes(memory.Span); + + writer.Advance(memory.Length); + await stream.WriteAsync(memory); + } + await writer.FlushAsync(); + }); + + var buffer = stream.ToArray(); + var buffer2 = owner.Payload.ToArray(); + + Assert.AreEqual(size4K * x4K, buffer.Length); + Assert.IsTrue(buffer.SequenceEqual(buffer2)); + } + } +} diff --git a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj index c89d8057b..e0fc8ee21 100644 --- a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj +++ b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj @@ -11,7 +11,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs index 7b7570a99..25fe4a696 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Buffers; using System.Collections.Generic; using System.Linq; @@ -360,7 +361,7 @@ public async Task Publish_And_Receive_New_Properties() Assert.AreEqual(applicationMessage.ContentType, receivedMessage.ContentType); Assert.AreEqual(applicationMessage.ResponseTopic, receivedMessage.ResponseTopic); Assert.AreEqual(applicationMessage.MessageExpiryInterval, receivedMessage.MessageExpiryInterval); - CollectionAssert.AreEqual(applicationMessage.CorrelationData, receivedMessage.CorrelationData); + Assert.IsTrue(applicationMessage.CorrelationData.Span.SequenceEqual(receivedMessage.CorrelationData.Span)); CollectionAssert.AreEqual(applicationMessage.Payload.ToArray(), receivedMessage.Payload.ToArray()); CollectionAssert.AreEqual(applicationMessage.UserProperties, receivedMessage.UserProperties); } diff --git a/Source/MQTTnet.Tests/MqttApplicationMessageExtensions_Test.cs b/Source/MQTTnet.Tests/MqttApplicationMessageExtensions_Test.cs new file mode 100644 index 000000000..3ff019e23 --- /dev/null +++ b/Source/MQTTnet.Tests/MqttApplicationMessageExtensions_Test.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Text.Json; + +namespace MQTTnet.Tests +{ + [TestClass] + public sealed class MqttApplicationMessageExtensions_Test + { + [TestMethod] + public void ConvertPayloadToJson_Test() + { + var input = new TValue { IntValue = 10, StrinValue = nameof(TValue) }; + + var message = new MqttApplicationMessageBuilder() + .WithTopic("Abc") + .WithPayload(JsonSerializer.SerializeToUtf8Bytes(input)) + .Build(); + + var output = message.ConvertPayloadToJson(); + + Assert.AreEqual(input.IntValue, output.IntValue); + Assert.AreEqual(input.StrinValue, output.StrinValue); + } + + private class TValue + { + public int IntValue { get; set; } + + public string StrinValue { get; set; } + } + } +} diff --git a/Source/MQTTnet.Tests/Protocol_Tests.cs b/Source/MQTTnet.Tests/Protocol_Tests.cs index 6f5146c8d..988c306fd 100644 --- a/Source/MQTTnet.Tests/Protocol_Tests.cs +++ b/Source/MQTTnet.Tests/Protocol_Tests.cs @@ -4,6 +4,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Formatter; +using System; namespace MQTTnet.Tests { @@ -19,10 +20,10 @@ public void Encode_Four_Byte_Integer() { writer.WriteVariableByteInteger(value); - var buffer = writer.GetBuffer(); + var buffer = writer.GetWrittenMemory(); var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, writer.Length); + reader.SetBuffer(buffer); var checkValue = reader.ReadVariableByteInteger(); Assert.AreEqual(value, checkValue); @@ -40,10 +41,10 @@ public void Encode_Two_Byte_Integer() { writer.WriteTwoByteInteger(value); - var buffer = writer.GetBuffer(); + var buffer = writer.GetWrittenMemory(); var reader = new MqttBufferReader(); - reader.SetBuffer(buffer, 0, writer.Length); + reader.SetBuffer(buffer); var checkValue = reader.ReadTwoByteInteger(); Assert.AreEqual(value, checkValue); diff --git a/Source/MQTTnet.Tests/Server/General.cs b/Source/MQTTnet.Tests/Server/General.cs index 45cff1983..da5a11d16 100644 --- a/Source/MQTTnet.Tests/Server/General.cs +++ b/Source/MQTTnet.Tests/Server/General.cs @@ -225,7 +225,7 @@ public async Task Handle_Lots_Of_Parallel_Retained_Messages() // Clear retained message. await client.PublishAsync( new MqttApplicationMessageBuilder().WithTopic("r" + i2) - .WithPayload(EmptyBuffer.Array) + .WithPayload(ReadOnlyMemory.Empty) .WithRetainFlag() .WithQualityOfServiceLevel(MqttQualityOfServiceLevel.AtLeastOnce) .Build()); @@ -304,7 +304,7 @@ public async Task Intercept_Message() var server = await testEnvironment.StartServer(); server.InterceptingPublishAsync += e => { - e.ApplicationMessage.PayloadSegment = new ArraySegment(Encoding.ASCII.GetBytes("extended")); + e.ApplicationMessage.Payload = new ReadOnlySequence(Encoding.ASCII.GetBytes("extended")); return CompletedTask.Instance; }; @@ -366,7 +366,7 @@ public async Task No_Messages_If_No_Subscription() client.ConnectedAsync += async e => { - await client.PublishStringAsync("Connected"); + await client.PublishStringAsync("Connected", default); }; client.ApplicationMessageReceivedAsync += e => @@ -381,7 +381,7 @@ public async Task No_Messages_If_No_Subscription() await Task.Delay(500); - await client.PublishStringAsync("Hello"); + await client.PublishStringAsync("Hello", default); await Task.Delay(500); @@ -426,7 +426,7 @@ await server.InjectApplicationMessage( new MqttApplicationMessage { Topic = "/test/1", - PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes("true")), + Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes("true")), QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce }) { @@ -624,7 +624,7 @@ public async Task Same_Client_Id_Connect_Disconnect_Event_Order() await c2.SubscribeAsync("topic"); // r - await c2.PublishStringAsync("topic"); + await c2.PublishStringAsync("topic", default); await LongTestDelay(); flow = string.Join(string.Empty, events); @@ -720,7 +720,7 @@ public async Task Same_Client_Id_Refuse_Connection() await Task.Delay(500); - c1.PublishStringAsync("topic").Wait(); + c1.PublishStringAsync("topic", default).Wait(); await Task.Delay(500); @@ -742,7 +742,7 @@ public async Task Same_Client_Id_Refuse_Connection() flow = string.Join(string.Empty, events); Assert.AreEqual("cr", flow); - c1.PublishStringAsync("topic").Wait(); + c1.PublishStringAsync("topic", default).Wait(); await Task.Delay(500); @@ -796,6 +796,31 @@ public async Task Send_Long_Body() } } + [TestMethod] + public async Task Send_Json_Body() + { + using var testEnvironment = CreateTestEnvironment(); + string receivedBody = null; + + await testEnvironment.StartServer(); + + var client1 = await testEnvironment.ConnectClient(); + client1.ApplicationMessageReceivedAsync += e => + { + receivedBody = e.ApplicationMessage.ConvertPayloadToString(); + return CompletedTask.Instance; + }; + + await client1.SubscribeAsync("string"); + + var client2 = await testEnvironment.ConnectClient(); + await client2.PublishJsonAsync("string", true); + + await Task.Delay(TimeSpan.FromSeconds(5)); + + Assert.AreEqual("true", receivedBody); + } + [TestMethod] public async Task Set_Subscription_At_Server() { @@ -824,11 +849,11 @@ public async Task Set_Subscription_At_Server() await Task.Delay(500); - await client.PublishStringAsync("Hello"); + await client.PublishStringAsync("Hello", default); await Task.Delay(100); Assert.AreEqual(0, receivedMessages.Count); - await client.PublishStringAsync("topic1"); + await client.PublishStringAsync("topic1", default); await Task.Delay(100); Assert.AreEqual(1, receivedMessages.Count); } @@ -956,7 +981,7 @@ public async Task Disconnect_Client_with_Reason() { if (e.Buffer.Length > 0) { - if (e.Buffer[0] == (byte)MqttControlPacketType.Disconnect << 4) + if (e.Buffer.FirstSpan[0] == (byte)MqttControlPacketType.Disconnect << 4) { disconnectPacketReceived = true; } diff --git a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs index e829e2573..60264c894 100644 --- a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs @@ -27,7 +27,7 @@ public async Task Disconnect_While_Publishing() server.InterceptingPublishAsync += ev => server.DisconnectClientAsync(ev.ClientId, MqttDisconnectReasonCode.NormalDisconnection); var client = await testEnvironment.ConnectClient(); - await client.PublishStringAsync("test", qualityOfServiceLevel: MqttQualityOfServiceLevel.AtLeastOnce); + await client.PublishStringAsync("test",null, qualityOfServiceLevel: MqttQualityOfServiceLevel.AtLeastOnce); } } diff --git a/Source/MQTTnet.Tests/Server/QoS_Tests.cs b/Source/MQTTnet.Tests/Server/QoS_Tests.cs index b48a1a99c..86ab8d7f3 100644 --- a/Source/MQTTnet.Tests/Server/QoS_Tests.cs +++ b/Source/MQTTnet.Tests/Server/QoS_Tests.cs @@ -32,7 +32,7 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_0() await client1.SubscribeAsync("A"); var client2 = await testEnvironment.ConnectClient(); - await client2.PublishStringAsync("A"); + await client2.PublishStringAsync("A", null); await LongTestDelay(); @@ -59,7 +59,7 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_1() await client1.SubscribeAsync("A", MqttQualityOfServiceLevel.AtLeastOnce); var client2 = await testEnvironment.ConnectClient(); - await client2.PublishStringAsync("A", qualityOfServiceLevel: MqttQualityOfServiceLevel.AtLeastOnce); + await client2.PublishStringAsync("A",null, qualityOfServiceLevel: MqttQualityOfServiceLevel.AtLeastOnce); await LongTestDelay(); @@ -94,7 +94,7 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_2() await client1.SubscribeAsync("A", MqttQualityOfServiceLevel.ExactlyOnce); var client2 = await testEnvironment.ConnectClient(); - await client2.PublishStringAsync("A", qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); + await client2.PublishStringAsync("A",null, qualityOfServiceLevel: MqttQualityOfServiceLevel.ExactlyOnce); await LongTestDelay(); @@ -127,21 +127,17 @@ public async Task Preserve_Message_Order_For_Queued_Messages() await LongTestDelay(); // Now inject messages which are appended to the queue of the client. - await server.InjectApplicationMessage("T", "0", MqttQualityOfServiceLevel.AtLeastOnce); - - await server.InjectApplicationMessage("T", "2", MqttQualityOfServiceLevel.AtLeastOnce); - await server.InjectApplicationMessage("T", "1", MqttQualityOfServiceLevel.AtLeastOnce); - - await server.InjectApplicationMessage("T", "4", MqttQualityOfServiceLevel.AtLeastOnce); - await server.InjectApplicationMessage("T", "3", MqttQualityOfServiceLevel.AtLeastOnce); - - await server.InjectApplicationMessage("T", "6", MqttQualityOfServiceLevel.AtLeastOnce); - await server.InjectApplicationMessage("T", "5", MqttQualityOfServiceLevel.AtLeastOnce); - - await server.InjectApplicationMessage("T", "8", MqttQualityOfServiceLevel.AtLeastOnce); - await server.InjectApplicationMessage("T", "7", MqttQualityOfServiceLevel.AtLeastOnce); - - await server.InjectApplicationMessage("T", "9", MqttQualityOfServiceLevel.AtLeastOnce); + var clientId = string.Empty; + await server.InjectStringAsync(clientId, "T", "0", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "2", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "1", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "4", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "3", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "6", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "5", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "8", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "7", MqttQualityOfServiceLevel.AtLeastOnce); + await server.InjectStringAsync(clientId, "T", "9", MqttQualityOfServiceLevel.AtLeastOnce); await LongTestDelay(); diff --git a/Source/MQTTnet.Tests/Server/Security_Tests.cs b/Source/MQTTnet.Tests/Server/Security_Tests.cs index 404314c43..7683be864 100644 --- a/Source/MQTTnet.Tests/Server/Security_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Security_Tests.cs @@ -58,7 +58,7 @@ await validClient.ConnectAsync( .WithClientId("CLIENT") .Build()); - await validClient.PublishStringAsync("HELLO 1"); + await validClient.PublishStringAsync("HELLO 1", null); // The following code tries to connect a new client with the same client ID but invalid // credentials. This should block the second client but keep the first one up and running. @@ -81,11 +81,11 @@ await invalidClient.ConnectAsync( await LongTestDelay(); - await validClient.PublishStringAsync("HELLO 2"); + await validClient.PublishStringAsync("HELLO 2", null); await LongTestDelay(); - await validClient.PublishStringAsync("HELLO 3"); + await validClient.PublishStringAsync("HELLO 3", null); await LongTestDelay(); diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index 91345da56..df5ef18eb 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Linq; using System.Text; using System.Threading; @@ -282,7 +283,7 @@ public async Task Set_Session_Item() server.InterceptingPublishAsync += e => { - e.ApplicationMessage.PayloadSegment = new ArraySegment(Encoding.UTF8.GetBytes(e.SessionItems["default_payload"] as string ?? string.Empty)); + e.ApplicationMessage.Payload = new ReadOnlySequence(Encoding.UTF8.GetBytes(e.SessionItems["default_payload"] as string ?? string.Empty)); return CompletedTask.Instance; }; @@ -300,7 +301,7 @@ public async Task Set_Session_Item() Assert.AreEqual(MqttClientSubscribeResultCode.GrantedQoS0, subscribeResult.Items.First().ResultCode); var client2 = await testEnvironment.ConnectClient(); - await client2.PublishStringAsync("x"); + await client2.PublishStringAsync("x",null); await Task.Delay(1000); diff --git a/Source/MQTTnet.Tests/Server/Status_Tests.cs b/Source/MQTTnet.Tests/Server/Status_Tests.cs index ef9419eb3..251b2e199 100644 --- a/Source/MQTTnet.Tests/Server/Status_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Status_Tests.cs @@ -158,7 +158,7 @@ public async Task Track_Sent_Application_Messages() for (var i = 1; i < 25; i++) { - await c1.PublishStringAsync("a"); + await c1.PublishStringAsync("a",null); await Task.Delay(50); var clientStatus = await server.GetClientsAsync(); diff --git a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs index 9dd48c317..9f6376ea3 100644 --- a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs @@ -187,7 +187,7 @@ public async Task Intercept_Subscription() await client.SubscribeAsync("b"); - await client.PublishStringAsync("a"); + await client.PublishStringAsync("a", null); await Task.Delay(500); @@ -323,15 +323,15 @@ public async Task Subscribe_Multiple_In_Multiple_Request() var c2 = await testEnvironment.ConnectClient(); - await c2.PublishStringAsync("a"); + await c2.PublishStringAsync("a", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 1); - await c2.PublishStringAsync("b"); + await c2.PublishStringAsync("b", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 2); - await c2.PublishStringAsync("c"); + await c2.PublishStringAsync("c", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 3); } @@ -357,15 +357,15 @@ public async Task Subscribe_Multiple_In_Single_Request() var c2 = await testEnvironment.ConnectClient(); - await c2.PublishStringAsync("a"); + await c2.PublishStringAsync("a", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 1); - await c2.PublishStringAsync("b"); + await c2.PublishStringAsync("b", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 2); - await c2.PublishStringAsync("c"); + await c2.PublishStringAsync("c", null); await Task.Delay(100); Assert.AreEqual(receivedMessagesCount, 3); } diff --git a/Source/MQTTnet.Tests/Server/Tls_Tests.cs b/Source/MQTTnet.Tests/Server/Tls_Tests.cs index a87b0dc4c..e8ef20582 100644 --- a/Source/MQTTnet.Tests/Server/Tls_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Tls_Tests.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Linq; using System.Net; using System.Security.Authentication; @@ -99,7 +100,7 @@ await firstClient.PublishAsync( new MqttApplicationMessage { Topic = "TestTopic1", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence(new byte[] { 1, 2, 3, 4 }) }); await testEnvironment.Server.InjectApplicationMessage( @@ -107,7 +108,7 @@ await testEnvironment.Server.InjectApplicationMessage( new MqttApplicationMessage { Topic = "TestTopic1", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence(new byte[] { 1, 2, 3, 4 }) })); certificateProvider.CurrentCertificate = CreateCertificate(secondOid); @@ -135,7 +136,7 @@ await firstClient.PublishAsync( new MqttApplicationMessage { Topic = "TestTopic2", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence([1, 2, 3, 4]) }); await testEnvironment.Server.InjectApplicationMessage( @@ -143,7 +144,7 @@ await testEnvironment.Server.InjectApplicationMessage( new MqttApplicationMessage { Topic = "TestTopic2", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence(new byte[] { 1, 2, 3, 4 }) })); // Ensure first client still works @@ -151,7 +152,7 @@ await firstClient.PublishAsync( new MqttApplicationMessage { Topic = "TestTopic1", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence(new byte[] { 1, 2, 3, 4 }) }); await testEnvironment.Server.InjectApplicationMessage( @@ -159,7 +160,7 @@ await testEnvironment.Server.InjectApplicationMessage( new MqttApplicationMessage { Topic = "TestTopic1", - PayloadSegment = new ArraySegment(new byte[] { 1, 2, 3, 4 }) + Payload = new ReadOnlySequence(new byte[] { 1, 2, 3, 4 }) })); await Task.Delay(1000); diff --git a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs index 28d72bcf2..1cc77486b 100644 --- a/Source/MQTTnet/Adapter/MqttChannelAdapter.cs +++ b/Source/MQTTnet/Adapter/MqttChannelAdapter.cs @@ -2,6 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Channel; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Exceptions; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.Packets; using System; using System.Buffers; using System.IO; @@ -11,12 +17,6 @@ using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Channel; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Exceptions; -using MQTTnet.Formatter; -using MQTTnet.Internal; -using MQTTnet.Packets; namespace MQTTnet.Adapter; @@ -30,6 +30,7 @@ public sealed class MqttChannelAdapter : Disposable, IMqttChannelAdapter readonly MqttNetSourceLogger _logger; readonly byte[] _singleByteBuffer = new byte[1]; readonly AsyncLock _syncRoot = new(); + private BufferOwner _bodyOwner = null; Statistics _statistics; // mutable struct, don't make readonly! @@ -259,6 +260,7 @@ protected override void Dispose(bool disposing) { _channel.Dispose(); _syncRoot.Dispose(); + _bodyOwner?.Dispose(); } base.Dispose(disposing); @@ -398,24 +400,29 @@ async Task ReceiveAsync(CancellationToken cancellationToken) var fixedHeader = readFixedHeaderResult.FixedHeader; if (fixedHeader.RemainingLength == 0) { - return new ReceivedMqttPacket(fixedHeader.Flags, EmptyBuffer.ArraySegment, 2); + return new ReceivedMqttPacket(fixedHeader.Flags, ReadOnlySequence.Empty, 2); } var bodyLength = fixedHeader.RemainingLength; - var body = new byte[bodyLength]; + // Return and clear the previous body buffer + _bodyOwner?.Dispose(); + + // Re-rent a body buffer + _bodyOwner = BufferOwner.Rent(bodyLength); + var bodyBuffer = _bodyOwner.Buffer; var bodyOffset = 0; var chunkSize = Math.Min(ReadBufferSize, bodyLength); do { - var bytesLeft = body.Length - bodyOffset; + var bytesLeft = bodyLength - bodyOffset; if (chunkSize > bytesLeft) { chunkSize = bytesLeft; } - var readBytes = await _channel.ReadAsync(body, bodyOffset, chunkSize, cancellationToken).ConfigureAwait(false); + var readBytes = await _channel.ReadAsync(bodyBuffer, bodyOffset, chunkSize, cancellationToken).ConfigureAwait(false); if (cancellationToken.IsCancellationRequested) { @@ -430,10 +437,10 @@ async Task ReceiveAsync(CancellationToken cancellationToken) bodyOffset += readBytes; } while (bodyOffset < bodyLength); - PacketInspector?.FillReceiveBuffer(body); + var body = bodyBuffer.AsMemory(0, bodyLength); + PacketInspector?.FillReceiveBuffer(body.Span); - var bodySegment = new ArraySegment(body, 0, bodyLength); - return new ReceivedMqttPacket(fixedHeader.Flags, bodySegment, fixedHeader.TotalLength); + return new ReceivedMqttPacket(fixedHeader.Flags, new ReadOnlySequence(body), fixedHeader.TotalLength); } static bool WrapAndThrowException(Exception exception) @@ -484,4 +491,36 @@ public void Reset() Volatile.Write(ref _bytesSent, 0); } } + + sealed class BufferOwner : IDisposable + { + private bool _disposed = false; + + public byte[] Buffer { get; private set; } + + /// + /// rent a buffer from ArrayPool + /// + /// + /// + public static BufferOwner Rent(int minBufferSize) + { + return new BufferOwner() + { + Buffer = ArrayPool.Shared.Rent(minBufferSize) + }; + } + + /// + /// return the buffer to ArrayPool + /// + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + ArrayPool.Shared.Return(Buffer); + } + } + } } \ No newline at end of file diff --git a/Source/MQTTnet/Adapter/MqttPacketInspector.cs b/Source/MQTTnet/Adapter/MqttPacketInspector.cs index f4ea7753b..d5bac11cf 100644 --- a/Source/MQTTnet/Adapter/MqttPacketInspector.cs +++ b/Source/MQTTnet/Adapter/MqttPacketInspector.cs @@ -2,13 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.IO; -using System.Threading.Tasks; using MQTTnet.Diagnostics.Logger; using MQTTnet.Diagnostics.PacketInspection; using MQTTnet.Formatter; using MQTTnet.Internal; +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.Adapter; @@ -17,7 +19,9 @@ public sealed class MqttPacketInspector readonly AsyncEvent _asyncEvent; readonly MqttNetSourceLogger _logger; - MemoryStream _receivedPacketBuffer; + readonly Pipe _pipeIn = new(); + readonly Pipe _pipeOut = new(); + ReceiveState _receiveState = ReceiveState.Disable; public MqttPacketInspector(AsyncEvent asyncEvent, IMqttNetLogger logger) { @@ -28,63 +32,118 @@ public MqttPacketInspector(AsyncEvent asyncEvent, IM _logger = logger.WithSource(nameof(MqttPacketInspector)); } - public void BeginReceivePacket() + public async Task BeginSendPacket(MqttPacketBuffer buffer) { if (!_asyncEvent.HasHandlers) { return; } - if (_receivedPacketBuffer == null) + // Create a copy of the actual packet so that the inspector gets no access + // to the internal buffers. This is waste of memory but this feature is only + // intended for debugging etc. so that this is OK. + var writer = _pipeOut.Writer; + await writer.WriteAsync(buffer.Packet).ConfigureAwait(false); + foreach (var memory in buffer.Payload) { - _receivedPacketBuffer = new MemoryStream(); + await writer.WriteAsync(memory).ConfigureAwait(false); } - _receivedPacketBuffer?.SetLength(0); + await writer.CompleteAsync().ConfigureAwait(false); + await InspectPacketAsync(_pipeOut.Reader, MqttPacketFlowDirection.Outbound).ConfigureAwait(false); + + // reset pipe + await _pipeOut.Reader.CompleteAsync().ConfigureAwait(false); + _pipeOut.Reset(); } - public Task BeginSendPacket(MqttPacketBuffer buffer) + public void BeginReceivePacket() { - if (!_asyncEvent.HasHandlers) + if (_asyncEvent.HasHandlers) { - return CompletedTask.Instance; + // This shouldn't happen, but we need to be able to accommodate the unexpected. + if (_receiveState == ReceiveState.Fill) + { + _pipeIn.Writer.Complete(); + _pipeIn.Reader.Complete(); + _pipeIn.Reset(); + + _logger.Warning("An EndReceivePacket() operation was unexpectedly lost."); + } + + _receiveState = ReceiveState.Begin; } + else + { + _receiveState = ReceiveState.Disable; + } + } - // Create a copy of the actual packet so that the inspector gets no access - // to the internal buffers. This is waste of memory but this feature is only - // intended for debugging etc. so that this is OK. - var bufferCopy = buffer.ToArray(); + public void FillReceiveBuffer(ReadOnlySpan buffer) + { + if (_receiveState == ReceiveState.Disable) + { + return; + } - return InspectPacket(bufferCopy, MqttPacketFlowDirection.Outbound); + if (_receiveState == ReceiveState.End) + { + throw new InvalidOperationException("FillReceiveBuffer is not allowed in End state."); + } + + _pipeIn.Writer.Write(buffer); + _receiveState = ReceiveState.Fill; } - public Task EndReceivePacket() + public void FillReceiveBuffer(ReadOnlySequence buffer) { - if (!_asyncEvent.HasHandlers) + if (_receiveState == ReceiveState.Disable) + { + return; + } + + if (_receiveState == ReceiveState.End) { - return CompletedTask.Instance; + throw new InvalidOperationException("FillReceiveBuffer is not allowed in End state."); } - var buffer = _receivedPacketBuffer.ToArray(); - _receivedPacketBuffer.SetLength(0); + var writer = _pipeIn.Writer; + foreach (var memory in buffer) + { + writer.Write(memory.Span); + } - return InspectPacket(buffer, MqttPacketFlowDirection.Inbound); + _receiveState = ReceiveState.Fill; } - public void FillReceiveBuffer(byte[] buffer) + + public async Task EndReceivePacket() { - if (!_asyncEvent.HasHandlers) + if (_receiveState == ReceiveState.Disable) { return; } - _receivedPacketBuffer?.Write(buffer, 0, buffer.Length); + if (_receiveState == ReceiveState.Fill) + { + await _pipeIn.Writer.FlushAsync().ConfigureAwait(false); + await _pipeIn.Writer.CompleteAsync().ConfigureAwait(false); + await InspectPacketAsync(_pipeIn.Reader, MqttPacketFlowDirection.Inbound).ConfigureAwait(false); + + // reset pipe + await _pipeIn.Reader.CompleteAsync().ConfigureAwait(false); + _pipeIn.Reset(); + } + + _receiveState = ReceiveState.End; } - async Task InspectPacket(byte[] buffer, MqttPacketFlowDirection direction) + + async Task InspectPacketAsync(PipeReader pipeReader, MqttPacketFlowDirection direction) { try { + var buffer = await ReadBufferAsync(pipeReader, default).ConfigureAwait(false); var eventArgs = new InspectMqttPacketEventArgs(direction, buffer); await _asyncEvent.InvokeAsync(eventArgs).ConfigureAwait(false); } @@ -93,4 +152,24 @@ async Task InspectPacket(byte[] buffer, MqttPacketFlowDirection direction) _logger.Error(exception, "Error while inspecting packet."); } } + + static async ValueTask> ReadBufferAsync(PipeReader pipeReader, CancellationToken cancellationToken) + { + var readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); + while (!readResult.IsCompleted) + { + pipeReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.Start); + readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + + return readResult.Buffer; + } + + private enum ReceiveState + { + Disable, + Begin, + Fill, + End, + } } \ No newline at end of file diff --git a/Source/MQTTnet/Adapter/ReceivedMqttPacket.cs b/Source/MQTTnet/Adapter/ReceivedMqttPacket.cs index f290e5656..676df3929 100644 --- a/Source/MQTTnet/Adapter/ReceivedMqttPacket.cs +++ b/Source/MQTTnet/Adapter/ReceivedMqttPacket.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; namespace MQTTnet.Adapter; @@ -10,7 +11,7 @@ public readonly struct ReceivedMqttPacket { public static readonly ReceivedMqttPacket Empty = new(); - public ReceivedMqttPacket(byte fixedHeader, ArraySegment body, int totalLength) + public ReceivedMqttPacket(byte fixedHeader, ReadOnlySequence body, int totalLength) { FixedHeader = fixedHeader; Body = body; @@ -19,7 +20,7 @@ public ReceivedMqttPacket(byte fixedHeader, ArraySegment body, int totalLe public byte FixedHeader { get; } - public ArraySegment Body { get; } + public ReadOnlySequence Body { get; } public int TotalLength { get; } } \ No newline at end of file diff --git a/Source/MQTTnet/Connecting/MqttClientConnectResult.cs b/Source/MQTTnet/Connecting/MqttClientConnectResult.cs index 5f9d7f15a..cc67190c0 100644 --- a/Source/MQTTnet/Connecting/MqttClientConnectResult.cs +++ b/Source/MQTTnet/Connecting/MqttClientConnectResult.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -20,7 +21,7 @@ public sealed class MqttClientConnectResult /// Gets the authentication data. /// MQTTv5 only. /// - public byte[] AuthenticationData { get; internal set; } + public ReadOnlyMemory AuthenticationData { get; internal set; } /// /// Gets the authentication method. diff --git a/Source/MQTTnet/Diagnostics/PacketInspection/InspectMqttPacketEventArgs.cs b/Source/MQTTnet/Diagnostics/PacketInspection/InspectMqttPacketEventArgs.cs index ec7ec733e..16a259364 100644 --- a/Source/MQTTnet/Diagnostics/PacketInspection/InspectMqttPacketEventArgs.cs +++ b/Source/MQTTnet/Diagnostics/PacketInspection/InspectMqttPacketEventArgs.cs @@ -3,18 +3,19 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; namespace MQTTnet.Diagnostics.PacketInspection { public sealed class InspectMqttPacketEventArgs : EventArgs { - public InspectMqttPacketEventArgs(MqttPacketFlowDirection direction, byte[] buffer) + public InspectMqttPacketEventArgs(MqttPacketFlowDirection direction, ReadOnlySequence buffer) { Direction = direction; - Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); + Buffer = buffer; } - public byte[] Buffer { get; } + public ReadOnlySequence Buffer { get; } public MqttPacketFlowDirection Direction { get; } } diff --git a/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeContext.cs b/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeContext.cs index 1f93a8075..d5ae81728 100644 --- a/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeContext.cs +++ b/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeContext.cs @@ -28,7 +28,7 @@ public MqttExtendedAuthenticationExchangeContext(MqttAuthPacket authPacket, Mqtt /// Gets the authentication data. /// Hint: MQTT 5 feature only. /// - public byte[] AuthenticationData { get; } + public ReadOnlyMemory AuthenticationData { get; } /// /// Gets the authentication method. diff --git a/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeData.cs b/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeData.cs index 2c58774c2..f20327172 100644 --- a/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeData.cs +++ b/Source/MQTTnet/ExtendedAuthenticationExchange/MqttExtendedAuthenticationExchangeData.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -18,7 +19,7 @@ public class MqttExtendedAuthenticationExchangeData /// method. /// Hint: MQTT 5 feature only. /// - public byte[] AuthenticationData { get; set; } + public ReadOnlyMemory AuthenticationData { get; set; } /// /// Gets or sets the reason code. diff --git a/Source/MQTTnet/Formatter/MqttBufferReader.cs b/Source/MQTTnet/Formatter/MqttBufferReader.cs index c8102be7e..cf889a1b8 100644 --- a/Source/MQTTnet/Formatter/MqttBufferReader.cs +++ b/Source/MQTTnet/Formatter/MqttBufferReader.cs @@ -2,106 +2,104 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; using System; +using System.Buffers; using System.Runtime.CompilerServices; using System.Text; -using MQTTnet.Exceptions; -using MQTTnet.Internal; -using System.Buffers.Binary; namespace MQTTnet.Formatter { public sealed class MqttBufferReader { - byte[] _buffer = EmptyBuffer.Array; - int _maxPosition; - int _offset; - int _position; + long _position = default; + ReadOnlySequence _buffer = default; + - public int BytesLeft => _maxPosition - _position; + public long BytesLeft => _buffer.Length - _position; - public bool EndOfStream => BytesLeft == 0; + public bool EndOfStream => _buffer.Length <= _position; - public int Position => _position - _offset; + public long Position => _position; - public byte[] ReadBinaryData() + public ReadOnlySequence ReadBinaryData() { var length = ReadTwoByteInteger(); - if (length == 0) { - return EmptyBuffer.Array; + return ReadOnlySequence.Empty; } ValidateReceiveBuffer(length); + var buffer = _buffer.Slice(_position, length); - var result = GC.AllocateUninitializedArray(length); - MqttMemoryHelper.Copy(_buffer, _position, result, 0, length); _position += length; - - return result; + return buffer; } public byte ReadByte() { ValidateReceiveBuffer(1); - return _buffer[_position++]; + + var reader = new SequenceReader(_buffer); + reader.Advance(_position); + reader.TryRead(out byte value); + + _position += 1; + return value; } + public ushort ReadTwoByteInteger() + { + ValidateReceiveBuffer(2); + + var reader = new SequenceReader(_buffer); + reader.Advance(_position); + reader.TryReadBigEndian(out short value); + + _position += 2; + return Unsafe.As(ref value); + } + + public uint ReadFourByteInteger() { ValidateReceiveBuffer(4); - var value = BinaryPrimitives.ReadUInt32BigEndian(_buffer.AsSpan(_position)); + var reader = new SequenceReader(_buffer); + reader.Advance(_position); + reader.TryReadBigEndian(out int value); _position += 4; - return value; + return Unsafe.As(ref value); } - public byte[] ReadRemainingData() - { - var bufferLength = BytesLeft; - if (bufferLength == 0) - { - return EmptyBuffer.Array; - } - - var buffer = GC.AllocateUninitializedArray(bufferLength); - MqttMemoryHelper.Copy(_buffer, _position, buffer, 0, bufferLength); - _position += bufferLength; + public ReadOnlySequence ReadRemainingData() + { + var buffer = _buffer.Slice(_position); + _position = _buffer.Length; return buffer; } + public string ReadString() { var length = ReadTwoByteInteger(); - if (length == 0) { return string.Empty; } ValidateReceiveBuffer(length); - - // AsSpan() version is slightly faster. Not much but at least a little bit. - var result = Encoding.UTF8.GetString(_buffer.AsSpan(_position, length)); + var buffer = _buffer.Slice(_position, length); + var result = Encoding.UTF8.GetString(buffer); _position += length; return result; } - public ushort ReadTwoByteInteger() - { - ValidateReceiveBuffer(2); - - var value = BinaryPrimitives.ReadUInt16BigEndian(_buffer.AsSpan(_position)); - - _position += 2; - return value; - } - public uint ReadVariableByteInteger() { var multiplier = 1; @@ -124,31 +122,26 @@ public uint ReadVariableByteInteger() return value; } - public void Seek(int position) + public void SetBuffer(ReadOnlyMemory buffer) { - _position = _offset + position; + SetBuffer(new ReadOnlySequence(buffer)); } - public void SetBuffer(ArraySegment buffer) + public void SetBuffer(ReadOnlySequence buffer) { - SetBuffer(buffer.Array, buffer.Offset, buffer.Count); + _buffer = buffer; + _position = 0; } - public void SetBuffer(byte[] buffer, int offset, int length) - { - _buffer = buffer ?? throw new ArgumentNullException(nameof(buffer)); - _offset = offset; - _position = offset; - _maxPosition = offset + length; - } [MethodImpl(MethodImplOptions.AggressiveInlining)] void ValidateReceiveBuffer(int length) { + var bufferLength = _buffer.Length; var newPosition = _position + length; - if (_maxPosition < newPosition) + if (bufferLength < newPosition) { - throw new MqttProtocolViolationException($"Expected at least {newPosition} bytes but there are only {_maxPosition} bytes"); + throw new MqttProtocolViolationException($"Expected at least {newPosition} bytes but there are only {bufferLength} bytes"); } } } diff --git a/Source/MQTTnet/Formatter/MqttBufferWriter.cs b/Source/MQTTnet/Formatter/MqttBufferWriter.cs index c89ccad76..95b54095b 100644 --- a/Source/MQTTnet/Formatter/MqttBufferWriter.cs +++ b/Source/MQTTnet/Formatter/MqttBufferWriter.cs @@ -2,12 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Protocol; using System; +using System.Buffers; +using System.Buffers.Binary; using System.Runtime.CompilerServices; using System.Text; -using MQTTnet.Exceptions; -using MQTTnet.Internal; -using MQTTnet.Protocol; namespace MQTTnet.Formatter { @@ -26,6 +27,11 @@ public sealed class MqttBufferWriter byte[] _buffer; int _position; + IBufferWriter _lowLevelBufferWriter; + + public int Length { get; private set; } + + public int BufferSize => _buffer.Length; public MqttBufferWriter(int bufferSize, int maxBufferSize) { @@ -33,7 +39,6 @@ public MqttBufferWriter(int bufferSize, int maxBufferSize) _maxBufferSize = maxBufferSize; } - public int Length { get; private set; } public static byte BuildFixedHeader(MqttControlPacketType packetType, byte flags = 0) { @@ -60,9 +65,15 @@ public void Cleanup() _buffer = new byte[_maxBufferSize]; } - public byte[] GetBuffer() + + public ReadOnlySpan GetWrittenSpan() + { + return _buffer.AsSpan(0, Length); + } + + public ReadOnlyMemory GetWrittenMemory() { - return _buffer; + return _buffer.AsMemory(0, Length); } public static int GetVariableByteIntegerSize(uint value) @@ -107,67 +118,53 @@ public void Write(MqttBufferWriter propertyWriter) { ArgumentNullException.ThrowIfNull(propertyWriter); - WriteBinary(propertyWriter._buffer, 0, propertyWriter.Length); + Write(propertyWriter.GetWrittenSpan()); } - public void WriteBinary(byte[] value) + public void Write(ReadOnlySpan buffer) { - if (value == null || value.Length == 0) + if (buffer.IsEmpty) { - EnsureAdditionalCapacity(2); - - _buffer[_position] = 0; - _buffer[_position + 1] = 0; - - IncreasePosition(2); + return; } - else - { - var valueLength = value.Length; - - EnsureAdditionalCapacity(valueLength + 2); - _buffer[_position] = (byte)(valueLength >> 8); - _buffer[_position + 1] = (byte)valueLength; + var size = buffer.Length; + var span = GetSpan(size); - MqttMemoryHelper.Copy(value, 0, _buffer, _position + 2, valueLength); - IncreasePosition(valueLength + 2); - } + buffer.CopyTo(span); + Advance(size); } - public void WriteBinary(byte[] buffer, int offset, int count) + public void WriteBinary(ReadOnlySpan value) { - ArgumentNullException.ThrowIfNull(buffer); - - if (count == 0) - { - return; - } + var size = value.Length + 2; + var span = GetSpan(size); - EnsureAdditionalCapacity(count); + BinaryPrimitives.WriteUInt16BigEndian(span, (ushort)value.Length); + value.CopyTo(span[2..]); - MqttMemoryHelper.Copy(buffer, offset, _buffer, _position, count); - IncreasePosition(count); + Advance(size); } + public void WriteByte(byte @byte) { - EnsureAdditionalCapacity(1); + const int size = sizeof(byte); + var span = GetSpan(size); - _buffer[_position] = @byte; - IncreasePosition(1); + span[0] = @byte; + Advance(size); } public void WriteString(string value) { if (string.IsNullOrEmpty(value)) { - EnsureAdditionalCapacity(2); - - _buffer[_position] = 0; - _buffer[_position + 1] = 0; + const int size = 2; + var span = GetSpan(size); - IncreasePosition(2); + span.Fill(default); + Advance(size); } else { @@ -176,10 +173,9 @@ public void WriteString(string value) // So the buffer should always have much more capacity left so that a correct value // here is only waste of CPU cycles. var byteCount = value.Length * 4; + var span = GetSpan(byteCount + 2); - EnsureAdditionalCapacity(byteCount + 2); - - var writtenBytes = Encoding.UTF8.GetBytes(value, 0, value.Length, _buffer, _position + 2); + var writtenBytes = Encoding.UTF8.GetBytes(value, span[2..]); // From RFC: 1.5.4 UTF-8 Encoded String // Unless stated otherwise all UTF-8 encoded strings can have any length in the range 0 to 65,535 bytes. @@ -188,37 +184,40 @@ public void WriteString(string value) throw new MqttProtocolViolationException($"The maximum string length is 65535. The current string has a length of {writtenBytes}."); } - _buffer[_position] = (byte)(writtenBytes >> 8); - _buffer[_position + 1] = (byte)writtenBytes; + BinaryPrimitives.WriteUInt16BigEndian(span, (ushort)writtenBytes); - IncreasePosition(writtenBytes + 2); + Advance(writtenBytes + 2); } } public void WriteTwoByteInteger(ushort value) { - EnsureAdditionalCapacity(2); + const int size = sizeof(ushort); + var span = GetSpan(size); - _buffer[_position] = (byte)(value >> 8); - IncreasePosition(1); - _buffer[_position] = (byte)value; - IncreasePosition(1); + BinaryPrimitives.WriteUInt16BigEndian(span, value); + + Advance(size); } - public void WriteVariableByteInteger(uint value) + public void WriteFourByteInteger(uint value) { - if (value == 0) - { - _buffer[_position] = 0; - IncreasePosition(1); + const int size = sizeof(uint); + var span = GetSpan(size); - return; - } + BinaryPrimitives.WriteUInt32BigEndian(span, value); + + Advance(size); + } + + public void WriteVariableByteInteger(uint value) + { + EnsureCapacity(sizeof(uint)); if (value <= 127) { _buffer[_position] = (byte)value; - IncreasePosition(1); + Advance(1); return; } @@ -243,21 +242,19 @@ public void WriteVariableByteInteger(uint value) size++; } while (x > 0); - IncreasePosition(size); + Advance(size); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - void EnsureAdditionalCapacity(int additionalCapacity) + Span GetSpan(int size) { - var bufferLength = _buffer.Length; - - var freeSpace = bufferLength - _position; - if (freeSpace >= additionalCapacity) + var freeSpace = _buffer.Length - _position; + if (freeSpace < size) { - return; + EnsureCapacity(_buffer.Length + size - freeSpace); } - EnsureCapacity(bufferLength + additionalCapacity - freeSpace); + return _buffer.AsSpan(_position, size); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -280,9 +277,9 @@ void EnsureCapacity(int capacity) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - void IncreasePosition(int length) + void Advance(int count) { - _position += length; + _position += count; if (_position > Length) { @@ -291,5 +288,24 @@ void IncreasePosition(int length) Length = _position; } } + + /// + /// Returns a lower-level wrapper. + /// + /// + public IBufferWriter AsLowLevelBufferWriter() + { + _lowLevelBufferWriter ??= new LowLevelBufferWriter(this); + return _lowLevelBufferWriter; + } + + + private sealed class LowLevelBufferWriter(MqttBufferWriter bufferWriter) : IBufferWriter + { + private readonly MqttBufferWriter _bufferWriter = bufferWriter; + public void Advance(int count) => _bufferWriter.Advance(count); + public Span GetSpan(int sizeHint = 0) => _bufferWriter.GetSpan(sizeHint); + public Memory GetMemory(int sizeHint = 0) => throw new NotSupportedException(); + } } } \ No newline at end of file diff --git a/Source/MQTTnet/Formatter/MqttConnectPacketFactory.cs b/Source/MQTTnet/Formatter/MqttConnectPacketFactory.cs index 89dda0598..a0137e2a9 100644 --- a/Source/MQTTnet/Formatter/MqttConnectPacketFactory.cs +++ b/Source/MQTTnet/Formatter/MqttConnectPacketFactory.cs @@ -17,7 +17,7 @@ public static MqttConnectPacket Create(MqttClientOptions clientOptions) { ClientId = clientOptions.ClientId, Username = clientOptions.Credentials?.GetUserName(clientOptions), - Password = clientOptions.Credentials?.GetPassword(clientOptions), + Password = clientOptions.Credentials == null ? default : clientOptions.Credentials.GetPassword(clientOptions), CleanSession = clientOptions.CleanSession, KeepAlivePeriod = (ushort)clientOptions.KeepAlivePeriod.TotalSeconds, AuthenticationMethod = clientOptions.AuthenticationMethod, diff --git a/Source/MQTTnet/Formatter/MqttPacketBuffer.cs b/Source/MQTTnet/Formatter/MqttPacketBuffer.cs index b87939e73..e93012bd6 100644 --- a/Source/MQTTnet/Formatter/MqttPacketBuffer.cs +++ b/Source/MQTTnet/Formatter/MqttPacketBuffer.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using MQTTnet.Internal; using System; using System.Buffers; @@ -10,28 +9,24 @@ namespace MQTTnet.Formatter { public readonly struct MqttPacketBuffer { - public MqttPacketBuffer(ArraySegment packet, ReadOnlySequence payload) - { - Packet = packet; - Payload = payload; + public int Length { get; } + + public ReadOnlyMemory Packet { get; } - Length = Packet.Count + (int)Payload.Length; + public ReadOnlySequence Payload { get; } + + public MqttPacketBuffer(ReadOnlyMemory packet) + : this(packet, ReadOnlySequence.Empty) + { } - public MqttPacketBuffer(ArraySegment packet) + public MqttPacketBuffer(ReadOnlyMemory packet, ReadOnlySequence payload) { Packet = packet; - Payload = EmptyBuffer.ReadOnlySequence; - - Length = Packet.Count; + Payload = payload; + Length = Packet.Length + (int)Payload.Length; } - public int Length { get; } - - public ArraySegment Packet { get; } - - public ReadOnlySequence Payload { get; } - public byte[] ToArray() { if (Payload.Length == 0) @@ -40,19 +35,24 @@ public byte[] ToArray() } var buffer = GC.AllocateUninitializedArray(Length); - MqttMemoryHelper.Copy(Packet.Array, Packet.Offset, buffer, 0, Packet.Count); - MqttMemoryHelper.Copy(Payload, 0, buffer, Packet.Count, (int)Payload.Length); + Packet.Span.CopyTo(buffer); + Payload.CopyTo(buffer.AsSpan(Packet.Length)); return buffer; } - public ArraySegment Join() + public ReadOnlyMemory Join() { if (Payload.Length == 0) { return Packet; } - return new ArraySegment(this.ToArray()); + + var buffer = GC.AllocateUninitializedArray(Length); + Packet.Span.CopyTo(buffer); + Payload.CopyTo(buffer.AsSpan(Packet.Length)); + + return buffer; } } } \ No newline at end of file diff --git a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs index 7022740a4..e8cb3f1af 100644 --- a/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs +++ b/Source/MQTTnet/Formatter/MqttPacketFormatterAdapter.cs @@ -82,7 +82,7 @@ public static IMqttPacketFormatter GetMqttPacketFormatter(MqttProtocolVersion pr MqttProtocolVersion ParseProtocolVersion(ReceivedMqttPacket receivedMqttPacket) { - if (receivedMqttPacket.Body.Count < 7) + if (receivedMqttPacket.Body.Length < 7) { // 2 byte protocol name length // at least 4 byte protocol name @@ -90,7 +90,7 @@ MqttProtocolVersion ParseProtocolVersion(ReceivedMqttPacket receivedMqttPacket) throw new MqttProtocolViolationException("CONNECT packet must have at least 7 bytes."); } - _bufferReader.SetBuffer(receivedMqttPacket.Body.Array, receivedMqttPacket.Body.Offset, receivedMqttPacket.Body.Count); + _bufferReader.SetBuffer(receivedMqttPacket.Body); var protocolName = _bufferReader.ReadString(); var protocolLevel = _bufferReader.ReadByte(); diff --git a/Source/MQTTnet/Formatter/V3/MqttV3PacketFormatter.cs b/Source/MQTTnet/Formatter/V3/MqttV3PacketFormatter.cs index 20dd59f49..bd2d17303 100644 --- a/Source/MQTTnet/Formatter/V3/MqttV3PacketFormatter.cs +++ b/Source/MQTTnet/Formatter/V3/MqttV3PacketFormatter.cs @@ -2,15 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Adapter; +using MQTTnet.Exceptions; +using MQTTnet.Internal; +using MQTTnet.Packets; +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; -using MQTTnet.Adapter; -using MQTTnet.Exceptions; -using MQTTnet.Packets; -using MQTTnet.Protocol; namespace MQTTnet.Formatter.V3 { @@ -115,18 +116,18 @@ public MqttPacketBuffer Encode(MqttPacket packet) _bufferWriter.WriteByte(fixedHeader); _bufferWriter.WriteVariableByteInteger(remainingLength); - var firstSegment = new ArraySegment(_bufferWriter.GetBuffer(), headerOffset, _bufferWriter.Length - headerOffset); + var firstSegment = _bufferWriter.GetWrittenMemory()[headerOffset..]; return payload.Length == 0 ? new MqttPacketBuffer(firstSegment) : new MqttPacketBuffer(firstSegment, payload); } - MqttPacket DecodeConnAckPacket(ArraySegment body) + MqttPacket DecodeConnAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttConnAckPacket(); @@ -136,11 +137,11 @@ MqttPacket DecodeConnAckPacket(ArraySegment body) return packet; } - MqttPacket DecodeConnAckPacketV311(ArraySegment body) + MqttPacket DecodeConnAckPacketV311(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttConnAckPacket(); @@ -152,11 +153,11 @@ MqttPacket DecodeConnAckPacketV311(ArraySegment body) return packet; } - MqttPacket DecodeConnectPacket(ArraySegment body) + MqttPacket DecodeConnectPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var protocolName = _bufferReader.ReadString(); var protocolVersion = _bufferReader.ReadByte(); @@ -203,7 +204,7 @@ MqttPacket DecodeConnectPacket(ArraySegment body) packet.WillRetain = willRetain; packet.WillTopic = _bufferReader.ReadString(); - packet.WillMessage = _bufferReader.ReadBinaryData(); + packet.WillMessage = _bufferReader.ReadBinaryData().Join(); } if (usernameFlag) @@ -213,18 +214,18 @@ MqttPacket DecodeConnectPacket(ArraySegment body) if (passwordFlag) { - packet.Password = _bufferReader.ReadBinaryData(); + packet.Password = _bufferReader.ReadBinaryData().ToArray(); } ValidateConnectPacket(packet); return packet; } - MqttPacket DecodePubAckPacket(ArraySegment body) + MqttPacket DecodePubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); return new MqttPubAckPacket { @@ -232,11 +233,11 @@ MqttPacket DecodePubAckPacket(ArraySegment body) }; } - MqttPacket DecodePubCompPacket(ArraySegment body) + MqttPacket DecodePubCompPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); return new MqttPubCompPacket { @@ -246,9 +247,9 @@ MqttPacket DecodePubCompPacket(ArraySegment body) MqttPacket DecodePublishPacket(ReceivedMqttPacket receivedMqttPacket) { - ThrowIfBodyIsEmpty(receivedMqttPacket.Body); + ThrowIfBodyIsEmpty(receivedMqttPacket.Body.Length); - _bufferReader.SetBuffer(receivedMqttPacket.Body.Array, receivedMqttPacket.Body.Offset, receivedMqttPacket.Body.Count); + _bufferReader.SetBuffer(receivedMqttPacket.Body); var retain = (receivedMqttPacket.FixedHeader & 0x1) > 0; var qualityOfServiceLevel = (MqttQualityOfServiceLevel)((receivedMqttPacket.FixedHeader >> 1) & 0x3); @@ -273,17 +274,17 @@ MqttPacket DecodePublishPacket(ReceivedMqttPacket receivedMqttPacket) if (!_bufferReader.EndOfStream) { - packet.PayloadSegment = new ArraySegment(_bufferReader.ReadRemainingData()); + packet.Payload = _bufferReader.ReadRemainingData(); } return packet; } - MqttPacket DecodePubRecPacket(ArraySegment body) + MqttPacket DecodePubRecPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); return new MqttPubRecPacket { @@ -291,11 +292,11 @@ MqttPacket DecodePubRecPacket(ArraySegment body) }; } - MqttPacket DecodePubRelPacket(ArraySegment body) + MqttPacket DecodePubRelPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); return new MqttPubRelPacket { @@ -303,16 +304,16 @@ MqttPacket DecodePubRelPacket(ArraySegment body) }; } - MqttPacket DecodeSubAckPacket(ArraySegment body) + MqttPacket DecodeSubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttSubAckPacket { PacketIdentifier = _bufferReader.ReadTwoByteInteger(), - ReasonCodes = new List(_bufferReader.BytesLeft) + ReasonCodes = new List((int)_bufferReader.BytesLeft) }; while (!_bufferReader.EndOfStream) @@ -323,11 +324,11 @@ MqttPacket DecodeSubAckPacket(ArraySegment body) return packet; } - MqttPacket DecodeSubscribePacket(ArraySegment body) + MqttPacket DecodeSubscribePacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttSubscribePacket { @@ -348,11 +349,11 @@ MqttPacket DecodeSubscribePacket(ArraySegment body) return packet; } - MqttPacket DecodeUnsubAckPacket(ArraySegment body) + MqttPacket DecodeUnsubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); return new MqttUnsubAckPacket { @@ -360,11 +361,11 @@ MqttPacket DecodeUnsubAckPacket(ArraySegment body) }; } - MqttPacket DecodeUnsubscribePacket(ArraySegment body) + MqttPacket DecodeUnsubscribePacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttUnsubscribePacket { @@ -454,7 +455,7 @@ byte EncodeConnectPacket(MqttConnectPacket packet, MqttBufferWriter bufferWriter if (packet.WillFlag) { bufferWriter.WriteString(packet.WillTopic); - bufferWriter.WriteBinary(packet.WillMessage); + bufferWriter.WriteBinary(packet.WillMessage.Span); } if (packet.Username != null) @@ -525,7 +526,7 @@ byte EncodeConnectPacketV311(MqttConnectPacket packet, MqttBufferWriter bufferWr if (packet.WillFlag) { bufferWriter.WriteString(packet.WillTopic); - bufferWriter.WriteBinary(packet.WillMessage); + bufferWriter.WriteBinary(packet.WillMessage.Span); } if (packet.Username != null) @@ -783,9 +784,9 @@ static byte EncodeUnsubscribePacket(MqttUnsubscribePacket packet, MqttBufferWrit } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void ThrowIfBodyIsEmpty(ArraySegment body) + static void ThrowIfBodyIsEmpty(long bodyLength) { - if (body.Count == 0) + if (bodyLength == 0) { throw new MqttProtocolViolationException("Data from the body is required but not present."); } diff --git a/Source/MQTTnet/Formatter/V5/MqttV5PacketDecoder.cs b/Source/MQTTnet/Formatter/V5/MqttV5PacketDecoder.cs index e27f129fc..eadefef5b 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV5PacketDecoder.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV5PacketDecoder.cs @@ -3,9 +3,12 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.CompilerServices; using MQTTnet.Adapter; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -66,9 +69,9 @@ public MqttPacket Decode(ReceivedMqttPacket receivedMqttPacket) } } - MqttPacket DecodeAuthPacket(ArraySegment body) + MqttPacket DecodeAuthPacket(ReadOnlySequence body) { - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttAuthPacket(); @@ -108,11 +111,11 @@ MqttPacket DecodeAuthPacket(ArraySegment body) return packet; } - MqttPacket DecodeConnAckPacket(ArraySegment body) + MqttPacket DecodeConnAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var acknowledgeFlags = _bufferReader.ReadByte(); @@ -211,11 +214,11 @@ MqttPacket DecodeConnAckPacket(ArraySegment body) return packet; } - MqttPacket DecodeConnectPacket(ArraySegment body) + MqttPacket DecodeConnectPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttConnectPacket { @@ -333,7 +336,7 @@ MqttPacket DecodeConnectPacket(ArraySegment body) } packet.WillTopic = _bufferReader.ReadString(); - packet.WillMessage = _bufferReader.ReadBinaryData(); + packet.WillMessage = _bufferReader.ReadBinaryData().Join(); packet.WillUserProperties = willPropertiesReader.CollectedUserProperties; } @@ -344,18 +347,18 @@ MqttPacket DecodeConnectPacket(ArraySegment body) if (passwordFlag) { - packet.Password = _bufferReader.ReadBinaryData(); + packet.Password = _bufferReader.ReadBinaryData().ToArray(); } return packet; } - MqttPacket DecodeDisconnectPacket(ArraySegment body) + MqttPacket DecodeDisconnectPacket(ReadOnlySequence body) { // From RFC: 3.14.2.1 Disconnect Reason Code // Byte 1 in the Variable Header is the Disconnect Reason Code. // If the Remaining Length is less than 1 the value of 0x00 (Normal disconnection) is used. - if (body.Count == 0) + if (body.Length == 0) { return new MqttDisconnectPacket { @@ -363,7 +366,7 @@ MqttPacket DecodeDisconnectPacket(ArraySegment body) }; } - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttDisconnectPacket { @@ -396,11 +399,11 @@ MqttPacket DecodeDisconnectPacket(ArraySegment body) return packet; } - MqttPacket DecodePubAckPacket(ArraySegment body) + MqttPacket DecodePubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttPubAckPacket { @@ -433,11 +436,11 @@ MqttPacket DecodePubAckPacket(ArraySegment body) return packet; } - MqttPacket DecodePubCompPacket(ArraySegment body) + MqttPacket DecodePubCompPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttPubCompPacket { @@ -471,11 +474,11 @@ MqttPacket DecodePubCompPacket(ArraySegment body) } - MqttPacket DecodePublishPacket(byte header, ArraySegment body) + MqttPacket DecodePublishPacket(byte header, ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var retain = (header & 1) > 0; var qos = (MqttQualityOfServiceLevel)((header >> 1) & 3); @@ -540,17 +543,17 @@ MqttPacket DecodePublishPacket(byte header, ArraySegment body) if (!_bufferReader.EndOfStream) { - packet.PayloadSegment = new ArraySegment(_bufferReader.ReadRemainingData()); + packet.Payload = _bufferReader.ReadRemainingData(); } return packet; } - MqttPacket DecodePubRecPacket(ArraySegment body) + MqttPacket DecodePubRecPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttPubRecPacket { @@ -583,11 +586,11 @@ MqttPacket DecodePubRecPacket(ArraySegment body) return packet; } - MqttPacket DecodePubRelPacket(ArraySegment body) + MqttPacket DecodePubRelPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttPubRelPacket { @@ -620,11 +623,11 @@ MqttPacket DecodePubRelPacket(ArraySegment body) return packet; } - MqttPacket DecodeSubAckPacket(ArraySegment body) + MqttPacket DecodeSubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttSubAckPacket { @@ -646,7 +649,7 @@ MqttPacket DecodeSubAckPacket(ArraySegment body) packet.UserProperties = propertiesReader.CollectedUserProperties; - packet.ReasonCodes = new List(_bufferReader.BytesLeft); + packet.ReasonCodes = new List((int)_bufferReader.BytesLeft); while (!_bufferReader.EndOfStream) { var reasonCode = (MqttSubscribeReasonCode)_bufferReader.ReadByte(); @@ -656,11 +659,11 @@ MqttPacket DecodeSubAckPacket(ArraySegment body) return packet; } - MqttPacket DecodeSubscribePacket(ArraySegment body) + MqttPacket DecodeSubscribePacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttSubscribePacket { @@ -706,11 +709,11 @@ MqttPacket DecodeSubscribePacket(ArraySegment body) return packet; } - MqttPacket DecodeUnsubAckPacket(ArraySegment body) + MqttPacket DecodeUnsubAckPacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttUnsubAckPacket { @@ -732,7 +735,7 @@ MqttPacket DecodeUnsubAckPacket(ArraySegment body) packet.UserProperties = propertiesReader.CollectedUserProperties; - packet.ReasonCodes = new List(_bufferReader.BytesLeft); + packet.ReasonCodes = new List((int)_bufferReader.BytesLeft); while (!_bufferReader.EndOfStream) { @@ -743,11 +746,11 @@ MqttPacket DecodeUnsubAckPacket(ArraySegment body) return packet; } - MqttPacket DecodeUnsubscribePacket(ArraySegment body) + MqttPacket DecodeUnsubscribePacket(ReadOnlySequence body) { - ThrowIfBodyIsEmpty(body); + ThrowIfBodyIsEmpty(body.Length); - _bufferReader.SetBuffer(body.Array, body.Offset, body.Count); + _bufferReader.SetBuffer(body); var packet = new MqttUnsubscribePacket { @@ -771,9 +774,10 @@ MqttPacket DecodeUnsubscribePacket(ArraySegment body) } // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local - static void ThrowIfBodyIsEmpty(ArraySegment body) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void ThrowIfBodyIsEmpty(long bodyLength) { - if (body.Count == 0) + if (bodyLength == 0) { throw new MqttProtocolViolationException("Data from the body is required but not present."); } diff --git a/Source/MQTTnet/Formatter/V5/MqttV5PacketEncoder.cs b/Source/MQTTnet/Formatter/V5/MqttV5PacketEncoder.cs index 9d5a6d09b..25518958c 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV5PacketEncoder.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV5PacketEncoder.cs @@ -56,16 +56,17 @@ public MqttPacketBuffer Encode(MqttPacket packet) _bufferWriter.WriteByte(fixedHeader); _bufferWriter.WriteVariableByteInteger(remainingLength); - var buffer = _bufferWriter.GetBuffer(); - var firstSegment = new ArraySegment(buffer, headerOffset, _bufferWriter.Length - headerOffset); + var firstSegment = _bufferWriter.GetWrittenMemory()[headerOffset..]; - return publishPacket == null ? new MqttPacketBuffer(firstSegment) : new MqttPacketBuffer(firstSegment, publishPacket.Payload); + return publishPacket == null + ? new MqttPacketBuffer(firstSegment) + : new MqttPacketBuffer(firstSegment, publishPacket.Payload); } byte EncodeAuthPacket(MqttAuthPacket packet) { _propertiesWriter.WriteAuthenticationMethod(packet.AuthenticationMethod); - _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData); + _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData.Span); _propertiesWriter.WriteReasonString(packet.ReasonString); _propertiesWriter.WriteUserProperties(packet.UserProperties); @@ -97,7 +98,7 @@ byte EncodeConnAckPacket(MqttConnAckPacket packet) _propertiesWriter.WriteSessionExpiryInterval(packet.SessionExpiryInterval); _propertiesWriter.WriteAuthenticationMethod(packet.AuthenticationMethod); - _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData); + _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData.Span); _propertiesWriter.WriteRetainAvailable(packet.RetainAvailable); _propertiesWriter.WriteReceiveMaximum(packet.ReceiveMaximum); _propertiesWriter.WriteMaximumQoS(packet.MaximumQoS); @@ -166,7 +167,7 @@ byte EncodeConnectPacket(MqttConnectPacket packet) _propertiesWriter.WriteSessionExpiryInterval(packet.SessionExpiryInterval); _propertiesWriter.WriteAuthenticationMethod(packet.AuthenticationMethod); - _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData); + _propertiesWriter.WriteAuthenticationData(packet.AuthenticationData.Span); _propertiesWriter.WriteRequestProblemInformation(packet.RequestProblemInformation); _propertiesWriter.WriteRequestResponseInformation(packet.RequestResponseInformation); _propertiesWriter.WriteReceiveMaximum(packet.ReceiveMaximum); @@ -184,7 +185,7 @@ byte EncodeConnectPacket(MqttConnectPacket packet) _propertiesWriter.WritePayloadFormatIndicator(packet.WillPayloadFormatIndicator); _propertiesWriter.WriteMessageExpiryInterval(packet.WillMessageExpiryInterval); _propertiesWriter.WriteResponseTopic(packet.WillResponseTopic); - _propertiesWriter.WriteCorrelationData(packet.WillCorrelationData); + _propertiesWriter.WriteCorrelationData(packet.WillCorrelationData.Span); _propertiesWriter.WriteContentType(packet.WillContentType); _propertiesWriter.WriteUserProperties(packet.WillUserProperties); _propertiesWriter.WriteWillDelayInterval(packet.WillDelayInterval); @@ -193,7 +194,7 @@ byte EncodeConnectPacket(MqttConnectPacket packet) _propertiesWriter.Reset(); _bufferWriter.WriteString(packet.WillTopic); - _bufferWriter.WriteBinary(packet.WillMessage); + _bufferWriter.WriteBinary(packet.WillMessage.Span); } if (packet.Username != null) @@ -342,7 +343,7 @@ byte EncodePublishPacket(MqttPublishPacket packet) } _propertiesWriter.WriteContentType(packet.ContentType); - _propertiesWriter.WriteCorrelationData(packet.CorrelationData); + _propertiesWriter.WriteCorrelationData(packet.CorrelationData.Span); _propertiesWriter.WriteMessageExpiryInterval(packet.MessageExpiryInterval); _propertiesWriter.WritePayloadFormatIndicator(packet.PayloadFormatIndicator); _propertiesWriter.WriteResponseTopic(packet.ResponseTopic); diff --git a/Source/MQTTnet/Formatter/V5/MqttV5PropertiesReader.cs b/Source/MQTTnet/Formatter/V5/MqttV5PropertiesReader.cs index e5a3efe7e..75bf57807 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV5PropertiesReader.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV5PropertiesReader.cs @@ -3,8 +3,10 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using MQTTnet.Exceptions; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -29,7 +31,7 @@ public MqttV5PropertiesReader(MqttBufferReader body) _length = 0; } - _targetOffset = body.Position + _length; + _targetOffset = (int)(body.Position + _length); CollectedUserProperties = null; CurrentPropertyId = MqttPropertyId.None; @@ -81,9 +83,9 @@ public string ReadAssignedClientIdentifier() return _body.ReadString(); } - public byte[] ReadAuthenticationData() + public ReadOnlyMemory ReadAuthenticationData() { - return _body.ReadBinaryData(); + return _body.ReadBinaryData().Join(); } public string ReadAuthenticationMethod() @@ -96,9 +98,9 @@ public string ReadContentType() return _body.ReadString(); } - public byte[] ReadCorrelationData() + public ReadOnlyMemory ReadCorrelationData() { - return _body.ReadBinaryData(); + return _body.ReadBinaryData().Join(); } public uint ReadMaximumPacketSize() diff --git a/Source/MQTTnet/Formatter/V5/MqttV5PropertiesWriter.cs b/Source/MQTTnet/Formatter/V5/MqttV5PropertiesWriter.cs index 25546d332..992455499 100644 --- a/Source/MQTTnet/Formatter/V5/MqttV5PropertiesWriter.cs +++ b/Source/MQTTnet/Formatter/V5/MqttV5PropertiesWriter.cs @@ -2,10 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; using MQTTnet.Packets; using MQTTnet.Protocol; +using System; +using System.Buffers.Binary; +using System.Collections.Generic; namespace MQTTnet.Formatter.V5 { @@ -31,7 +32,7 @@ public void WriteAssignedClientIdentifier(string value) Write(MqttPropertyId.AssignedClientIdentifier, value); } - public void WriteAuthenticationData(byte[] value) + public void WriteAuthenticationData(ReadOnlySpan value) { Write(MqttPropertyId.AuthenticationData, value); } @@ -46,7 +47,7 @@ public void WriteContentType(string value) Write(MqttPropertyId.ContentType, value); } - public void WriteCorrelationData(byte[] value) + public void WriteCorrelationData(ReadOnlySpan value) { Write(MqttPropertyId.CorrelationData, value); } @@ -88,7 +89,7 @@ public void WriteMessageExpiryInterval(uint value) { return; } - + WriteAsFourByteInteger(MqttPropertyId.MessageExpiryInterval, value); } @@ -297,20 +298,31 @@ public void WriteWillDelayInterval(uint value) void Write(MqttPropertyId id, bool value) { - _bufferWriter.WriteByte((byte)id); - _bufferWriter.WriteByte(value ? (byte)0x1 : (byte)0x0); + Write(id, value ? (byte)0x1 : default); } void Write(MqttPropertyId id, byte value) { - _bufferWriter.WriteByte((byte)id); - _bufferWriter.WriteByte(value); + const int size = 2; + var bufferWriter = _bufferWriter.AsLowLevelBufferWriter(); + var span = bufferWriter.GetSpan(size); + + span[0] = (byte)id; + span[1] = value; + + bufferWriter.Advance(size); } void Write(MqttPropertyId id, ushort value) { - _bufferWriter.WriteByte((byte)id); - _bufferWriter.WriteTwoByteInteger(value); + const int size = 3; + var bufferWriter = _bufferWriter.AsLowLevelBufferWriter(); + var span = bufferWriter.GetSpan(size); + + span[0] = (byte)id; + BinaryPrimitives.WriteUInt16BigEndian(span[1..], value); + + bufferWriter.Advance(size); } void Write(MqttPropertyId id, string value) @@ -324,9 +336,9 @@ void Write(MqttPropertyId id, string value) _bufferWriter.WriteString(value); } - void Write(MqttPropertyId id, byte[] value) + void Write(MqttPropertyId id, ReadOnlySpan value) { - if (value == null) + if (value.IsEmpty) { return; } @@ -337,11 +349,14 @@ void Write(MqttPropertyId id, byte[] value) void WriteAsFourByteInteger(MqttPropertyId id, uint value) { - _bufferWriter.WriteByte((byte)id); - _bufferWriter.WriteByte((byte)(value >> 24)); - _bufferWriter.WriteByte((byte)(value >> 16)); - _bufferWriter.WriteByte((byte)(value >> 8)); - _bufferWriter.WriteByte((byte)value); + const int size = 5; + var bufferWriter = _bufferWriter.AsLowLevelBufferWriter(); + var span = bufferWriter.GetSpan(size); + + span[0] = (byte)id; + BinaryPrimitives.WriteUInt32BigEndian(span[1..], value); + + bufferWriter.Advance(size); } void WriteAsVariableByteInteger(MqttPropertyId id, uint value) diff --git a/Source/MQTTnet/Internal/EmptyBuffer.cs b/Source/MQTTnet/Internal/EmptyBuffer.cs deleted file mode 100644 index eb0506e94..000000000 --- a/Source/MQTTnet/Internal/EmptyBuffer.cs +++ /dev/null @@ -1,18 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Buffers; - -namespace MQTTnet.Internal -{ - public static class EmptyBuffer - { - public static readonly byte[] Array = System.Array.Empty(); - - public static readonly ArraySegment ArraySegment = new ArraySegment(Array, 0, 0); - - public static readonly ReadOnlySequence ReadOnlySequence = ReadOnlySequence.Empty; - } -} \ No newline at end of file diff --git a/Source/MQTTnet/Internal/MqttMemoryHelper.cs b/Source/MQTTnet/Internal/MqttMemoryHelper.cs index 8f3b10083..4aa30c27f 100644 --- a/Source/MQTTnet/Internal/MqttMemoryHelper.cs +++ b/Source/MQTTnet/Internal/MqttMemoryHelper.cs @@ -6,6 +6,21 @@ namespace MQTTnet.Internal { public static class MqttMemoryHelper { + public static ReadOnlyMemory Join(this ReadOnlySequence buffer) + { + if (buffer.IsEmpty) + { + return default; + } + + if (buffer.IsSingleSegment) + { + return buffer.First; + } + + return buffer.ToArray(); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void Copy(byte[] source, int sourceIndex, byte[] destination, int destinationIndex, int length) { diff --git a/Source/MQTTnet/Internal/MqttPayloadOwner.cs b/Source/MQTTnet/Internal/MqttPayloadOwner.cs new file mode 100644 index 000000000..4c27ecb7d --- /dev/null +++ b/Source/MQTTnet/Internal/MqttPayloadOwner.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public abstract class MqttPayloadOwner : IAsyncDisposable + { + private bool _disposed = false; + public static MqttPayloadOwner Empty { get; } = new EmptyPayloadOwner(); + + public abstract ReadOnlySequence Payload { get; } + + public ValueTask DisposeAsync() + { + if (!_disposed) + { + _disposed = true; + return DisposeAsync(true); + } + return ValueTask.CompletedTask; + } + + protected virtual ValueTask DisposeAsync(bool disposing) + { + return ValueTask.CompletedTask; + } + + + private sealed class EmptyPayloadOwner : MqttPayloadOwner + { + public override ReadOnlySequence Payload => ReadOnlySequence.Empty; + protected override ValueTask DisposeAsync(bool disposing) + { + return ValueTask.CompletedTask; + } + } + } +} diff --git a/Source/MQTTnet/Internal/MqttPayloadOwnerFactory.cs b/Source/MQTTnet/Internal/MqttPayloadOwnerFactory.cs new file mode 100644 index 000000000..105527620 --- /dev/null +++ b/Source/MQTTnet/Internal/MqttPayloadOwnerFactory.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public static class MqttPayloadOwnerFactory + { + /// + /// Create owner for a single segment payload + /// + /// + /// + /// + public static MqttPayloadOwner CreateSingleSegment(int payloadSize, Action> payloadFactory) + { + ArgumentNullException.ThrowIfNull(payloadFactory); + return SingleSegmentPayloadOwner.Create(payloadSize, payloadFactory); + } + + /// + /// Create owner for a multiple segments payload + /// + /// + /// + /// + public static ValueTask CreateMultipleSegmentAsync(Func payloadFactory, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payloadFactory); + return MultipleSegmentPayloadOwner.CreateAsync(payloadFactory, cancellationToken); + } + + private sealed class SingleSegmentPayloadOwner : MqttPayloadOwner + { + private readonly byte[] _buffer; + public override ReadOnlySequence Payload { get; } + + private SingleSegmentPayloadOwner(byte[] buffer, ReadOnlyMemory payload) + { + _buffer = buffer; + Payload = new ReadOnlySequence(payload); + } + + public static MqttPayloadOwner Create(int payloadSize, Action> payloadFactory) + { + byte[] buffer; + Memory payload; + + if (payloadSize <= 0) + { + buffer = Array.Empty(); + payload = Memory.Empty; + } + else + { + buffer = ArrayPool.Shared.Rent(payloadSize); + payload = buffer.AsMemory(0, payloadSize); + } + + payloadFactory.Invoke(payload); + return new SingleSegmentPayloadOwner(buffer, payload); + } + + protected override ValueTask DisposeAsync(bool disposing) + { + if (_buffer.Length > 0) + { + ArrayPool.Shared.Return(_buffer); + } + return ValueTask.CompletedTask; + } + } + + private sealed class MultipleSegmentPayloadOwner : MqttPayloadOwner + { + private readonly Pipe _pipe; + private static readonly ConcurrentQueue _pipeQueue = new(); + public override ReadOnlySequence Payload { get; } + + private MultipleSegmentPayloadOwner(Pipe pipe, ReadOnlySequence payload) + { + _pipe = pipe; + Payload = payload; + } + + public static async ValueTask CreateAsync(Func payloadFactory, CancellationToken cancellationToken) + { + if (!_pipeQueue.TryDequeue(out var pipe)) + { + pipe = new Pipe(); + } + + await payloadFactory.Invoke(pipe.Writer).ConfigureAwait(false); + await pipe.Writer.CompleteAsync().ConfigureAwait(false); + + var payload = await ReadPayloadAsync(pipe.Reader, cancellationToken).ConfigureAwait(false); + return new MultipleSegmentPayloadOwner(pipe, payload); + } + + protected override async ValueTask DisposeAsync(bool disposing) + { + await _pipe.Reader.CompleteAsync().ConfigureAwait(false); + _pipe.Reset(); + _pipeQueue.Enqueue(_pipe); + } + + private static async ValueTask> ReadPayloadAsync( + PipeReader pipeReader, + CancellationToken cancellationToken) + { + var readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); + while (!readResult.IsCompleted) + { + pipeReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.Start); + readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false); + } + + return readResult.Buffer; + } + } + } +} diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index ece6812a0..392d43174 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -2,7 +2,7 @@ - + @(ReleaseNotes, '%0a') @@ -35,7 +35,7 @@ false nuget.png true - true + MIT true 1591;NETSDK1138;NU1803;NU1901;NU1902 @@ -44,7 +44,7 @@ all true low - latest-Recommended + @@ -62,7 +62,8 @@ - + + \ No newline at end of file diff --git a/Source/MQTTnet/MqttApplicationMessage.cs b/Source/MQTTnet/MqttApplicationMessage.cs index eb402e155..e9166d1f8 100644 --- a/Source/MQTTnet/MqttApplicationMessage.cs +++ b/Source/MQTTnet/MqttApplicationMessage.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; using System; @@ -26,7 +25,7 @@ public sealed class MqttApplicationMessage /// published message. /// Hint: MQTT 5 feature only. /// - public byte[] CorrelationData { get; set; } + public ReadOnlyMemory CorrelationData { get; set; } /// /// If the DUP flag is set to 0, it indicates that this is the first occasion that the Client or Server has attempted @@ -53,7 +52,7 @@ public sealed class MqttApplicationMessage /// /// Set an ArraySegment as Payload. /// - public ArraySegment PayloadSegment + public ReadOnlyMemory PayloadSegment { set { Payload = new ReadOnlySequence(value); } } @@ -64,7 +63,7 @@ public ArraySegment PayloadSegment /// It can be used in combination with a RecyclableMemoryStream to publish /// large buffered messages without allocating large chunks of memory. /// - public ReadOnlySequence Payload { get; set; } = EmptyBuffer.ReadOnlySequence; + public ReadOnlySequence Payload { get; set; } /// /// Gets or sets the payload format indicator. diff --git a/Source/MQTTnet/MqttApplicationMessageBuilder.cs b/Source/MQTTnet/MqttApplicationMessageBuilder.cs index 2c5c2dd70..c1ac16cf2 100644 --- a/Source/MQTTnet/MqttApplicationMessageBuilder.cs +++ b/Source/MQTTnet/MqttApplicationMessageBuilder.cs @@ -2,24 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Packets; +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Runtime.InteropServices; using System.Text; -using MQTTnet.Exceptions; -using MQTTnet.Internal; -using MQTTnet.Packets; -using MQTTnet.Protocol; namespace MQTTnet { public sealed class MqttApplicationMessageBuilder { string _contentType; - byte[] _correlationData; + ReadOnlyMemory _correlationData; uint _messageExpiryInterval; MqttPayloadFormatIndicator _payloadFormatIndicator; @@ -72,7 +69,7 @@ public MqttApplicationMessageBuilder WithContentType(string contentType) /// Adds the correlation data to the message. /// MQTT 5.0.0+ feature. /// - public MqttApplicationMessageBuilder WithCorrelationData(byte[] correlationData) + public MqttApplicationMessageBuilder WithCorrelationData(ReadOnlyMemory correlationData) { _correlationData = correlationData; return this; @@ -88,48 +85,55 @@ public MqttApplicationMessageBuilder WithMessageExpiryInterval(uint messageExpir return this; } - public MqttApplicationMessageBuilder WithPayload(byte[] payload) + public MqttApplicationMessageBuilder WithPayload(ReadOnlySequence payload) { - _payload = payload == null || payload.Length == 0 ? EmptyBuffer.ReadOnlySequence : new ReadOnlySequence(payload); + _payload = payload; return this; } - public MqttApplicationMessageBuilder WithPayload(ArraySegment payloadSegment) + public MqttApplicationMessageBuilder WithPayload(ReadOnlyMemory payload) { - _payload = new ReadOnlySequence(payloadSegment); - return this; + return WithPayload(new ReadOnlySequence(payload)); } - public MqttApplicationMessageBuilder WithPayload(IEnumerable payload) + /// + /// This method causes memory allocation when transcoding the string payload. + /// * Use the method IMqttClient.PublishStringAsync() instead in client side. + /// * Use the method MqttServer.InjectStringAsync() instead in server side. + /// + /// + /// + /// + public MqttApplicationMessageBuilder WithPayload(string payload) { - if (payload == null) - { - return WithPayload(default(byte[])); - } - - if (payload is byte[] byteArray) - { - return WithPayload(byteArray); - } - - if (payload is ArraySegment arraySegment) - { - return WithPayloadSegment(arraySegment); - } - - return WithPayload(payload.ToArray()); + return string.IsNullOrEmpty(payload) ? this : WithPayload(Encoding.UTF8.GetBytes(payload)); } + /// + /// This method causes memory allocation when transcoding the stream payload. + /// * Use the method IMqttClient.PublishStreamAsync() instead in client side. + /// * Use the method MqttServer.InjectStreamAsync() instead in server side. + /// + /// + /// public MqttApplicationMessageBuilder WithPayload(Stream payload) { - return payload == null ? WithPayload(default(byte[])) : WithPayload(payload, payload.Length - payload.Position); + return payload == null ? this : WithPayload(payload, payload.Length - payload.Position); } + /// + /// This method causes memory allocation when transcoding the stream payload. + /// * Use the method IMqttClient.PublishStreamAsync() instead in client side. + /// * Use the method MqttServer.InjectStreamAsync() instead in server side. + /// + /// + /// + /// public MqttApplicationMessageBuilder WithPayload(Stream payload, long length) - { + { if (payload == null || length == 0) { - return WithPayload(default(byte[])); + return this; } var payloadBuffer = new byte[length]; @@ -148,22 +152,6 @@ public MqttApplicationMessageBuilder WithPayload(Stream payload, long length) return WithPayload(payloadBuffer); } - public MqttApplicationMessageBuilder WithPayload(string payload) - { - if (string.IsNullOrEmpty(payload)) - { - return WithPayload(default(byte[])); - } - - var payloadBuffer = Encoding.UTF8.GetBytes(payload); - return WithPayload(payloadBuffer); - } - - public MqttApplicationMessageBuilder WithPayload(ReadOnlySequence payload) - { - _payload = payload; - return this; - } /// /// Adds the payload format indicator to the message. @@ -175,16 +163,6 @@ public MqttApplicationMessageBuilder WithPayloadFormatIndicator(MqttPayloadForma return this; } - public MqttApplicationMessageBuilder WithPayloadSegment(ArraySegment payloadSegment) - { - _payload = new ReadOnlySequence(payloadSegment); - return this; - } - - public MqttApplicationMessageBuilder WithPayloadSegment(ReadOnlyMemory payloadSegment) - { - return MemoryMarshal.TryGetArray(payloadSegment, out var segment) ? WithPayloadSegment(segment) : WithPayload(payloadSegment.ToArray()); - } /// /// The quality of service level. @@ -266,11 +244,7 @@ public MqttApplicationMessageBuilder WithTopicAlias(ushort topicAlias) /// public MqttApplicationMessageBuilder WithUserProperty(string name, string value) { - if (_userProperties == null) - { - _userProperties = new List(); - } - + _userProperties ??= new List(); _userProperties.Add(new MqttUserProperty(name, value)); return this; } diff --git a/Source/MQTTnet/MqttApplicationMessageExtensions.cs b/Source/MQTTnet/MqttApplicationMessageExtensions.cs index eef59cb54..a5392aaac 100644 --- a/Source/MQTTnet/MqttApplicationMessageExtensions.cs +++ b/Source/MQTTnet/MqttApplicationMessageExtensions.cs @@ -4,6 +4,8 @@ using System; using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; namespace MQTTnet; @@ -13,11 +15,38 @@ public static string ConvertPayloadToString(this MqttApplicationMessage applicat { ArgumentNullException.ThrowIfNull(applicationMessage); - if (applicationMessage.Payload.Length == 0) - { - return null; - } + return applicationMessage.Payload.IsEmpty + ? null + : Encoding.UTF8.GetString(applicationMessage.Payload); + } + + public static TValue ConvertPayloadToJson(this MqttApplicationMessage applicationMessage, JsonTypeInfo jsonTypeInfo) + { + ArgumentNullException.ThrowIfNull(applicationMessage); + ArgumentNullException.ThrowIfNull(jsonTypeInfo); + + var readerOptions = CreateJsonReaderOptions(jsonTypeInfo.Options); + var jsonReader = new Utf8JsonReader(applicationMessage.Payload, readerOptions); + return JsonSerializer.Deserialize(ref jsonReader, jsonTypeInfo); + } - return Encoding.UTF8.GetString(applicationMessage.Payload); + public static TValue ConvertPayloadToJson(this MqttApplicationMessage applicationMessage, JsonSerializerOptions jsonSerializerOptions = null) + { + ArgumentNullException.ThrowIfNull(applicationMessage); + + var readerOptions = CreateJsonReaderOptions(jsonSerializerOptions); + var jsonReader = new Utf8JsonReader(applicationMessage.Payload, readerOptions); + return JsonSerializer.Deserialize(ref jsonReader, jsonSerializerOptions); + } + + private static JsonReaderOptions CreateJsonReaderOptions(JsonSerializerOptions jsonSerializerOptions) + { + jsonSerializerOptions ??= JsonSerializerOptions.Default; + return new JsonReaderOptions + { + MaxDepth = jsonSerializerOptions.MaxDepth, + AllowTrailingCommas = jsonSerializerOptions.AllowTrailingCommas, + CommentHandling = jsonSerializerOptions.ReadCommentHandling + }; } } \ No newline at end of file diff --git a/Source/MQTTnet/MqttApplicationMessageValidator.cs b/Source/MQTTnet/MqttApplicationMessageValidator.cs index d7eb3b611..2b7ad87bb 100644 --- a/Source/MQTTnet/MqttApplicationMessageValidator.cs +++ b/Source/MQTTnet/MqttApplicationMessageValidator.cs @@ -31,7 +31,7 @@ public static void ThrowIfNotSupported(MqttApplicationMessage applicationMessage Throw(nameof(applicationMessage.UserProperties)); } - if (applicationMessage.CorrelationData?.Any() == true) + if (applicationMessage.CorrelationData.Length > 0) { Throw(nameof(applicationMessage.CorrelationData)); } diff --git a/Source/MQTTnet/MqttClientExtensions.cs b/Source/MQTTnet/MqttClientExtensions.cs index f6d09dd08..b7486c29b 100644 --- a/Source/MQTTnet/MqttClientExtensions.cs +++ b/Source/MQTTnet/MqttClientExtensions.cs @@ -2,14 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Internal; +using MQTTnet.Packets; +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Packets; -using MQTTnet.Protocol; namespace MQTTnet; @@ -36,10 +41,10 @@ public static Task DisconnectAsync( return client.DisconnectAsync(disconnectOptions, cancellationToken); } - public static Task PublishBinaryAsync( + public static Task PublishSequenceAsync( this IMqttClient mqttClient, string topic, - IEnumerable payload = null, + ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, bool retain = false, CancellationToken cancellationToken = default) @@ -47,53 +52,134 @@ public static Task PublishBinaryAsync( ArgumentNullException.ThrowIfNull(mqttClient); ArgumentNullException.ThrowIfNull(topic); - var applicationMessage = new MqttApplicationMessageBuilder().WithTopic(topic) - .WithPayload(payload) - .WithRetainFlag(retain) - .WithQualityOfServiceLevel(qualityOfServiceLevel) - .Build(); + var applicationMessage = new MqttApplicationMessage + { + Topic = topic, + Payload = payload, + Retain = retain, + QualityOfServiceLevel = qualityOfServiceLevel + }; return mqttClient.PublishAsync(applicationMessage, cancellationToken); } - public static Task PublishSequenceAsync( + public static async Task PublishSequenceAsync( this IMqttClient mqttClient, string topic, - ReadOnlySequence payload, + Func payloadFactory, MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, bool retain = false, CancellationToken cancellationToken = default) { - ArgumentNullException.ThrowIfNull(mqttClient); - ArgumentNullException.ThrowIfNull(topic); + ArgumentNullException.ThrowIfNull(payloadFactory); - var applicationMessage = new MqttApplicationMessageBuilder().WithTopic(topic) - .WithPayload(payload) - .WithRetainFlag(retain) - .WithQualityOfServiceLevel(qualityOfServiceLevel) - .Build(); + await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(payloadFactory, cancellationToken).ConfigureAwait(false); + return await mqttClient.PublishSequenceAsync(topic, payloadOwner.Payload, qualityOfServiceLevel, retain, cancellationToken).ConfigureAwait(false); + } - return mqttClient.PublishAsync(applicationMessage, cancellationToken); + public static Task PublishBinaryAsync( + this IMqttClient mqttClient, + string topic, + ReadOnlyMemory payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + return mqttClient.PublishSequenceAsync(topic, new ReadOnlySequence(payload), qualityOfServiceLevel, retain, cancellationToken); + } + + public static async Task PublishBinaryAsync( + this IMqttClient mqttClient, + string topic, + int payloadSize, + Action> payloadFactory, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + await using var payloadOwner = MqttPayloadOwnerFactory.CreateSingleSegment(payloadSize, payloadFactory); + return await mqttClient.PublishSequenceAsync(topic, payloadOwner.Payload, qualityOfServiceLevel, retain, cancellationToken).ConfigureAwait(false); } public static Task PublishStringAsync( this IMqttClient mqttClient, string topic, - string payload = null, + string payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + return string.IsNullOrEmpty(payload) + ? mqttClient.PublishSequenceAsync(topic, ReadOnlySequence.Empty, qualityOfServiceLevel, retain, cancellationToken) + : mqttClient.PublishSequenceAsync(topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + Encoding.UTF8.GetBytes(payload, writer); + await writer.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + + public static Task PublishJsonAsync( + this IMqttClient mqttClient, + string topic, + TValue payload, + JsonSerializerOptions jsonSerializerOptions = default, MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, bool retain = false, CancellationToken cancellationToken = default) { - var payloadBuffer = Encoding.UTF8.GetBytes(payload ?? string.Empty); - return mqttClient.PublishBinaryAsync(topic, payloadBuffer, qualityOfServiceLevel, retain, cancellationToken); + return mqttClient.PublishSequenceAsync(topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + var stream = writer.AsStream(leaveOpen: true); + await JsonSerializer.SerializeAsync(stream, payload, jsonSerializerOptions, cancellationToken).ConfigureAwait(false); + } } + public static Task PublishJsonAsync( + this IMqttClient mqttClient, + string topic, + TValue payload, + JsonTypeInfo jsonTypeInfo, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(jsonTypeInfo); + return mqttClient.PublishSequenceAsync(topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + var stream = writer.AsStream(leaveOpen: true); + await JsonSerializer.SerializeAsync(stream, payload, jsonTypeInfo, cancellationToken).ConfigureAwait(false); + } + } + + public static Task PublishStreamAsync( + this IMqttClient mqttClient, + string topic, + Stream payload, + MqttQualityOfServiceLevel qualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce, + bool retain = false, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payload); + return mqttClient.PublishSequenceAsync(topic, WritePayloadAsync, qualityOfServiceLevel, retain, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + await payload.CopyToAsync(writer, cancellationToken).ConfigureAwait(false); + } + } + + public static Task ReconnectAsync(this IMqttClient client, CancellationToken cancellationToken = default) { if (client.Options == null) { - throw new InvalidOperationException( - "The MQTT client was not connected before. A reconnect is only permitted when the client was already connected or at least tried to."); + throw new InvalidOperationException("The MQTT client was not connected before. A reconnect is only permitted when the client was already connected or at least tried to."); } return client.ConnectAsync(client.Options, cancellationToken); diff --git a/Source/MQTTnet/Options/IMqttClientCredentialsProvider.cs b/Source/MQTTnet/Options/IMqttClientCredentialsProvider.cs index 394388d0c..adc0dfd9b 100644 --- a/Source/MQTTnet/Options/IMqttClientCredentialsProvider.cs +++ b/Source/MQTTnet/Options/IMqttClientCredentialsProvider.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace MQTTnet; public interface IMqttClientCredentialsProvider diff --git a/Source/MQTTnet/Options/MqttClientCredentials.cs b/Source/MQTTnet/Options/MqttClientCredentials.cs index 0de96c0fd..1e5d54e2f 100644 --- a/Source/MQTTnet/Options/MqttClientCredentials.cs +++ b/Source/MQTTnet/Options/MqttClientCredentials.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; + namespace MQTTnet; public sealed class MqttClientCredentials : IMqttClientCredentialsProvider @@ -9,7 +11,7 @@ public sealed class MqttClientCredentials : IMqttClientCredentialsProvider readonly byte[] _password; readonly string _userName; - public MqttClientCredentials(string userName, byte[] password = null) + public MqttClientCredentials(string userName, byte[] password ) { _userName = userName; _password = password; diff --git a/Source/MQTTnet/Options/MqttClientOptions.cs b/Source/MQTTnet/Options/MqttClientOptions.cs index d17dd0548..14a03b7a5 100644 --- a/Source/MQTTnet/Options/MqttClientOptions.cs +++ b/Source/MQTTnet/Options/MqttClientOptions.cs @@ -24,7 +24,7 @@ public sealed class MqttClientOptions /// Gets or sets the authentication data. /// MQTT 5.0.0+ feature. /// - public byte[] AuthenticationData { get; set; } + public ReadOnlyMemory AuthenticationData { get; set; } /// /// Gets or sets the authentication method. @@ -159,7 +159,7 @@ public sealed class MqttClientOptions /// Gets or sets the correlation data of the will message. /// MQTT 5.0.0+ feature. /// - public byte[] WillCorrelationData { get; set; } + public ReadOnlyMemory WillCorrelationData { get; set; } /// /// Gets or sets the will delay interval. @@ -177,7 +177,7 @@ public sealed class MqttClientOptions /// /// Gets or sets the payload of the will message. /// - public byte[] WillPayload { get; set; } + public ReadOnlyMemory WillPayload { get; set; } /// /// Gets or sets the payload format indicator of the will message. diff --git a/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs index b1e14a797..33d24c6b0 100644 --- a/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Options/MqttClientOptionsBuilder.cs @@ -98,7 +98,7 @@ public MqttClientOptionsBuilder WithAddressFamily(AddressFamily addressFamily) return this; } - public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data) + public MqttClientOptionsBuilder WithAuthentication(string method, ReadOnlyMemory data) { _options.AuthenticationMethod = method; _options.AuthenticationData = data; @@ -191,17 +191,12 @@ public MqttClientOptionsBuilder WithConnectionUri(string uri) public MqttClientOptionsBuilder WithCredentials(string username, string password) { - byte[] passwordBuffer = null; - - if (password != null) - { - passwordBuffer = Encoding.UTF8.GetBytes(password); - } - - return WithCredentials(username, passwordBuffer); + return password == null + ? WithCredentials(username, default(byte[])) + : WithCredentials(username, Encoding.UTF8.GetBytes(password)); } - public MqttClientOptionsBuilder WithCredentials(string username, byte[] password = null) + public MqttClientOptionsBuilder WithCredentials(string username, byte[] password = default) { return WithCredentials(new MqttClientCredentials(username, password)); } @@ -395,7 +390,7 @@ public MqttClientOptionsBuilder WithWillContentType(string willContentType) return this; } - public MqttClientOptionsBuilder WithWillCorrelationData(byte[] willCorrelationData) + public MqttClientOptionsBuilder WithWillCorrelationData(ReadOnlyMemory willCorrelationData) { _options.WillCorrelationData = willCorrelationData; return this; @@ -413,33 +408,15 @@ public MqttClientOptionsBuilder WithWillMessageExpiryInterval(uint willMessageEx return this; } - public MqttClientOptionsBuilder WithWillPayload(byte[] willPayload) + public MqttClientOptionsBuilder WithWillPayload(ReadOnlyMemory willPayload) { _options.WillPayload = willPayload; return this; } - public MqttClientOptionsBuilder WithWillPayload(ArraySegment willPayload) - { - if (willPayload.Count == 0) - { - _options.WillPayload = null; - return this; - } - - _options.WillPayload = willPayload.ToArray(); - return this; - } - public MqttClientOptionsBuilder WithWillPayload(string willPayload) { - if (string.IsNullOrEmpty(willPayload)) - { - return WithWillPayload((byte[])null); - } - - _options.WillPayload = Encoding.UTF8.GetBytes(willPayload); - return this; + return string.IsNullOrEmpty(willPayload) ? this : WithWillPayload(Encoding.UTF8.GetBytes(willPayload)); } public MqttClientOptionsBuilder WithWillPayloadFormatIndicator(MqttPayloadFormatIndicator willPayloadFormatIndicator) diff --git a/Source/MQTTnet/Options/MqttClientOptionsValidator.cs b/Source/MQTTnet/Options/MqttClientOptionsValidator.cs index 19b2c91c3..0ce31882a 100644 --- a/Source/MQTTnet/Options/MqttClientOptionsValidator.cs +++ b/Source/MQTTnet/Options/MqttClientOptionsValidator.cs @@ -59,7 +59,7 @@ public static void ThrowIfNotSupported(MqttClientOptions options) // Authentication relevant properties. - if (options.AuthenticationData?.Any() == true) + if (options.AuthenticationData.Length > 0) { Throw(nameof(options.AuthenticationData)); } @@ -81,7 +81,7 @@ public static void ThrowIfNotSupported(MqttClientOptions options) Throw(nameof(options.WillContentType)); } - if (options.WillCorrelationData?.Any() == true) + if (options.WillCorrelationData.Length > 0) { Throw(nameof(options.WillCorrelationData)); } diff --git a/Source/MQTTnet/Packets/MqttAuthPacket.cs b/Source/MQTTnet/Packets/MqttAuthPacket.cs index 8c1b710e7..db2b1aa69 100644 --- a/Source/MQTTnet/Packets/MqttAuthPacket.cs +++ b/Source/MQTTnet/Packets/MqttAuthPacket.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using MQTTnet.Protocol; @@ -10,7 +11,7 @@ namespace MQTTnet.Packets; /// Added in MQTTv5.0.0. public sealed class MqttAuthPacket : MqttPacket { - public byte[] AuthenticationData { get; set; } + public ReadOnlyMemory AuthenticationData { get; set; } public string AuthenticationMethod { get; set; } diff --git a/Source/MQTTnet/Packets/MqttConnAckPacket.cs b/Source/MQTTnet/Packets/MqttConnAckPacket.cs index 3906a9004..fc54ac481 100644 --- a/Source/MQTTnet/Packets/MqttConnAckPacket.cs +++ b/Source/MQTTnet/Packets/MqttConnAckPacket.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using MQTTnet.Protocol; @@ -14,7 +15,7 @@ public sealed class MqttConnAckPacket : MqttPacket /// public string AssignedClientIdentifier { get; set; } - public byte[] AuthenticationData { get; set; } + public ReadOnlyMemory AuthenticationData { get; set; } public string AuthenticationMethod { get; set; } diff --git a/Source/MQTTnet/Packets/MqttConnectPacket.cs b/Source/MQTTnet/Packets/MqttConnectPacket.cs index f9e83f35d..538f1b84d 100644 --- a/Source/MQTTnet/Packets/MqttConnectPacket.cs +++ b/Source/MQTTnet/Packets/MqttConnectPacket.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using MQTTnet.Protocol; @@ -9,7 +10,7 @@ namespace MQTTnet.Packets { public sealed class MqttConnectPacket : MqttPacket { - public byte[] AuthenticationData { get; set; } + public ReadOnlyMemory AuthenticationData { get; set; } public string AuthenticationMethod { get; set; } @@ -20,10 +21,10 @@ public sealed class MqttConnectPacket : MqttPacket public string ClientId { get; set; } - public byte[] WillCorrelationData { get; set; } + public ReadOnlyMemory WillCorrelationData { get; set; } public ushort KeepAlivePeriod { get; set; } - + public uint MaximumPacketSize { get; set; } public byte[] Password { get; set; } @@ -50,7 +51,7 @@ public sealed class MqttConnectPacket : MqttPacket public bool WillFlag { get; set; } - public byte[] WillMessage { get; set; } + public ReadOnlyMemory WillMessage { get; set; } public uint WillMessageExpiryInterval { get; set; } @@ -65,7 +66,7 @@ public sealed class MqttConnectPacket : MqttPacket public List WillUserProperties { get; set; } public bool TryPrivate { get; set; } - + public override string ToString() { var passwordText = string.Empty; diff --git a/Source/MQTTnet/Packets/MqttPublishPacket.cs b/Source/MQTTnet/Packets/MqttPublishPacket.cs index 9edc789e2..62521f932 100644 --- a/Source/MQTTnet/Packets/MqttPublishPacket.cs +++ b/Source/MQTTnet/Packets/MqttPublishPacket.cs @@ -2,10 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Generic; -using MQTTnet.Protocol; namespace MQTTnet.Packets; @@ -13,7 +13,7 @@ public sealed class MqttPublishPacket : MqttPacketWithIdentifier { public string ContentType { get; set; } - public byte[] CorrelationData { get; set; } + public ReadOnlyMemory CorrelationData { get; set; } public bool Dup { get; set; } @@ -21,7 +21,7 @@ public sealed class MqttPublishPacket : MqttPacketWithIdentifier public MqttPayloadFormatIndicator PayloadFormatIndicator { get; set; } = MqttPayloadFormatIndicator.Unspecified; - public ArraySegment PayloadSegment { set { Payload = new ReadOnlySequence(value); } } + public ReadOnlyMemory PayloadSegment { set { Payload = new ReadOnlySequence(value); } } public ReadOnlySequence Payload { get; set; }