diff --git a/benchmark/Benchmark.cs b/benchmark/Benchmark.cs index f474675..45611a0 100644 --- a/benchmark/Benchmark.cs +++ b/benchmark/Benchmark.cs @@ -254,8 +254,27 @@ public unsafe void RunScalarDecodingBenchmarkUTF8(string[] data, int[] lengths) throw new Exception("Error"); } } - } - + } + + public unsafe void RunScalarDecodingBenchmarkUTF16(string[] data, int[] lengths) + { + for (int i = 0; i < FileContent.Length; i++) + { + string s = FileContent[i]; + char[] base64 = s.ToCharArray(); + byte[] dataoutput = output[i]; + int bytesConsumed = 0; + int bytesWritten = 0; + SimdBase64.Base64.Base64WithWhiteSpaceToBinaryScalar(base64.AsSpan(), dataoutput, out bytesConsumed, out bytesWritten, false); + if (bytesWritten != lengths[i]) + { + Console.WriteLine($"Error: {bytesWritten} != {lengths[i]}"); +#pragma warning disable CA2201 + throw new Exception("Error"); + } + } + } + public unsafe void RunSSEDecodingBenchmarkUTF8(string[] data, int[] lengths) { for (int i = 0; i < FileContent.Length; i++) @@ -275,13 +294,34 @@ public unsafe void RunSSEDecodingBenchmarkUTF8(string[] data, int[] lengths) } } + public unsafe void RunSSEDecodingBenchmarkUTF16(string[] data, int[] lengths) + { + for (int i = 0; i < FileContent.Length; i++) + { + string s = FileContent[i]; + ReadOnlySpan base64 = s.AsSpan(); + byte[] dataoutput = output[i]; + int bytesConsumed = 0; + int bytesWritten = 0; + SimdBase64.Base64.DecodeFromBase64SSE(base64, dataoutput, out bytesConsumed, out bytesWritten, false); + if (bytesWritten != lengths[i]) + { + Console.WriteLine($"Error: {bytesWritten} != {lengths[i]}"); +#pragma warning disable CA2201 + throw new Exception("Error"); + } + } + } + + + public unsafe void RunSSEDecodingBenchmarkWithAllocUTF8(string[] data, int[] lengths) { for (int i = 0; i < FileContent.Length; i++) { //string s = FileContent[i]; byte[] base64 = input[i]; - byte[] dataoutput = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64.AsSpan())]; + byte[] dataoutput = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64.AsSpan())]; //byte[] dataoutput = output[i]; int bytesConsumed = 0; int bytesWritten = 0; @@ -295,6 +335,25 @@ public unsafe void RunSSEDecodingBenchmarkWithAllocUTF8(string[] data, int[] len } } + public unsafe void RunSSEDecodingBenchmarkWithAllocUTF16(string[] data, int[] lengths) + { + for (int i = 0; i < FileContent.Length; i++) + { + string s = FileContent[i]; + char[] base64 = s.ToCharArray(); + byte[] dataoutput = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64.AsSpan())]; + int bytesConsumed = 0; + int bytesWritten = 0; + SimdBase64.Base64.DecodeFromBase64SSE(base64.AsSpan(), dataoutput, out bytesConsumed, out bytesWritten, false); + if (bytesWritten != lengths[i]) + { + Console.WriteLine($"Error: {bytesWritten} != {lengths[i]}"); +#pragma warning disable CA2201 + throw new Exception("Error"); + } + } + } + [GlobalSetup] public void Setup() { @@ -388,6 +447,20 @@ public unsafe void SSEDecodingRealDataWithAllocUTF8() RunSSEDecodingBenchmarkWithAllocUTF8(FileContent, DecodedLengths); } + [Benchmark] + [BenchmarkCategory("SSE")] + public unsafe void SSEDecodingRealDataUTF16() + { + RunSSEDecodingBenchmarkUTF16(FileContent, DecodedLengths); + } + + [Benchmark] + [BenchmarkCategory("SSE")] + public unsafe void SSEDecodingRealDataWithAllocUTF16() + { + RunSSEDecodingBenchmarkWithAllocUTF16(FileContent, DecodedLengths); + } + } public class Program { diff --git a/src/Base64SSEUTF16.cs b/src/Base64SSEUTF16.cs new file mode 100644 index 0000000..f1fc30b --- /dev/null +++ b/src/Base64SSEUTF16.cs @@ -0,0 +1,764 @@ +using System; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Buffers; +using System.Buffers.Binary; +using System.IO.Pipes; +using System.Text; +using System.Reflection; +using System.Diagnostics; +using System.Numerics; + +namespace SimdBase64 +{ + public static partial class Base64 + { + + // Caller is responsible for checking that Ssse3.IsSupported && Popcnt.IsSupported + public unsafe static OperationStatus DecodeFromBase64SSE(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { + if (isUrl) + { + return InnerDecodeFromBase64SSEUrl(source, dest, out bytesConsumed, out bytesWritten); + } + else + { + return InnerDecodeFromBase64SSERegular(source, dest, out bytesConsumed, out bytesWritten); + } + } + + private unsafe static OperationStatus InnerDecodeFromBase64SSERegular(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten) + { + // translation from ASCII to 6 bit values + bool isUrl = false; + byte[] toBase64 = Tables.ToBase64Value; + bytesConsumed = 0; + bytesWritten = 0; + const int blocksSize = 6; + Span buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + // Define pointers within the fixed blocks + fixed (char* srcInit = source) + fixed (byte* dstInit = dest) + fixed (byte* startOfBuffer = buffer) + { + char* srcEnd = srcInit + source.Length; + char* src = srcInit; + byte* dst = dstInit; + byte* dstEnd = dstInit + dest.Length; + + int whiteSpaces = 0; + int equalsigns = 0; + + int bytesToProcess = source.Length; + // skip trailing spaces + while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) + { + bytesToProcess--; + whiteSpaces++; + } + + int equallocation = bytesToProcess; // location of the first padding character if any + if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') + { + bytesToProcess -= 1; + equalsigns++; + while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) + { + bytesToProcess--; + whiteSpaces++; + } + if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') + { + equalsigns++; + bytesToProcess -= 1; + } + } + + // round up to the nearest multiple of 4, then multiply by 3 + int decoded3bitsChunksToProcess = (bytesToProcess + 3) / 4 * 3; + + byte* endOfSafe64ByteZone = + decoded3bitsChunksToProcess >= 63 ? + dst + decoded3bitsChunksToProcess - 63 : + dst; + + { + byte* bufferPtr = startOfBuffer; + + ulong bufferBytesConsumed = 0;//Only used if there is an error + ulong bufferBytesWritten = 0;//Only used if there is an error + + if (bytesToProcess >= 64) + { + char* srcEnd64 = srcInit + bytesToProcess - 64; + while (src <= srcEnd64) + { + Base64.Block64 b; + Base64.LoadBlock(&b, src); + src += 64; + bufferBytesConsumed += 64; + bool error = false; + UInt64 badCharMask = Base64.ToBase64Mask(isUrl, &b, ref error); + if (error == true) + { + src -= bufferBytesConsumed; + dst -= bufferBytesWritten; + + bytesConsumed = Math.Max(0, (int)(src - srcInit)); + bytesWritten = Math.Max(0, (int)(dst - dstInit)); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + if (badCharMask != 0) + { + // optimization opportunity: check for simple masks like those made of + // continuous 1s followed by continuous 0s. And masks containing a + // single bad character. + ulong compressedBytesCount = CompressBlock(ref b, badCharMask, bufferPtr); + bufferPtr += compressedBytesCount; + bufferBytesConsumed += compressedBytesCount; + + + } + else if (bufferPtr != startOfBuffer) + { + CopyBlock(&b, bufferPtr); + bufferPtr += 64; + bufferBytesConsumed += 64; + } + else + { + if (dst >= endOfSafe64ByteZone) + { + Base64DecodeBlockSafe(dst, &b); + } + else + { + Base64DecodeBlock(dst, &b); + } + bufferBytesWritten += 48; + dst += 48; + } + + if (bufferPtr >= (blocksSize - 1) * 64 + startOfBuffer) // We treat the last block separately later on + { + for (int i = 0; i < (blocksSize - 2); i++) // We also treat the second to last block differently! Until then it is safe to proceed: + { + Base64DecodeBlock(dst, startOfBuffer + i * 64); + bufferBytesWritten += 48; + dst += 48; + } + if (dst >= endOfSafe64ByteZone) // for the second to last block, we may need to chcek if its unsafe to proceed + { + Base64DecodeBlockSafe(dst, startOfBuffer + (blocksSize - 2) * 64); + } + else + { + Base64DecodeBlock(dst, startOfBuffer + (blocksSize - 2) * 64); + } + + + + dst += 48; + Buffer.MemoryCopy(startOfBuffer + (blocksSize - 1) * 64, startOfBuffer, 64, 64); + bufferPtr -= (blocksSize - 1) * 64; + + bufferBytesWritten = 0; + bufferBytesConsumed = 0; + } + + } + } + // Optimization note: if this is almost full, then it is worth our + // time, otherwise, we should just decode directly. + int lastBlock = (int)((bufferPtr - startOfBuffer) % 64); + // There is at some bytes remaining beyond the last 64 bit block remaining + if (lastBlock != 0 && srcEnd - src + lastBlock >= 64) // We first check if there is any error and eliminate white spaces?: + { + int lastBlockSrcCount = 0; + while ((bufferPtr - startOfBuffer) % 64 != 0 && src < srcEnd) + { + if (!IsValidBase64Index(*src)) + { + bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed); + bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + + byte val = toBase64[(int)*src]; + *bufferPtr = val; + if (val > 64) + { + bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed); + bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + bufferPtr += (val <= 63) ? 1 : 0; + src++; + lastBlockSrcCount++; + } + } + + byte* subBufferPtr = startOfBuffer; + for (; subBufferPtr + 64 <= bufferPtr; subBufferPtr += 64) + { + if (dst >= endOfSafe64ByteZone) + { + Base64DecodeBlockSafe(dst, subBufferPtr); + } + else + { + Base64DecodeBlock(dst, subBufferPtr); + } + + dst += 48;// 64 bits of base64 decodes to 48 bits + } + if ((bufferPtr - subBufferPtr) % 64 != 0) + { + while (subBufferPtr + 4 < bufferPtr) // we decode one base64 element (4 bit) at a time + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 4, 4); + + dst += 3; + subBufferPtr += 4; + } + if (subBufferPtr + 4 <= bufferPtr) // this may be the very last element, might be incomplete + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 3, 3); + + dst += 3; + subBufferPtr += 4; + } + int leftover = (int)(bufferPtr - subBufferPtr); + if (leftover > 0) + { + + while (leftover < 4 && src < srcEnd) + { + + if (!IsValidBase64Index(*src)) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.InvalidData; + } + + + byte val = toBase64[(byte)*src]; + if (val > 64) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.InvalidData; + } + subBufferPtr[leftover] = (byte)(val); + leftover += (val <= 63) ? 1 : 0; + src++; + } + + if (leftover == 1) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.NeedMoreData; + } + if (leftover == 2) + { + UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + + ((UInt32)(subBufferPtr[1]) << 2 * 6); + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Buffer.MemoryCopy(&triple, dst, 1, 1); + + dst += 1; + } + else if (leftover == 3) + { + UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + + ((UInt32)(subBufferPtr[1]) << 2 * 6) + + ((UInt32)(subBufferPtr[2]) << 1 * 6); + triple = BinaryPrimitives.ReverseEndianness(triple); + + triple >>= 8; + + Buffer.MemoryCopy(&triple, dst, 2, 2); + + dst += 2; + } + else + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 3, 3); + + dst += 3; + } + } + } + + if (src < srcEnd + equalsigns) // We finished processing 64-bit blocks, we're not quite at the end yet + { + + + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + + + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(bytesConsumed), dest.Slice(bytesWritten), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + + if (result == OperationStatus.InvalidData) + { + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + else + { + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + } + if (result == OperationStatus.Done && equalsigns > 0) + { + + // additional checks + if ((remainderBytesWritten % 3 == 0) || ((remainderBytesWritten % 3) + 1 + equalsigns != 4)) + { + result = OperationStatus.InvalidData; + } + } + return result; + } + if (equalsigns > 0) // final additional check + { + if (((int)(dst - dstInit) % 3 == 0) || (((int)(dst - dstInit) % 3) + 1 + equalsigns != 4)) + { + return OperationStatus.InvalidData; + } + } + + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.Done; + } + + } + } + + private unsafe static OperationStatus InnerDecodeFromBase64SSEUrl(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten) + { + // translation from ASCII to 6 bit values + bool isUrl = true; + byte[] toBase64 = Tables.ToBase64UrlValue; + bytesConsumed = 0; + bytesWritten = 0; + const int blocksSize = 6; + Span buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + // Define pointers within the fixed blocks + fixed (char* srcInit = source) + fixed (byte* dstInit = dest) + fixed (byte* startOfBuffer = buffer) + { + char* srcEnd = srcInit + source.Length; + char* src = srcInit; + byte* dst = dstInit; + byte* dstEnd = dstInit + dest.Length; + + int whiteSpaces = 0; + int equalsigns = 0; + + int bytesToProcess = source.Length; + // skip trailing spaces + while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) + { + bytesToProcess--; + whiteSpaces++; + } + + int equallocation = bytesToProcess; // location of the first padding character if any + if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') + { + bytesToProcess -= 1; + equalsigns++; + while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) + { + bytesToProcess--; + whiteSpaces++; + } + if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') + { + equalsigns++; + bytesToProcess -= 1; + } + } + + // round up to the nearest multiple of 4, then multiply by 3 + int decoded3bitsChunksToProcess = (bytesToProcess + 3) / 4 * 3; + + byte* endOfSafe64ByteZone = + decoded3bitsChunksToProcess >= 63 ? + dst + decoded3bitsChunksToProcess - 63 : + dst; + + { + byte* bufferPtr = startOfBuffer; + + ulong bufferBytesConsumed = 0;//Only used if there is an error + ulong bufferBytesWritten = 0;//Only used if there is an error + + if (bytesToProcess >= 64) + { + char* srcEnd64 = srcInit + bytesToProcess - 64; + while (src <= srcEnd64) + { + Base64.Block64 b; + Base64.LoadBlock(&b, src); + src += 64; + bufferBytesConsumed += 64; + bool error = false; + UInt64 badCharMask = Base64.ToBase64Mask(isUrl, &b, ref error); + if (error == true) + { + src -= bufferBytesConsumed; + dst -= bufferBytesWritten; + + bytesConsumed = Math.Max(0, (int)(src - srcInit)); + bytesWritten = Math.Max(0, (int)(dst - dstInit)); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + if (badCharMask != 0) + { + // optimization opportunity: check for simple masks like those made of + // continuous 1s followed by continuous 0s. And masks containing a + // single bad character. + ulong compressedBytesCount = CompressBlock(ref b, badCharMask, bufferPtr); + bufferPtr += compressedBytesCount; + bufferBytesConsumed += compressedBytesCount; + + + } + else if (bufferPtr != startOfBuffer) + { + CopyBlock(&b, bufferPtr); + bufferPtr += 64; + bufferBytesConsumed += 64; + } + else + { + if (dst >= endOfSafe64ByteZone) + { + Base64DecodeBlockSafe(dst, &b); + } + else + { + Base64DecodeBlock(dst, &b); + } + bufferBytesWritten += 48; + dst += 48; + } + + if (bufferPtr >= (blocksSize - 1) * 64 + startOfBuffer) // We treat the last block separately later on + { + for (int i = 0; i < (blocksSize - 2); i++) // We also treat the second to last block differently! Until then it is safe to proceed: + { + Base64DecodeBlock(dst, startOfBuffer + i * 64); + bufferBytesWritten += 48; + dst += 48; + } + if (dst >= endOfSafe64ByteZone) // for the second to last block, we may need to chcek if its unsafe to proceed + { + Base64DecodeBlockSafe(dst, startOfBuffer + (blocksSize - 2) * 64); + } + else + { + Base64DecodeBlock(dst, startOfBuffer + (blocksSize - 2) * 64); + } + + + + dst += 48; + Buffer.MemoryCopy(startOfBuffer + (blocksSize - 1) * 64, startOfBuffer, 64, 64); + bufferPtr -= (blocksSize - 1) * 64; + + bufferBytesWritten = 0; + bufferBytesConsumed = 0; + } + + } + } + // Optimization note: if this is almost full, then it is worth our + // time, otherwise, we should just decode directly. + int lastBlock = (int)((bufferPtr - startOfBuffer) % 64); + // There is at some bytes remaining beyond the last 64 bit block remaining + if (lastBlock != 0 && srcEnd - src + lastBlock >= 64) // We first check if there is any error and eliminate white spaces?: + { + int lastBlockSrcCount = 0; + while ((bufferPtr - startOfBuffer) % 64 != 0 && src < srcEnd) + { + + if (!IsValidBase64Index(*src)) + { + bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed); + bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + + byte val = toBase64[(int)*src]; + *bufferPtr = val; + if (val > 64) + { + bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed); + bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten); + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + bufferPtr += (val <= 63) ? 1 : 0; + src++; + lastBlockSrcCount++; + } + } + + byte* subBufferPtr = startOfBuffer; + for (; subBufferPtr + 64 <= bufferPtr; subBufferPtr += 64) + { + if (dst >= endOfSafe64ByteZone) + { + Base64DecodeBlockSafe(dst, subBufferPtr); + } + else + { + Base64DecodeBlock(dst, subBufferPtr); + } + + dst += 48;// 64 bits of base64 decodes to 48 bits + } + if ((bufferPtr - subBufferPtr) % 64 != 0) + { + while (subBufferPtr + 4 < bufferPtr) // we decode one base64 element (4 bit) at a time + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 4, 4); + + dst += 3; + subBufferPtr += 4; + } + if (subBufferPtr + 4 <= bufferPtr) // this may be the very last element, might be incomplete + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 3, 3); + + dst += 3; + subBufferPtr += 4; + } + int leftover = (int)(bufferPtr - subBufferPtr); + if (leftover > 0) + { + + while (leftover < 4 && src < srcEnd) + { + + if (!IsValidBase64Index(*src)) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.InvalidData; + } + + byte val = toBase64[(byte)*src]; + if (val > 64) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.InvalidData; + } + subBufferPtr[leftover] = (byte)(val); + leftover += (val <= 63) ? 1 : 0; + src++; + } + + if (leftover == 1) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.NeedMoreData; + } + if (leftover == 2) + { + UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + + ((UInt32)(subBufferPtr[1]) << 2 * 6); + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Buffer.MemoryCopy(&triple, dst, 1, 1); + + dst += 1; + } + else if (leftover == 3) + { + UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + + ((UInt32)(subBufferPtr[1]) << 2 * 6) + + ((UInt32)(subBufferPtr[2]) << 1 * 6); + triple = BinaryPrimitives.ReverseEndianness(triple); + + triple >>= 8; + + Buffer.MemoryCopy(&triple, dst, 2, 2); + + dst += 2; + } + else + { + UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + + ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + + ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + + ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) + << 8; + triple = BinaryPrimitives.ReverseEndianness(triple); + Buffer.MemoryCopy(&triple, dst, 3, 3); + + dst += 3; + } + } + } + + if (src < srcEnd + equalsigns) // We finished processing 64-bit blocks, we're not quite at the end yet + { + + + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + + + + int remainderBytesConsumed = 0; + int remainderBytesWritten = 0; + + OperationStatus result = + Base64WithWhiteSpaceToBinaryScalar(source.Slice(bytesConsumed), dest.Slice(bytesWritten), out remainderBytesConsumed, out remainderBytesWritten, isUrl); + + + if (result == OperationStatus.InvalidData) + { + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + return result; + } + else + { + bytesConsumed += remainderBytesConsumed; + bytesWritten += remainderBytesWritten; + } + if (result == OperationStatus.Done && equalsigns > 0) + { + + // additional checks + if ((remainderBytesWritten % 3 == 0) || ((remainderBytesWritten % 3) + 1 + equalsigns != 4)) + { + result = OperationStatus.InvalidData; + } + } + return result; + } + if (equalsigns > 0) // final additional check + { + if (((int)(dst - dstInit) % 3 == 0) || (((int)(dst - dstInit) % 3) + 1 + equalsigns != 4)) + { + return OperationStatus.InvalidData; + } + } + + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.Done; + } + + } + } + + } +} diff --git a/src/Base64SSE.cs b/src/Base64SSEUTF8.cs similarity index 72% rename from src/Base64SSE.cs rename to src/Base64SSEUTF8.cs index f749641..350408a 100644 --- a/src/Base64SSE.cs +++ b/src/Base64SSEUTF8.cs @@ -15,10 +15,6 @@ namespace SimdBase64 { public static partial class Base64 - - - - { /* // If needed for debugging, you can do the following: @@ -52,6 +48,25 @@ private static unsafe void LoadBlock(Block64* b, byte* src) b->chunk3 = Sse2.LoadVector128(src + 48); } + private unsafe static void LoadBlock(Block64* b, char* src) + { + // Load 128 bits (16 chars, 32 bytes) at each step from the UTF-16 source + var m1 = Sse2.LoadVector128((ushort*)src); + var m2 = Sse2.LoadVector128((ushort*)(src + 8)); + var m3 = Sse2.LoadVector128((ushort*)(src + 16)); + var m4 = Sse2.LoadVector128((ushort*)(src + 24)); + var m5 = Sse2.LoadVector128((ushort*)(src + 32)); + var m6 = Sse2.LoadVector128((ushort*)(src + 40)); + var m7 = Sse2.LoadVector128((ushort*)(src + 48)); + var m8 = Sse2.LoadVector128((ushort*)(src + 56)); + + // Pack 16-bit chars down to 8-bit chars, handling two __m128i at a time + b->chunk0 = Sse2.PackUnsignedSaturate(m1.AsInt16(), m2.AsInt16()).AsByte(); + b->chunk1 = Sse2.PackUnsignedSaturate(m3.AsInt16(), m4.AsInt16()).AsByte(); + b->chunk2 = Sse2.PackUnsignedSaturate(m5.AsInt16(), m6.AsInt16()).AsByte(); + b->chunk3 = Sse2.PackUnsignedSaturate(m7.AsInt16(), m8.AsInt16()).AsByte(); + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe ulong ToBase64Mask(bool base64Url, Block64* b, ref bool error) { @@ -263,7 +278,7 @@ public unsafe static OperationStatus DecodeFromBase64SSE(ReadOnlySpan sour { if (isUrl) { - return InnerDecodeFromBase64SSEURL(source, dest, out bytesConsumed, out bytesWritten); + return InnerDecodeFromBase64SSEUrl(source, dest, out bytesConsumed, out bytesWritten); } else { @@ -275,7 +290,7 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSERegular(ReadOnlySp { // translation from ASCII to 6 bit values bool isUrl = false; - byte[] toBase64 = isUrl == true ? Tables.ToBase64UrlValue : Tables.ToBase64Value; + byte[] toBase64 = Tables.ToBase64Value; bytesConsumed = 0; bytesWritten = 0; const int blocksSize = 6; @@ -611,351 +626,11 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSERegular(ReadOnlySp } } - private unsafe static OperationStatus InnerDecodeFromBase64SSEURL(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten) + private unsafe static OperationStatus InnerDecodeFromBase64SSEUrl(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten) { // translation from ASCII to 6 bit values bool isUrl = true; - byte[] toBase64 = isUrl == true ? Tables.ToBase64UrlValue : Tables.ToBase64Value; - bytesConsumed = 0; - bytesWritten = 0; - const int blocksSize = 6; - Span buffer = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - // Define pointers within the fixed blocks - fixed (byte* srcInit = source) - fixed (byte* dstInit = dest) - fixed (byte* startOfBuffer = buffer) - { - byte* srcEnd = srcInit + source.Length; - byte* src = srcInit; - byte* dst = dstInit; - byte* dstEnd = dstInit + dest.Length; - - int whiteSpaces = 0; - int equalsigns = 0; - - int bytesToProcess = source.Length; - // skip trailing spaces - while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) - { - bytesToProcess--; - whiteSpaces++; - } - - int equallocation = bytesToProcess; // location of the first padding character if any - if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') - { - bytesToProcess -= 1; - equalsigns++; - while (bytesToProcess > 0 && Base64.IsAsciiWhiteSpace((char)source[bytesToProcess - 1])) - { - bytesToProcess--; - whiteSpaces++; - } - if (bytesToProcess > 0 && source[bytesToProcess - 1] == '=') - { - equalsigns++; - bytesToProcess -= 1; - } - } - - // round up to the nearest multiple of 4, then multiply by 3 - int decoded3bitsChunksToProcess = (bytesToProcess + 3) / 4 * 3; - - byte* endOfSafe64ByteZone = - decoded3bitsChunksToProcess >= 63 ? - dst + decoded3bitsChunksToProcess - 63 : - dst; - - { - byte* bufferPtr = startOfBuffer; - - ulong bufferBytesConsumed = 0;//Only used if there is an error - ulong bufferBytesWritten = 0;//Only used if there is an error - - if (bytesToProcess >= 64) - { - byte* srcEnd64 = srcInit + bytesToProcess - 64; - while (src <= srcEnd64) - { - Base64.Block64 b; - Base64.LoadBlock(&b, src); - src += 64; - bufferBytesConsumed += 64; - bool error = false; - UInt64 badCharMask = Base64.ToBase64Mask(isUrl, &b, ref error); - if (error == true) - { - src -= bufferBytesConsumed; - dst -= bufferBytesWritten; - - bytesConsumed = Math.Max(0, (int)(src - srcInit)); - bytesWritten = Math.Max(0, (int)(dst - dstInit)); - - int remainderBytesConsumed = 0; - int remainderBytesWritten = 0; - - OperationStatus result = - Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); - - bytesConsumed += remainderBytesConsumed; - bytesWritten += remainderBytesWritten; - return result; - } - if (badCharMask != 0) - { - // optimization opportunity: check for simple masks like those made of - // continuous 1s followed by continuous 0s. And masks containing a - // single bad character. - ulong compressedBytesCount = CompressBlock(ref b, badCharMask, bufferPtr); - bufferPtr += compressedBytesCount; - bufferBytesConsumed += compressedBytesCount; - - - } - else if (bufferPtr != startOfBuffer) - { - CopyBlock(&b, bufferPtr); - bufferPtr += 64; - bufferBytesConsumed += 64; - } - else - { - if (dst >= endOfSafe64ByteZone) - { - Base64DecodeBlockSafe(dst, &b); - } - else - { - Base64DecodeBlock(dst, &b); - } - bufferBytesWritten += 48; - dst += 48; - } - - if (bufferPtr >= (blocksSize - 1) * 64 + startOfBuffer) // We treat the last block separately later on - { - for (int i = 0; i < (blocksSize - 2); i++) // We also treat the second to last block differently! Until then it is safe to proceed: - { - Base64DecodeBlock(dst, startOfBuffer + i * 64); - bufferBytesWritten += 48; - dst += 48; - } - if (dst >= endOfSafe64ByteZone) // for the second to last block, we may need to chcek if its unsafe to proceed - { - Base64DecodeBlockSafe(dst, startOfBuffer + (blocksSize - 2) * 64); - } - else - { - Base64DecodeBlock(dst, startOfBuffer + (blocksSize - 2) * 64); - } - - - - dst += 48; - Buffer.MemoryCopy(startOfBuffer + (blocksSize - 1) * 64, startOfBuffer, 64, 64); - bufferPtr -= (blocksSize - 1) * 64; - - bufferBytesWritten = 0; - bufferBytesConsumed = 0; - } - - } - } - // Optimization note: if this is almost full, then it is worth our - // time, otherwise, we should just decode directly. - int lastBlock = (int)((bufferPtr - startOfBuffer) % 64); - // There is at some bytes remaining beyond the last 64 bit block remaining - if (lastBlock != 0 && srcEnd - src + lastBlock >= 64) // We first check if there is any error and eliminate white spaces?: - { - int lastBlockSrcCount = 0; - while ((bufferPtr - startOfBuffer) % 64 != 0 && src < srcEnd) - { - byte val = toBase64[(int)*src]; - *bufferPtr = val; - if (val > 64) - { - bytesConsumed = Math.Max(0, (int)(src - srcInit) - lastBlockSrcCount - (int)bufferBytesConsumed); - bytesWritten = Math.Max(0, (int)(dst - dstInit) - (int)bufferBytesWritten); - - int remainderBytesConsumed = 0; - int remainderBytesWritten = 0; - - OperationStatus result = - Base64WithWhiteSpaceToBinaryScalar(source.Slice(Math.Max(0, bytesConsumed)), dest.Slice(Math.Max(0, bytesWritten)), out remainderBytesConsumed, out remainderBytesWritten, isUrl); - - bytesConsumed += remainderBytesConsumed; - bytesWritten += remainderBytesWritten; - return result; - } - bufferPtr += (val <= 63) ? 1 : 0; - src++; - lastBlockSrcCount++; - } - } - - byte* subBufferPtr = startOfBuffer; - for (; subBufferPtr + 64 <= bufferPtr; subBufferPtr += 64) - { - if (dst >= endOfSafe64ByteZone) - { - Base64DecodeBlockSafe(dst, subBufferPtr); - } - else - { - Base64DecodeBlock(dst, subBufferPtr); - } - - dst += 48;// 64 bits of base64 decodes to 48 bits - } - if ((bufferPtr - subBufferPtr) % 64 != 0) - { - while (subBufferPtr + 4 < bufferPtr) // we decode one base64 element (4 bit) at a time - { - UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + - ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + - ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + - ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) - << 8; - triple = BinaryPrimitives.ReverseEndianness(triple); - Buffer.MemoryCopy(&triple, dst, 4, 4); - - dst += 3; - subBufferPtr += 4; - } - if (subBufferPtr + 4 <= bufferPtr) // this may be the very last element, might be incomplete - { - UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + - ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + - ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + - ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) - << 8; - triple = BinaryPrimitives.ReverseEndianness(triple); - Buffer.MemoryCopy(&triple, dst, 3, 3); - - dst += 3; - subBufferPtr += 4; - } - int leftover = (int)(bufferPtr - subBufferPtr); - if (leftover > 0) - { - - while (leftover < 4 && src < srcEnd) - { - byte val = toBase64[(byte)*src]; - if (val > 64) - { - bytesConsumed = (int)(src - srcInit); - bytesWritten = (int)(dst - dstInit); - return OperationStatus.InvalidData; - } - subBufferPtr[leftover] = (byte)(val); - leftover += (val <= 63) ? 1 : 0; - src++; - } - - if (leftover == 1) - { - bytesConsumed = (int)(src - srcInit); - bytesWritten = (int)(dst - dstInit); - return OperationStatus.NeedMoreData; - } - if (leftover == 2) - { - UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + - ((UInt32)(subBufferPtr[1]) << 2 * 6); - triple = BinaryPrimitives.ReverseEndianness(triple); - triple >>= 8; - Buffer.MemoryCopy(&triple, dst, 1, 1); - - dst += 1; - } - else if (leftover == 3) - { - UInt32 triple = ((UInt32)(subBufferPtr[0]) << 3 * 6) + - ((UInt32)(subBufferPtr[1]) << 2 * 6) + - ((UInt32)(subBufferPtr[2]) << 1 * 6); - triple = BinaryPrimitives.ReverseEndianness(triple); - - triple >>= 8; - - Buffer.MemoryCopy(&triple, dst, 2, 2); - - dst += 2; - } - else - { - UInt32 triple = (((UInt32)((byte)(subBufferPtr[0])) << 3 * 6) + - ((UInt32)((byte)(subBufferPtr[1])) << 2 * 6) + - ((UInt32)((byte)(subBufferPtr[2])) << 1 * 6) + - ((UInt32)((byte)(subBufferPtr[3])) << 0 * 6)) - << 8; - triple = BinaryPrimitives.ReverseEndianness(triple); - Buffer.MemoryCopy(&triple, dst, 3, 3); - - dst += 3; - } - } - } - - if (src < srcEnd + equalsigns) // We finished processing 64-bit blocks, we're not quite at the end yet - { - - - bytesConsumed = (int)(src - srcInit); - bytesWritten = (int)(dst - dstInit); - - - - int remainderBytesConsumed = 0; - int remainderBytesWritten = 0; - - OperationStatus result = - Base64WithWhiteSpaceToBinaryScalar(source.Slice(bytesConsumed), dest.Slice(bytesWritten), out remainderBytesConsumed, out remainderBytesWritten, isUrl); - - - if (result == OperationStatus.InvalidData) - { - bytesConsumed += remainderBytesConsumed; - bytesWritten += remainderBytesWritten; - return result; - } - else - { - bytesConsumed += remainderBytesConsumed; - bytesWritten += remainderBytesWritten; - } - if (result == OperationStatus.Done && equalsigns > 0) - { - - // additional checks - if ((remainderBytesWritten % 3 == 0) || ((remainderBytesWritten % 3) + 1 + equalsigns != 4)) - { - result = OperationStatus.InvalidData; - } - } - return result; - } - if (equalsigns > 0) // final additional check - { - if (((int)(dst - dstInit) % 3 == 0) || (((int)(dst - dstInit) % 3) + 1 + equalsigns != 4)) - { - return OperationStatus.InvalidData; - } - } - - bytesConsumed = (int)(src - srcInit); - bytesWritten = (int)(dst - dstInit); - return OperationStatus.Done; - } - - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private unsafe static OperationStatus InnerDecodeFromBase64SSE(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) - { - // translation from ASCII to 6 bit values - byte[] toBase64 = isUrl == true ? Tables.ToBase64UrlValue : Tables.ToBase64Value; + byte[] toBase64 = Tables.ToBase64UrlValue; bytesConsumed = 0; bytesWritten = 0; const int blocksSize = 6; @@ -1290,8 +965,5 @@ private unsafe static OperationStatus InnerDecodeFromBase64SSE(ReadOnlySpan input) + public static int MaximalBinaryLengthFromBase64Scalar(ReadOnlySpan input) { // We follow https://infra.spec.whatwg.org/#forgiving-base64-decode int padding = 0; @@ -63,6 +61,7 @@ public static int MaximalBinaryLengthFromBase64Scalar(ReadOnlySpan input) return actualLength / 4 * 3 + (actualLength % 4) - 1; } + public unsafe static OperationStatus DecodeFromBase64Scalar(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) { @@ -205,10 +204,15 @@ public unsafe static OperationStatus DecodeFromBase64Scalar(ReadOnlySpan s } } - // like DecodeFromBase64Scalar, but it will not write past the end of the ouput buffer. - public unsafe static OperationStatus SafeDecodeFromBase64Scalar(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + + public static bool IsValidBase64Index(char b) { + return b < 256; + } + + public unsafe static OperationStatus DecodeFromBase64Scalar(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { byte[] toBase64 = isUrl != false ? Tables.ToBase64UrlValue : Tables.ToBase64Value; uint[] d0 = isUrl != false ? Base64Url.d0 : Base64Default.d0; uint[] d1 = isUrl != false ? Base64Url.d1 : Base64Default.d1; @@ -217,9 +221,158 @@ public unsafe static OperationStatus SafeDecodeFromBase64Scalar(ReadOnlySpan 64) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + + return OperationStatus.InvalidData;// Found a character that cannot be part of a valid base64 string. + } + else + { + // We have a space or a newline. We ignore it. + } + src++; + } + + // deals with reminder + if (idx != 4) + { + if (idx == 2) // we just copy directly while converting + { + triple = ((uint)buffer[0] << (3 * 6)) + ((uint)buffer[1] << (2 * 6)); // the 2 last byte are shifted 18 and 12 bits respectively + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + byte[] byteTriple = BitConverter.GetBytes(triple); + dst[0] = byteTriple[0]; + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + byte[] byteTriple = BitConverter.GetBytes(triple); + dst[0] = byteTriple[0]; // Copy only the first byte + } + dst += 1; + } + + else if (idx == 3) + { + triple = ((uint)buffer[0] << 3 * 6) + + ((uint)buffer[1] << 2 * 6) + + ((uint)buffer[2] << 1 * 6); + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 2); + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 2); + } + dst += 2; + } + else if (idx == 1) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.NeedMoreData;// The base64 input terminates with a single character, excluding padding. + } + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.Done; + } + triple = + ((uint)buffer[0] << 3 * 6) + ((uint)buffer[1] << 2 * 6) + + ((uint)buffer[2] << 1 * 6) + ((uint)buffer[3] << 0 * 6); + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 3); + + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 3); + } + dst += 3; + } + + } + } + + // like DecodeFromBase64Scalar, but it will not write past the end of the ouput buffer. + public unsafe static OperationStatus SafeDecodeFromBase64Scalar(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { + + byte[] toBase64 = isUrl != false ? Tables.ToBase64UrlValue : Tables.ToBase64Value; + uint[] d0 = isUrl != false ? Base64Url.d0 : Base64Default.d0; + uint[] d1 = isUrl != false ? Base64Url.d1 : Base64Default.d1; + uint[] d2 = isUrl != false ? Base64Url.d2 : Base64Default.d2; + uint[] d3 = isUrl != false ? Base64Url.d3 : Base64Default.d3; + + int length = source.Length; + // Define pointers within the fixed blocks fixed (byte* srcInit = source) fixed (byte* dstInit = dest) @@ -383,6 +536,181 @@ public unsafe static OperationStatus SafeDecodeFromBase64Scalar(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { + + byte[] toBase64 = isUrl != false ? Tables.ToBase64UrlValue : Tables.ToBase64Value; + uint[] d0 = isUrl != false ? Base64Url.d0 : Base64Default.d0; + uint[] d1 = isUrl != false ? Base64Url.d1 : Base64Default.d1; + uint[] d2 = isUrl != false ? Base64Url.d2 : Base64Default.d2; + uint[] d3 = isUrl != false ? Base64Url.d3 : Base64Default.d3; + + int length = source.Length; + + // Define pointers within the fixed blocks + fixed (char* srcInit = source) + fixed (byte* dstInit = dest) + + { + char* srcEnd = srcInit + length; + char* src = srcInit; + byte* dst = dstInit; + byte* dstEnd = dstInit + dest.Length; + + // Continue the implementation + uint x; + uint triple; + int idx; + byte[] buffer = new byte[4]; + + while (true) + { + // fastpath + while (src + 4 <= srcEnd && + (x = d0[*src] | d1[src[1]] | d2[src[2]] | d3[src[3]]) < 0x01FFFFFF) + { + + + if (MatchSystem(Endianness.BIG)) + { + x = BinaryPrimitives.ReverseEndianness(x); + } + if (dst + 3 > dstEnd) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.DestinationTooSmall; + } + Marshal.Copy(buffer, 0, (IntPtr)dst, 3); // optimization opportunity: copy 4 bytes + dst += 3; + src += 4; + } + idx = 0; + + char* srcCurrent = src; + + // We need at least four characters. + while (idx < 4 && src < srcEnd) + { + + + char c = (char)*src; + byte code = toBase64[c]; + buffer[idx] = code; + + + + if (code <= 63) + { + idx++; + } + else if (code > 64) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.InvalidData;// Found a character that cannot be part of a valid base64 string. + } + else + { + // We have a space or a newline. We ignore it. + } + src++; + } + + // deals with reminder + if (idx != 4) + { + if (idx == 2) // we just copy directly while converting + { + if (dst == dstEnd) + { + bytesConsumed = (int)(srcCurrent - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.DestinationTooSmall; + } + triple = ((uint)buffer[0] << (3 * 6)) + ((uint)buffer[1] << (2 * 6)); // the 2 last byte are shifted 18 and 12 bits respectively + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + byte[] byteTriple = BitConverter.GetBytes(triple); + dst[0] = byteTriple[0]; + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + byte[] byteTriple = BitConverter.GetBytes(triple); + dst[0] = byteTriple[0]; // Copy only the first byte + } + dst += 1; + } + + else if (idx == 3) // same story here + { + if (dst + 2 > dstEnd) + { + bytesConsumed = (int)(srcCurrent - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.DestinationTooSmall; + } + triple = ((uint)buffer[0] << 3 * 6) + + ((uint)buffer[1] << 2 * 6) + + ((uint)buffer[2] << 1 * 6); + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 2); + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 2); + } + dst += 2; + } + + else if (idx == 1) + { + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + + return OperationStatus.InvalidData;// The base64 input terminates with a single character, excluding padding. + } + bytesConsumed = (int)(src - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.Done;//SUCCESS + } + + if (dst + 3 >= dstEnd) + { + bytesConsumed = (int)(srcCurrent - srcInit); + bytesWritten = (int)(dst - dstInit); + return OperationStatus.DestinationTooSmall; + } + triple = + ((uint)(buffer[0]) << 3 * 6) + ((uint)(buffer[1]) << 2 * 6) + + ((uint)(buffer[2]) << 1 * 6) + ((uint)(buffer[3]) << 0 * 6); + if (MatchSystem(Endianness.BIG)) + { + triple <<= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 3); + + } + else + { + triple = BinaryPrimitives.ReverseEndianness(triple); + triple >>= 8; + Marshal.Copy(BitConverter.GetBytes(triple), 0, (IntPtr)dst, 3); + } + dst += 3; + } + + } + } + + public static OperationStatus Base64WithWhiteSpaceToBinaryScalar(ReadOnlySpan input, Span output, out int bytesConsumed, out int bytesWritten, bool isUrl = false) { int length = input.Length; @@ -445,6 +773,69 @@ public static OperationStatus Base64WithWhiteSpaceToBinaryScalar(ReadOnlySpan input, Span output, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { + int length = input.Length; + int whiteSpaces = 0; + while (length > 0 && IsAsciiWhiteSpace((char)input[length - 1])) + { + length--; + whiteSpaces++; + } + int equallocation = length; // location of the first padding character if any + int equalsigns = 0; + if (length > 0 && input[length - 1] == '=') + { + length -= 1; + equalsigns++; + while (length > 0 && IsAsciiWhiteSpace((char)input[length - 1])) + { + length--; + whiteSpaces++; + } + if (length > 0 && input[length - 1] == '=') + { + equalsigns++; + length -= 1; + } + } + if (length == 0) + { + if (equalsigns > 0) + { + bytesConsumed = equallocation; + bytesWritten = 0; + + return OperationStatus.InvalidData; + + } + bytesConsumed = 0 + whiteSpaces + equalsigns; + bytesWritten = 0; + return OperationStatus.Done; + } + + ReadOnlySpan trimmedInput = input.Slice(0, length); + + OperationStatus r = Base64.DecodeFromBase64Scalar(trimmedInput, output, out bytesConsumed, out bytesWritten, isUrl); + + if (r == OperationStatus.Done) + { + if (equalsigns > 0) + { + // Additional checks + if ((bytesWritten % 3 == 0) || (((bytesWritten % 3) + 1 + equalsigns) != 4)) + { + return OperationStatus.InvalidData; + } + } + + // Only increment bytesConsumed if decoding was successful + bytesConsumed += equalsigns + whiteSpaces; + } + return r; + } + + public static int Base64LengthFromBinary(int length, bool isUrl = false) { @@ -555,6 +946,106 @@ public unsafe static OperationStatus SafeBase64ToBinaryWithWhiteSpace(ReadOnlySp return r; } + public unsafe static OperationStatus SafeBase64ToBinaryWithWhiteSpace(ReadOnlySpan input, Span output, out int bytesConsumed, out int bytesWritten, bool isUrl = false) + { + // The implementation could be nicer, but we expect that most times, the user + // will provide us with a buffer that is large enough. + int maxLength = MaximalBinaryLengthFromBase64Scalar(input); + + if (output.Length >= maxLength) + { + // fast path + OperationStatus fastPathResult = Base64.Base64WithWhiteSpaceToBinaryScalar(input, output, out bytesConsumed, out bytesWritten, isUrl); + return fastPathResult; + } + // The output buffer is maybe too small. We will decode a truncated version of the input. + int outlen3 = output.Length / 3 * 3; // round down to multiple of 3 + int safeInputLength = Base64LengthFromBinary(outlen3); + + OperationStatus r = DecodeFromBase64Scalar(input.Slice(0, Math.Max(0, safeInputLength)), output, out bytesConsumed, out bytesWritten, isUrl); // there might be a -1 error here + + + if (r == OperationStatus.InvalidData) + { + return r; + } + int offset = (r == OperationStatus.NeedMoreData) ? 1 : + ((bytesWritten % 3) == 0 ? + 0 : (bytesWritten % 3) + 1); + + int outputIndex = bytesWritten - (bytesWritten % 3); + int inputIndex = safeInputLength; + int whiteSpaces = 0; + // offset is a value that is no larger than 3. We backtrack + // by up to offset characters + an undetermined number of + // white space characters. It is expected that the next loop + // runs at most 3 times + the number of white space characters + // in between them, so we are not worried about performance. + while (offset > 0 && inputIndex > 0) + { + char c = (char)input[--inputIndex]; + if (IsAsciiWhiteSpace(c)) + { + // skipping + } + else + { + offset--; + whiteSpaces++; + } + } + ReadOnlySpan tailInput = input.Slice(inputIndex); + int RemainingInputLength = tailInput.Length; + while (RemainingInputLength > 0 && IsAsciiWhiteSpace((char)tailInput[RemainingInputLength - 1])) + { + RemainingInputLength--; + } + int paddingCharacts = 0; + if (RemainingInputLength > 0 && tailInput[RemainingInputLength - 1] == '=') + { + RemainingInputLength--; + paddingCharacts++; + while (RemainingInputLength > 0 && IsAsciiWhiteSpace((char)tailInput[RemainingInputLength - 1])) + { + RemainingInputLength--; + whiteSpaces++; + } + if (RemainingInputLength > 0 && tailInput[RemainingInputLength - 1] == '=') + { + RemainingInputLength--; + paddingCharacts++; + } + } + + int tailBytesConsumed; + int tailBytesWritten; + + Span remainingOut = output.Slice(Math.Min(output.Length, outputIndex)); + r = SafeDecodeFromBase64Scalar(tailInput.Slice(0, RemainingInputLength), remainingOut, out tailBytesConsumed, out tailBytesWritten, isUrl); + + if (r == OperationStatus.Done && paddingCharacts > 0) + { + // additional checks: + if ((remainingOut.Length % 3 == 0) || ((remainingOut.Length % 3) + 1 + paddingCharacts != 4)) + { + r = OperationStatus.InvalidData; + } + } + + + if (r == OperationStatus.Done) + { + bytesConsumed += tailBytesConsumed + paddingCharacts + whiteSpaces; + } + else + { + bytesConsumed += tailBytesConsumed; + } + bytesWritten += tailBytesWritten; + return r; + } + + } diff --git a/test/Base64DecodingTestsUTF16.cs b/test/Base64DecodingTestsUTF16.cs new file mode 100644 index 0000000..4c81098 --- /dev/null +++ b/test/Base64DecodingTestsUTF16.cs @@ -0,0 +1,1446 @@ +namespace tests; +using System.Text; +using SimdBase64; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics.Arm; +using System.Buffers; +using Newtonsoft.Json; + +public partial class Base64DecodingTests{ + + public delegate OperationStatus DecodeFromBase64DelegateFncFromUTF16(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); + public delegate OperationStatus DecodeFromBase64DelegateSafeFromUTF16(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); + public delegate OperationStatus Base64WithWhiteSpaceToBinaryFromUTF16(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); + + + protected static void DecodeBase64CasesUTF16(DecodeFromBase64DelegateFncFromUTF16 DecodeFromBase64Delegate) + { + var cases = new List { new char[] { (char)0x53, (char)0x53 } }; + // Define expected results for each case + var expectedResults = new List<(OperationStatus, int)> { (OperationStatus.Done, 1) }; + + for (int i = 0; i < cases.Count; i++) + { + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(cases[i].AsSpan())]; + int bytesConsumed; + int bytesWritten; + + var result = DecodeFromBase64Delegate(cases[i], buffer, out bytesConsumed, out bytesWritten, false); + + Assert.Equal(expectedResults[i].Item1, result); + Assert.Equal(expectedResults[i].Item2, bytesWritten); + } + } + + + + [Fact] + [Trait("Category", "scalar")] + public void DecodeBase64CasesScalarTUF16() + { + DecodeBase64CasesUTF16(Base64.DecodeFromBase64Scalar); + } + + [Fact] + [Trait("Category", "SSE")] + public void DecodeBase64CasesSSETUF16() + { + DecodeBase64CasesUTF16(Base64.DecodeFromBase64SSE); + } + + protected static void CompleteDecodeBase64CasesUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + List<(string decoded, string base64)> cases = new List<(string, string)> + { + ("abcd", " Y\fW\tJ\njZ A=\r= "), + }; + + foreach (var (decoded, base64) in cases) + { + // byte[] base64Bytes = Encoding.UTF8.GetBytes(base64); + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = Base64WithWhiteSpaceToBinaryFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, true); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + for (int i = 0; i < bytesWritten; i++) + { + Assert.Equal(decoded[i], (char)buffer[i]); + } + } + + foreach (var (decoded, base64) in cases) + { + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = DecodeFromBase64DelegateSafeFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, false); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + + for (int i = 0; i < bytesWritten; i++) + { + Assert.Equal(decoded[i], (char)buffer[i]); + } + + + } + + } + + [Fact] + [Trait("Category", "scalar")] + public void CompleteDecodeBase64CasesScalarUTF16() + { + CompleteDecodeBase64CasesUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void CompleteDecodeBase64CasesSSEUTF16() + { + CompleteDecodeBase64CasesUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + protected static void MoreDecodeTestsUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + List<(string decoded, string base64)> cases = new List<(string, string)> + { + ("Hello, World!", "SGVsbG8sIFdvcmxkIQ=="), + ("GeeksforGeeks", "R2Vla3Nmb3JHZWVrcw=="), + ("123456", "MTIzNDU2"), + ("Base64 Encoding", "QmFzZTY0IEVuY29kaW5n"), + ("!R~J2jL&mI]O)3=c:G3Mo)oqmJdxoprTZDyxEvU0MI.'Ww5H{G>}y;;+B8E_Ah,Ed[ PdBqY'^N>O$4:7LK1<:|7)btV@|{YWR$$Er59-XjVrFl4L}~yzTEd4'E[@k", "IVJ+SjJqTCZtSV1PKTM9YzpHM01vKW9xbUpkeG9wclRaRHl4RXZVME1JLidXdzVIe0c+fXk7OytCOEVfQWgsRWRbIFBkQnFZJ15OPk8kNDo3TEsxPDp8NylidFZAfHtZV1IkJEVyNTktWGpWckZsNEx9fnl6VEVkNCdFW0Br") + }; + + foreach (var (decoded, base64) in cases) + { + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = Base64WithWhiteSpaceToBinaryFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, false); + Assert.Equal(OperationStatus.Done, result); + Assert.True(OperationStatus.Done == result, $"Decoding string {decoded} with Length {decoded.Length} bytes went wrong"); + for (int i = 0; i < bytesWritten; i++) + { + Assert.True(decoded[i] == (char)buffer[i], $"Decoded character not equal to source at location {i}: \n Actual: {(char)buffer[i]} ,\n Expected: {decoded[i]},\n Actual string: {BitConverter.ToString(buffer)},\n Expected string :{decoded} "); + } + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + + } + + foreach (var (decoded, base64) in cases) + { + + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = DecodeFromBase64DelegateSafeFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, false); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + + for (int i = 0; i < bytesWritten; i++) + { + Assert.Equal(decoded[i], (char)buffer[i]); + } + } + + } + + [Fact] + [Trait("Category", "scalar")] + public void MoreDecodeTestsScalarUTF16() + { + MoreDecodeTestsUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "SSE")] + public void MoreDecodeTestsSSEUTF16() + { + MoreDecodeTestsUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected static void MoreDecodeTestsUrlUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + List<(string decoded, string base64)> cases = new List<(string, string)> + { + ("Hello, World!", "SGVsbG8sIFdvcmxkIQ=="), + ("GeeksforGeeks", "R2Vla3Nmb3JHZWVrcw=="), + ("123456", "MTIzNDU2"), + ("Base64 Encoding", "QmFzZTY0IEVuY29kaW5n"), + ("!R~J2jL&mI]O)3=c:G3Mo)oqmJdxoprTZDyxEvU0MI.'Ww5H{G>}y;;+B8E_Ah,Ed[ PdBqY'^N>O$4:7LK1<:|7)btV@|{YWR$$Er59-XjVrFl4L}~yzTEd4'E[@k", "IVJ-SjJqTCZtSV1PKTM9YzpHM01vKW9xbUpkeG9wclRaRHl4RXZVME1JLidXdzVIe0c-fXk7OytCOEVfQWgsRWRbIFBkQnFZJ15OPk8kNDo3TEsxPDp8NylidFZAfHtZV1IkJEVyNTktWGpWckZsNEx9fnl6VEVkNCdFW0Br") + }; + + foreach (var (decoded, base64) in cases) + { + + + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = Base64WithWhiteSpaceToBinaryFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, true); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + for (int i = 0; i < bytesWritten; i++) + { + Assert.Equal(decoded[i], (char)buffer[i]); + } + } + + foreach (var (decoded, base64) in cases) + { + + ReadOnlySpan base64Span = new ReadOnlySpan(base64.ToCharArray()); + int bytesConsumed; + int bytesWritten; + + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Span)]; + var result = DecodeFromBase64DelegateSafeFromUTF16(base64Span, buffer, out bytesConsumed, out bytesWritten, true); + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(decoded.Length, bytesWritten); + Assert.Equal(base64.Length, bytesConsumed); + + for (int i = 0; i < bytesWritten; i++) + { + Assert.Equal(decoded[i], (char)buffer[i]); + } + } + } + + [Fact] + [Trait("Category", "sse")] + public void MoreDecodeTestsUrlUTF16SSE() + { + MoreDecodeTestsUrlUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "scalar")] + public void MoreDecodeTestsUTF16UrlUTF16Scalar() + { + MoreDecodeTestsUrlUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void RoundtripBase64UTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); + + string base64String = Convert.ToBase64String(source); + + byte[] decodedBytes = new byte[len]; + + int bytesConsumed, bytesWritten; + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64String.ToCharArray(), decodedBytes.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(source, decodedBytes.AsSpan().ToArray()); + Assert.True(len == bytesWritten, $" Expected bytesWritten: {len} , Actual: {bytesWritten}"); + Assert.True(base64String.Length == bytesConsumed, $" Expected bytesConsumed: {base64String.Length} , Actual: {bytesConsumed}"); + } + } + + [Fact] + [Trait("Category", "scalar")] + public void RoundtripBase64ScalarUTF16() + { + RoundtripBase64UTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void RoundtripBase64SSEUtf16() + { + RoundtripBase64UTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void RoundtripBase64UrlUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); + + string base64String = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_'); + + byte[] decodedBytes = new byte[len]; + + int bytesConsumed, bytesWritten; + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64String.ToCharArray(), decodedBytes.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: true); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(len, bytesWritten); + Assert.Equal(base64String.Length, bytesConsumed); + Assert.Equal(source, decodedBytes.AsSpan().ToArray()); + } + } + + [Fact] + [Trait("Category", "scalar")] + public void RoundtripBase64UrlScalarUTF16() + { + RoundtripBase64UrlUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void RoundtripBase64UrlSSEUtf16() + { + RoundtripBase64UrlUTF16(Base64.DecodeFromBase64SSE, Base64.DecodeFromBase64SSE); + } + + protected static void BadPaddingBase64UTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + Random random = new Random(1234); // use deterministic seed for reproducibility + + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; + int bytesConsumed; + int bytesWritten; + + for (int trial = 0; trial < 10; trial++) + { +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + string base64 = Convert.ToBase64String(source); // Encode source bytes to Base64 + int padding = base64.EndsWith('=') ? 1 : 0; + padding += base64.EndsWith("==", StringComparison.InvariantCulture) ? 1 : 0; + + if (padding != 0) + { + try + { + + // Test adding padding characters should break decoding + List modifiedBase64 = (base64 + "=").ToList(); + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(modifiedBase64.ToArray())]; + for (int i = 0; i < 5; i++) + { + AddSpace(modifiedBase64.ToList(), random); + } + + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + modifiedBase64.ToArray(), buffer.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.InvalidData, result); + } + catch (FormatException) + { + if (padding == 2) + { +#pragma warning disable CA1303 // Do not pass literals as localized parameters + Console.WriteLine($"Wrong OperationStatus when adding one padding character to TWO padding character"); + } + else if (padding == 1) + { +#pragma warning disable CA1303 // Do not pass literals as localized parameters + Console.WriteLine($"Wrong OperationStatus when adding one padding character to ONE padding character"); + } + } + + if (padding == 2) + { + try + { + + // removing one padding characters should break decoding + List modifiedBase64 = base64.Substring(0, base64.Length - 1).ToList(); + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(modifiedBase64.ToArray())]; + for (int i = 0; i < 5; i++) + { + AddSpace(modifiedBase64.ToList(), random); + } + + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + modifiedBase64.ToArray(), buffer.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.InvalidData, result); + } + catch (FormatException) + { +#pragma warning disable CA1303 // Do not pass literals as localized parameters + Console.WriteLine($"Wrong OperationStatus when substracting one padding character"); + } + } + } + else + { + try + { + + // Test adding padding characters should break decoding + List modifiedBase64 = (base64 + "=").ToList(); + byte[] buffer = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(modifiedBase64.ToArray())]; + for (int i = 0; i < 5; i++) + { + AddSpace(modifiedBase64.ToList(), random); + } + + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + modifiedBase64.ToArray(), buffer.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.InvalidData, result); + } + catch (FormatException) + { +#pragma warning disable CA1303 // Do not pass literals as localized parameters + Console.WriteLine($"Wrong OperationStatus when adding one padding character to base64 string with no padding charater"); + } + } + + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void BadPaddingUTF16Base64Scalar() + { + BadPaddingBase64UTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void BadPaddingUTF16Base64SSE() + { + BadPaddingBase64UTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + protected void DoomedBase64RoundtripUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int trial = 0; trial < 10; trial++) + { + int bytesConsumed = 0; + int bytesWritten = 0; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + char[] base64 = Convert.ToBase64String(source).ToCharArray(); + + (char[] base64WithGarbage, int location) = AddGarbage(base64, random); + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64WithGarbage)]; + + // Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.True(OperationStatus.InvalidData == result, $"OperationStatus {result} is not Invalid Data, error at location {location}. "); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + // Also test safe decoding with a specified back_length + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.InvalidData, safeResult); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void DoomedBase64RoundtripScalarUTF16() + { + DoomedBase64RoundtripUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void DoomedBase64RoundtripSSEUTF16() + { + DoomedBase64RoundtripUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + + protected void TruncatedDoomedBase64RoundtripUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int len = 1; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int trial = 0; trial < 10; trial++) + { + + int bytesConsumed = 0; + int bytesWritten = 0; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + char[] base64 = Convert.ToBase64String(source).ToCharArray(); + + char[] base64Truncated = base64[..^3]; // removing last 3 elements with a view + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64Truncated)]; + + // Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64Truncated.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.NeedMoreData, result); + Assert.Equal((base64.Length - 4) / 4 * 3, bytesWritten); + Assert.Equal(base64Truncated.Length, bytesConsumed); + + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64Truncated.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.NeedMoreData, safeResult); + Assert.Equal((base64.Length - 4) / 4 * 3, bytesWritten); + Assert.Equal(base64Truncated.Length, bytesConsumed); + + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void TruncatedDoomedBase64RoundtripScalarUTF16() + { + TruncatedDoomedBase64RoundtripUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void TruncatedDoomedBase64RoundtripSSEUTF16() + { + TruncatedDoomedBase64RoundtripUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void RoundtripBase64WithSpacesUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int len = 0; len < 2048; len++) + { + // Initialize source buffer with random bytes + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); + + // Encode source to Base64 + string base64String = Convert.ToBase64String(source); + char[] base64 = base64String.ToCharArray(); + + for (int i = 0; i < 5; i++) + { + AddSpace(base64.ToList(), random); + } + + + // Prepare buffer for decoded bytes + byte[] decodedBytes = new byte[len]; + + // Call your custom decoding function + int bytesConsumed, bytesWritten; + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64.AsSpan(), decodedBytes.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + // Assert that decoding was successful + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(len, bytesWritten); + Assert.Equal(base64String.Length, bytesConsumed); + Assert.Equal(source, decodedBytes.AsSpan().ToArray()); + + // Safe version not working + result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64.AsSpan(), decodedBytes.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + + // Assert that decoding was successful + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(len, bytesWritten); + Assert.Equal(base64String.Length, bytesConsumed); + Assert.Equal(source, decodedBytes.AsSpan().ToArray()); + } + } + + [Fact] + [Trait("Category", "scalar")] + public void RoundtripBase64WithSpacesScalarUTF16() + { + RoundtripBase64WithSpacesUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void RoundtripBase64WithSpacesSSEUTF16() + { + RoundtripBase64WithSpacesUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void AbortedSafeRoundtripBase64UTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int offset = 1; offset <= 16; offset += 3) + { + for (int len = offset; len < 1024; len++) + { + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Initialize source buffer with random bytes + + string base64String = Convert.ToBase64String(source); + + char[] base64 = base64String.ToCharArray(); + + + + int limitedLength = len - offset; // intentionally too little + byte[] tooSmallArray = new byte[limitedLength]; + + int bytesConsumed = 0; + int bytesWritten = 0; + + var result = DecodeFromBase64DelegateSafeFromUTF16( + base64.AsSpan(), tooSmallArray.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.DestinationTooSmall, result); + Assert.Equal(source.Take(bytesWritten).ToArray(), tooSmallArray.Take(bytesWritten).ToArray()); + + + + // Now let us decode the rest !!! + ReadOnlySpan base64Remains = base64.AsSpan().Slice(bytesConsumed); + + byte[] decodedRemains = new byte[len - bytesWritten]; + + int remainingBytesConsumed = 0; + int remainingBytesWritten = 0; + + result = DecodeFromBase64DelegateSafeFromUTF16( + base64Remains, decodedRemains.AsSpan(), + out remainingBytesConsumed, out remainingBytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(len, bytesWritten + remainingBytesWritten); + Assert.Equal(source.Skip(bytesWritten).ToArray(), decodedRemains.ToArray()); + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void AbortedSafeRoundtripBase64ScalarUTF16() + { + AbortedSafeRoundtripBase64UTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void AbortedSafeRoundtripBase64SSEUTF16() + { + AbortedSafeRoundtripBase64UTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void AbortedSafeRoundtripBase64WithSpacesUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + for (int offset = 1; offset <= 16; offset += 3) + { + for (int len = offset; len < 1024; len++) + { + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Initialize source buffer with random bytes + + string base64String = Convert.ToBase64String(source); + + char[] base64 = base64String.ToCharArray(); + for (int i = 0; i < 5; i++) + { + AddSpace(base64.ToList(), random); + } + + int limitedLength = len - offset; // intentionally too little + byte[] tooSmallArray = new byte[limitedLength]; + + int bytesConsumed = 0; + int bytesWritten = 0; + + var result = DecodeFromBase64DelegateSafeFromUTF16( + base64.AsSpan(), tooSmallArray.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.DestinationTooSmall, result); + Assert.Equal(source.Take(bytesWritten).ToArray(), tooSmallArray.Take(bytesWritten).ToArray()); + + // Now let us decode the rest !!! + ReadOnlySpan base64Remains = base64.AsSpan().Slice(bytesConsumed); + + byte[] decodedRemains = new byte[len - bytesWritten]; + + int remainingBytesConsumed = 0; + int remainingBytesWritten = 0; + + result = DecodeFromBase64DelegateSafeFromUTF16( + base64Remains, decodedRemains.AsSpan(), + out remainingBytesConsumed, out remainingBytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(len, bytesWritten + remainingBytesWritten); + Assert.Equal(source.Skip(bytesWritten).ToArray(), decodedRemains.ToArray()); + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void AbortedSafeRoundtripBase64WithSpacesScalarUTF16() + { + AbortedSafeRoundtripBase64WithSpacesUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void AbortedSafeRoundtripBase64WithSpacesSSEUTF16() + { + AbortedSafeRoundtripBase64WithSpacesUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void StreamingBase64RoundtripUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + int len = 2048; + byte[] source = new byte[len]; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Initialize source buffer with random bytes + + string base64String = Convert.ToBase64String(source); + + char[] base64 = base64String.ToCharArray(); + + for (int window = 16; window <= 2048; window += 7) + { + // build a buffer with enough space to receive the decoded base64 + int bytesConsumed = 0; + int bytesWritten = 0; + + byte[] decodedBytes = new byte[len]; + int outpos = 0; + for (int pos = 0; pos < base64.Length; pos += window) + { + int windowsBytes = Math.Min(window, base64.Length - pos); + +#pragma warning disable CA1062 + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64.AsSpan().Slice(pos, windowsBytes), decodedBytes.AsSpan().Slice(outpos), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.True(result != OperationStatus.InvalidData); + + if (windowsBytes + pos == base64.Length) + { + + // We must check that the last call to base64_to_binary did not + // end with an OperationStatus.NeedMoreData error. + Assert.Equal(OperationStatus.Done, result); + } + else + { + int tailBytesToReprocess = 0; + if (result == OperationStatus.NeedMoreData) + { + tailBytesToReprocess = 1; + } + else + { + tailBytesToReprocess = (bytesWritten % 3) == 0 ? 0 : (bytesWritten % 3) + 1; + } + pos -= tailBytesToReprocess; + bytesWritten -= bytesWritten % 3; + } + outpos += bytesWritten; + } + Assert.Equal(source, decodedBytes); + } + } + + [Fact] + [Trait("Category", "scalar")] + public void StreamingBase64RoundtripScalarUTF16() + { + StreamingBase64RoundtripUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void StreamingBase64RoundtripSSEUTF16() + { + StreamingBase64RoundtripUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected static void ReadmeTestUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + int len = 2048; + string source = new string('a', len); + char[] base64 = source.ToCharArray(); + + // Calculate the required size for 'decoded' to accommodate Base64 decoding + byte[] decodedBytes = new byte[(len + 3) / 4 * 3]; + int outpos = 0; + int window = 512; + + for (int pos = 0; pos < base64.Length; pos += window) + { + int bytesConsumed = 0; + int bytesWritten = 0; + + // how many base64 characters we can process in this iteration + int windowsBytes = Math.Min(window, base64.Length - pos); +#pragma warning disable CA1062 //validate parameter 'Base64WithWhiteSpaceToBinaryFromUTF16' is non-null before using it. + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64.AsSpan().Slice(pos, windowsBytes), decodedBytes.AsSpan().Slice(outpos), + out bytesConsumed, out bytesWritten, isUrl: false); + + Assert.True(result != OperationStatus.InvalidData, $"Invalid base64 character at position {pos + bytesConsumed}"); + + // If we arrived at the end of the base64 input, we must check that the + // number of characters processed is a multiple of 4, or that we have a + // remainder of 0, 2 or 3. + // Eg we must check that the last call to base64_to_binary did not + // end with an OperationStatus.NeedMoreData error. + + if (windowsBytes + pos == base64.Length) + { + Assert.Equal(OperationStatus.Done, result); + } + else + { + // If we are not at the end, we may have to reprocess either 1, 2 or 3 + // bytes, and to drop the last 0, 2 or 3 bytes decoded. + int tailBytesToReprocess = 0; + if (result == OperationStatus.NeedMoreData) + { + tailBytesToReprocess = 1; + } + else + { + tailBytesToReprocess = (bytesWritten % 3) == 0 ? 0 : (bytesWritten % 3) + 1; + } + pos -= tailBytesToReprocess; + bytesWritten -= bytesWritten % 3; + outpos += bytesWritten; + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void ReadmeTestScalarUTF16() + { + ReadmeTestUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void ReadmeTestSSEUTF16() + { + ReadmeTestUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected static void ReadmeTestSafeUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + int len = 72; + string source = new string('a', len); + char[] base64 = source.ToCharArray(); + + byte[] decodedBytesTooSmall = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64) / 2]; // Intentionally too small + + int bytesConsumed = 0; + int bytesWritten = 0; + + var result = DecodeFromBase64DelegateSafeFromUTF16( + base64.AsSpan(), decodedBytesTooSmall.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.DestinationTooSmall, result); + + // We decoded 'limited_length' bytes to back. + // Now let us decode the rest !!! + byte[] decodedRemains = new byte[len - bytesWritten]; + ReadOnlySpan base64Remains = base64.AsSpan().Slice(bytesConsumed); + + int remainingBytesConsumed = 0; + int remainingBytesWritten = 0; + + result = DecodeFromBase64DelegateSafeFromUTF16( + base64Remains, decodedRemains.AsSpan(), + out remainingBytesConsumed, out remainingBytesWritten, isUrl: false); + + Assert.Equal(OperationStatus.Done, result); + Assert.Equal(base64.Length, remainingBytesConsumed + bytesConsumed); + Assert.Equal(Base64.MaximalBinaryLengthFromBase64Scalar(base64), remainingBytesWritten + bytesWritten); + } + + [Fact] + [Trait("Category", "scalar")] + public void ReadmeTestSafeScalarUTF16() + { + ReadmeTestSafeUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void ReadmeTestSafeSSEUTF16() + { + ReadmeTestSafeUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void DoomedBase64AtPos0(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + if (Base64WithWhiteSpaceToBinaryFromUTF16 == null || DecodeFromBase64DelegateSafeFromUTF16 == null || Base64.MaximalBinaryLengthFromBase64Scalar == null) + { +#pragma warning disable CA2208 + throw new ArgumentNullException("Unexpected null parameter"); + } + + List positions = new List(); + for (int i = 0; i < Tables.ToBase64Value.Length; i++) + { + if (Tables.ToBase64Value[i] == 255) + { + positions.Add(i); + } + } + for (int len = 57; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int i = 0; i < positions.Count; i++) + { + int bytesConsumed = 0; + int bytesWritten = 0; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + char[] base64 = Convert.ToBase64String(source).ToCharArray(); + + + + (char[] base64WithGarbage, int location) = AddGarbage(base64, random, 0); + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64)]; + + // Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.InvalidData, result); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + // Also test safe decoding with a specified back_length + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.InvalidData, safeResult); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void DoomedBase64AtPos0ScalarUTF16() + { + DoomedBase64AtPos0(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void DoomedBase64AtPos0SSEUTF16() + { + DoomedBase64AtPos0(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected static void EnronFilesTestUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + string[] fileNames = Directory.GetFiles("../../../../benchmark/data/email"); + string[] FileContent = new string[fileNames.Length]; + + for (int i = 0; i < fileNames.Length; i++) + { + FileContent[i] = File.ReadAllText(fileNames[i]); + } + + foreach (string s in FileContent) + { + char[] base64 = s.ToCharArray(); + + Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64)]; + int bytesConsumed = 0; + int bytesWritten = 0; + + var result = Base64WithWhiteSpaceToBinaryFromUTF16(base64.AsSpan(), output, out bytesConsumed, out bytesWritten, false); + + int bytesConsumedScalar = 0; + int bytesWrittenScalar = 0; + + var resultScalar = DecodeFromBase64DelegateSafeFromUTF16(base64.AsSpan(), output, out bytesConsumedScalar, out bytesWrittenScalar, false); + + Assert.True(result == resultScalar); + Assert.True(result == OperationStatus.Done); + Assert.True(bytesConsumed== bytesConsumedScalar, $"bytesConsumed: {bytesConsumed},bytesConsumedScalar:{bytesConsumedScalar}"); + Assert.True(bytesWritten== bytesWrittenScalar); + } + } + + [Fact] + [Trait("Category", "scalar")] + public void EnronFilesTestScalarUTF16() + { + EnronFilesTestUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void EnronFilesTestSSEUTF16() + { + EnronFilesTestUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + protected static void SwedenZoneBaseFileTestUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + string FilePath = "../../../../benchmark/data/dns/swedenzonebase.txt"; + // Read the contents of the file + string fileContent = File.ReadAllText(FilePath); + + // Convert file content to byte array (assuming it's base64 encoded) + char[] base64Bytes = fileContent.ToCharArray(); + + Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64Bytes)]; + + + // Decode the base64 content + int bytesConsumed, bytesWritten; + var result = Base64WithWhiteSpaceToBinaryFromUTF16(base64Bytes, output, out bytesConsumed, out bytesWritten, false); + + // Assert that the decoding was successful + + int bytesConsumedScalar = 0; + int bytesWrittenScalar = 0; + + var resultScalar = DecodeFromBase64DelegateSafeFromUTF16(base64Bytes.AsSpan(), output, out bytesConsumedScalar, out bytesWrittenScalar, false); + + Assert.True( result == resultScalar,"result != resultScalar"); + Assert.True(bytesConsumed== bytesConsumedScalar, $"bytesConsumed: {bytesConsumed},bytesConsumedScalar:{bytesConsumedScalar}"); + Assert.True(bytesWritten == bytesWrittenScalar, $"bytesWritten: {bytesWritten},bytesWrittenScalar:{bytesWrittenScalar}"); + } + + [Fact] + [Trait("Category", "scalar")] + public void SwedenZoneBaseFileTestScalarUTF16() + { + SwedenZoneBaseFileTestUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void SwedenZoneBaseFileTestSSEUTF16() + { + SwedenZoneBaseFileTestUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + + protected void DoomedPartialBufferUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16, DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + char[] VectorToBeCompressed = new char[] { + (char)0x6D,(char) 0x6A,(char) 0x6D,(char) 0x73,(char) 0x41,(char) 0x71,(char) 0x39,(char) 0x75, + (char)0x76,(char) 0x6C,(char) 0x77,(char) 0x48,(char) 0x20,(char) 0x77,(char) 0x33,(char) 0x53 + }; + + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int trial = 0; trial < 10; trial++) + { + int bytesConsumed = 0; + int bytesWritten = 0; + + int bytesConsumedSafe = 0; + int bytesWrittenSafe = 0; + +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + char[] base64 = Convert.ToBase64String(source).ToCharArray(); + + + (char[] base64WithGarbage, int location) = AddGarbage(base64, random); + + // Insert 1 to 5 copies of the vector right before the garbage + int numberOfCopies = random.Next(1, 6); // Randomly choose 1 to 5 copies + List base64WithGarbageAndTrigger = new List(base64WithGarbage); + int insertPosition = location; // Insert right before the garbage + + for (int i = 0; i < numberOfCopies; i++) + { + base64WithGarbageAndTrigger.InsertRange(insertPosition, VectorToBeCompressed); + insertPosition += VectorToBeCompressed.Length; + } + + // Update the location to reflect the new position of the garbage byte + location += insertPosition; + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64WithGarbageAndTrigger.ToArray())]; + + // Attempt to decode base64 back to binary and assert that it fails + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64WithGarbageAndTrigger.ToArray().AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.True(OperationStatus.InvalidData == result, $"OperationStatus {result} is not Invalid Data, error at location {location}. "); + Assert.Equal(insertPosition, bytesConsumed); + + // Also test safe decoding with a specified back_length + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64WithGarbageAndTrigger.ToArray().AsSpan(), back.AsSpan(), + out bytesConsumedSafe, out bytesWrittenSafe, isUrl: false); + + Assert.True(result == safeResult); + Assert.True(bytesConsumedSafe == bytesConsumed, $"bytesConsumedSafe :{bytesConsumedSafe} != bytesConsumed {bytesConsumed}"); + Assert.True(bytesWrittenSafe == bytesWritten,$"bytesWrittenSafe :{bytesWrittenSafe} != bytesWritten {bytesWritten}"); + + } + } + } + + [Fact] + [Trait("Category", "scalar")] + public void DoomedPartialBufferScalarUTF16() + { + DoomedPartialBufferUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + [Fact] + [Trait("Category", "sse")] + public void DoomedPartialBufferSSEUTF16() + { + DoomedPartialBufferUTF16(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected static void Issue511UTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinary) + { + ArgumentNullException.ThrowIfNull(Base64WithWhiteSpaceToBinary); + + char[] base64Bytes = [ + (char)0x7f, + (char)0x57, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x5a, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x57, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x20, + (char)0x5a, + (char)0x20, + (char)0x5a, + (char)0x5a, + (char)0x5a]; + ReadOnlySpan base64Span = new ReadOnlySpan(base64Bytes); + int bytesConsumed; + int bytesWritten; + byte[] buffer = new byte[48]; + var result = Base64WithWhiteSpaceToBinary(base64Span, buffer, out bytesConsumed, out bytesWritten, true); + Assert.Equal(OperationStatus.InvalidData, result); + + } + + [Fact] + [Trait("Category", "scalar")] + public void Issue511ScalarUTF16() + { + Issue511UTF16(Base64.Base64WithWhiteSpaceToBinaryScalar); + } + + + [Trait("Category", "SSE")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] + public void Issue511SSEUTF16() + { + Issue511UTF16(Base64.DecodeFromBase64SSE); + } + + + protected void TruncatedCharErrorUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16,DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + + string badNonASCIIString = "♡♡♡♡"; + + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int trial = 0; trial < 10; trial++) + { + int bytesConsumed = 0; + int bytesWritten = 0; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + string base64 = Convert.ToBase64String(source); + + int location = random.Next(0, base64.Length + 1)/4; + + char[] base64WithGarbage = base64.Insert(location, badNonASCIIString).ToCharArray(); + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64WithGarbage)]; + + // Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.True(OperationStatus.InvalidData == result, $"OperationStatus {result} is not Invalid Data, error at location {location}. "); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + // Also test safe decoding with a specified back_length + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: false); + Assert.Equal(OperationStatus.InvalidData, safeResult); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + } + } + + + } + + [Fact] + [Trait("Category", "scalar")] + public void TruncatedCharErrorScalarUTF16() + { + TruncatedCharErrorUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar,Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void TruncatedCharErrorUTF16SSE() + { + TruncatedCharErrorUTF16(Base64.DecodeFromBase64SSE,Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + protected void TruncatedCharErrorUrlUTF16(Base64WithWhiteSpaceToBinaryFromUTF16 Base64WithWhiteSpaceToBinaryFromUTF16,DecodeFromBase64DelegateSafeFromUTF16 DecodeFromBase64DelegateSafeFromUTF16) + { + + string badNonASCIIString = "♡♡♡♡"; + + for (int len = 0; len < 2048; len++) + { + byte[] source = new byte[len]; + + for (int trial = 0; trial < 10; trial++) + { + int bytesConsumed = 0; + int bytesWritten = 0; +#pragma warning disable CA5394 // Do not use insecure randomness + random.NextBytes(source); // Generate random bytes for source + + string base64 = Convert.ToBase64String(source).Replace('+', '-').Replace('/', '_'); + + int location = random.Next(0, base64.Length + 1)/4; + + char[] base64WithGarbage = base64.Insert(location, badNonASCIIString).ToCharArray(); + + // Prepare a buffer for decoding the base64 back to binary + byte[] back = new byte[Base64.MaximalBinaryLengthFromBase64Scalar(base64WithGarbage)]; + + // Attempt to decode base64 back to binary and assert that it fails with INVALID_BASE64_CHARACTER + var result = Base64WithWhiteSpaceToBinaryFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: true); + Assert.True(OperationStatus.InvalidData == result, $"OperationStatus {result} is not Invalid Data, error at location {location}. "); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + // Also test safe decoding with a specified back_length + var safeResult = DecodeFromBase64DelegateSafeFromUTF16( + base64WithGarbage.AsSpan(), back.AsSpan(), + out bytesConsumed, out bytesWritten, isUrl: true); + Assert.Equal(OperationStatus.InvalidData, safeResult); + Assert.Equal(location, bytesConsumed); + Assert.Equal(location / 4 * 3, bytesWritten); + + } + } + + + } + + [Fact] + [Trait("Category", "scalar")] + public void TruncatedCharErrorUrlScalarUTF16() + { + TruncatedCharErrorUrlUTF16(Base64.Base64WithWhiteSpaceToBinaryScalar,Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + + [Fact] + [Trait("Category", "sse")] + public void TruncatedCharErrorUrlUTF16SSE() + { + TruncatedCharErrorUrlUTF16(Base64.DecodeFromBase64SSE,Base64.SafeBase64ToBinaryWithWhiteSpace); + } + + +} + + + + + + + + diff --git a/test/Base64DecodingTests.cs b/test/Base64DecodingTestsUTF8.cs similarity index 80% rename from test/Base64DecodingTests.cs rename to test/Base64DecodingTestsUTF8.cs index 1f27a03..2f3da33 100644 --- a/test/Base64DecodingTests.cs +++ b/test/Base64DecodingTestsUTF8.cs @@ -9,136 +9,15 @@ namespace tests; using System.Buffers; using Newtonsoft.Json; -public class Base64DecodingTests +public partial class Base64DecodingTests { - Random random = new Random(12345680); - - private static readonly char[] SpaceCharacters = { ' ', '\t', '\n', '\r' }; -#pragma warning disable CA1002 - protected static void AddSpace(List list, Random random) - { - ArgumentNullException.ThrowIfNull(random); - ArgumentNullException.ThrowIfNull(list); -#pragma warning disable CA5394 // Do not use insecure randomness - int index = random.Next(list.Count + 1); // Random index to insert at -#pragma warning disable CA5394 // Do not use insecure randomness - int charIndex = random.Next(SpaceCharacters.Length); // Random space character - char spaceChar = SpaceCharacters[charIndex]; - byte[] spaceBytes = Encoding.UTF8.GetBytes(new char[] { spaceChar }); - list.Insert(index, spaceBytes[0]); - } - - public static (byte[] modifiedArray, int location) AddGarbage( - byte[] inputArray, Random gen, int? specificLocation = null, byte? specificGarbage = null) - { - ArgumentNullException.ThrowIfNull(inputArray); - ArgumentNullException.ThrowIfNull(gen); - List v = new List(inputArray); - - int len = v.Count; - int i; - - int equalSignIndex = v.FindIndex(c => c == '='); - if (equalSignIndex != -1) - { - len = equalSignIndex; // Adjust the length to before the '=' - } - - if (specificLocation.HasValue && specificLocation.Value < len) - { - i = specificLocation.Value; - } - else - { - i = gen.Next(len + 1); - } - - byte c; - if (specificGarbage.HasValue) - { - c = specificGarbage.Value; - } - else - { - do - { - c = (byte)gen.Next(256); - } while (c == '=' || SimdBase64.Tables.ToBase64Value[c] != 255); - } - - v.Insert(i, c); - - byte[] modifiedArray = v.ToArray(); - - return (modifiedArray, i); - } - - - [Flags] - public enum TestSystemRequirements - { - None = 0, - Arm64 = 1, - X64Avx512 = 2, - X64Avx2 = 4, - X64Sse = 8, - } public delegate OperationStatus DecodeFromBase64DelegateFnc(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); public delegate OperationStatus DecodeFromBase64DelegateSafe(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); public delegate int MaxBase64ToBinaryLengthDelegateFnc(ReadOnlySpan input); public delegate OperationStatus Base64WithWhiteSpaceToBinary(ReadOnlySpan source, Span dest, out int bytesConsumed, out int bytesWritten, bool isUrl); - - - protected sealed class FactOnSystemRequirementAttribute : FactAttribute - { - private TestSystemRequirements RequiredSystems; -#pragma warning disable CA1019 - public FactOnSystemRequirementAttribute(TestSystemRequirements requiredSystems) - { - RequiredSystems = requiredSystems; - - if (!IsSystemSupported(requiredSystems)) - { - Skip = "Test is skipped due to not meeting system requirements."; - } - } - - private static bool IsSystemSupported(TestSystemRequirements requiredSystems) - { - switch (RuntimeInformation.ProcessArchitecture) - { - case Architecture.Arm64: - return requiredSystems.HasFlag(TestSystemRequirements.Arm64); - case Architecture.X64: - return (requiredSystems.HasFlag(TestSystemRequirements.X64Avx512) && Vector512.IsHardwareAccelerated && System.Runtime.Intrinsics.X86.Avx512F.IsSupported) || - (requiredSystems.HasFlag(TestSystemRequirements.X64Avx2) && System.Runtime.Intrinsics.X86.Avx2.IsSupported) || - (requiredSystems.HasFlag(TestSystemRequirements.X64Sse) && System.Runtime.Intrinsics.X86.Ssse3.IsSupported && System.Runtime.Intrinsics.X86.Popcnt.IsSupported); - default: - return false; - } - } - } - - - protected sealed class TestIfCondition : FactAttribute - { -#pragma warning disable CA1019 - public TestIfCondition(Func condition, string skipReason) - { - ArgumentNullException.ThrowIfNull(condition); - // Only set the Skip property if the condition evaluates to false - if (!condition.Invoke()) - { - Skip = skipReason; - } - } - - } - - - protected static void DecodeBase64Cases(DecodeFromBase64DelegateFnc DecodeFromBase64Delegate, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void DecodeBase64CasesUTF8(DecodeFromBase64DelegateFnc DecodeFromBase64Delegate, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (DecodeFromBase64Delegate == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -166,19 +45,19 @@ protected static void DecodeBase64Cases(DecodeFromBase64DelegateFnc DecodeFromBa [Fact] [Trait("Category", "scalar")] - public void DecodeBase64CasesScalar() + public void DecodeBase64CasesScalarUTF8() { - DecodeBase64Cases(Base64.DecodeFromBase64Scalar, Base64.MaximalBinaryLengthFromBase64Scalar); + DecodeBase64CasesUTF8(Base64.DecodeFromBase64SSE, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] public void DecodeBase64CasesSSE() { - DecodeBase64Cases(Base64.DecodeFromBase64SSE, Base64.MaximalBinaryLengthFromBase64Scalar); + DecodeBase64CasesUTF8(Base64.DecodeFromBase64SSE, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void CompleteDecodeBase64Cases(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void CompleteDecodeBase64CasesUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -233,20 +112,20 @@ protected static void CompleteDecodeBase64Cases(Base64WithWhiteSpaceToBinary Bas [Fact] [Trait("Category", "scalar")] - public void CompleteDecodeBase64CasesScalar() + public void CompleteDecodeBase64CasesScalarUTF8() { - CompleteDecodeBase64Cases(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + CompleteDecodeBase64CasesUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] public void CompleteDecodeBase64CasesSSE() { - CompleteDecodeBase64Cases(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + CompleteDecodeBase64CasesUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void Issue511(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary) + protected static void Issue511UTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary) { ArgumentNullException.ThrowIfNull(Base64WithWhiteSpaceToBinary); @@ -325,21 +204,20 @@ protected static void Issue511(Base64WithWhiteSpaceToBinary Base64WithWhiteSpace [Fact] [Trait("Category", "scalar")] - public void Issue511Scalar() + public void Issue511ScalarUTF8() { - Issue511(Base64.Base64WithWhiteSpaceToBinaryScalar); + Issue511UTF8(Base64.Base64WithWhiteSpaceToBinaryScalar); } [Trait("Category", "SSE")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void Issue511SSE() + public void Issue511SSEUTF8() { - Issue511(Base64.DecodeFromBase64SSE); + Issue511UTF8(Base64.DecodeFromBase64SSE); } - - protected static void MoreDecodeTestsUrl(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void MoreDecodeTestsUrlUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -396,19 +274,19 @@ protected static void MoreDecodeTestsUrl(Base64WithWhiteSpaceToBinary Base64With [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void MoreDecodeTestsUrlSSE() + public void MoreDecodeTestsUrlSSEUTF8() { - MoreDecodeTestsUrl(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + MoreDecodeTestsUrlUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Fact] [Trait("Category", "scalar")] - public void MoreDecodeTestsUrlScalar() + public void MoreDecodeTestsUrlScalarUTF8() { - MoreDecodeTestsUrl(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + MoreDecodeTestsUrlUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void RoundtripBase64(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void RoundtripBase64UTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -440,20 +318,20 @@ protected void RoundtripBase64(Base64WithWhiteSpaceToBinary Base64WithWhiteSpace [Fact] [Trait("Category", "scalar")] - public void RoundtripBase64Scalar() + public void RoundtripBase64ScalarUTF8() { - RoundtripBase64(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64UTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void RoundtripBase64SSE() + public void RoundtripBase64SSEUTF8() { - RoundtripBase64(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64UTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void RoundtripBase64Url(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void RoundtripBase64UrlUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -484,20 +362,20 @@ protected void RoundtripBase64Url(Base64WithWhiteSpaceToBinary Base64WithWhiteSp [Fact] [Trait("Category", "scalar")] - public void RoundtripBase64UrlScalar() + public void RoundtripBase64UrlScalarUTF8() { - RoundtripBase64Url(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64UrlUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void RoundtripBase64UrlSSE() + public void RoundtripBase64UrlSSEUTF8() { - RoundtripBase64Url(Base64.DecodeFromBase64SSE, Base64.DecodeFromBase64SSE, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64UrlUTF8(Base64.DecodeFromBase64SSE, Base64.DecodeFromBase64SSE, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void BadPaddingBase64(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void BadPaddingBase64UTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -612,16 +490,16 @@ protected static void BadPaddingBase64(Base64WithWhiteSpaceToBinary Base64WithWh [Fact] [Trait("Category", "scalar")] - public void BadPaddingBase64Scalar() + public void BadPaddingBase64ScalarUTF8() { - BadPaddingBase64(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + BadPaddingBase64UTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void BadPaddingBase64SSE() + public void BadPaddingBase64SSEUTF8() { - BadPaddingBase64(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + BadPaddingBase64UTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } @@ -672,14 +550,14 @@ protected void DoomedBase64Roundtrip(Base64WithWhiteSpaceToBinary Base64WithWhit [Fact] [Trait("Category", "scalar")] - public void DoomedBase64RoundtripScalar() + public void DoomedBase64RoundtripScalarUTF8() { DoomedBase64Roundtrip(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void DoomedBase64RoundtripSSE() + public void DoomedBase64RoundtripSSEUTF8() { DoomedBase64Roundtrip(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } @@ -733,19 +611,19 @@ protected void TruncatedDoomedBase64Roundtrip(Base64WithWhiteSpaceToBinary Base6 [Fact] [Trait("Category", "scalar")] - public void TruncatedDoomedBase64RoundtripScalar() + public void TruncatedDoomedBase64RoundtripScalarUTF8() { TruncatedDoomedBase64Roundtrip(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void TruncatedDoomedBase64RoundtripSSE() + public void TruncatedDoomedBase64RoundtripSSEUTF8() { TruncatedDoomedBase64Roundtrip(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void RoundtripBase64WithSpaces(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void RoundtripBase64WithSpacesUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -799,16 +677,16 @@ protected void RoundtripBase64WithSpaces(Base64WithWhiteSpaceToBinary Base64With [Fact] [Trait("Category", "scalar")] - public void RoundtripBase64WithSpacesScalar() + public void RoundtripBase64WithSpacesScalarUTF8() { - RoundtripBase64WithSpaces(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64WithSpacesUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void RoundtripBase64WithSpacesSSE() + public void RoundtripBase64WithSpacesSSEUTF8() { - RoundtripBase64WithSpaces(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + RoundtripBase64WithSpacesUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } protected void AbortedSafeRoundtripBase64(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) @@ -867,14 +745,14 @@ protected void AbortedSafeRoundtripBase64(Base64WithWhiteSpaceToBinary Base64Wit [Fact] [Trait("Category", "scalar")] - public void AbortedSafeRoundtripBase64Scalar() + public void AbortedSafeRoundtripBase64ScalarUTF8() { AbortedSafeRoundtripBase64(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void AbortedSafeRoundtripBase64SSE() + public void AbortedSafeRoundtripBase64SSEUTF8() { AbortedSafeRoundtripBase64(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } @@ -935,19 +813,19 @@ protected void AbortedSafeRoundtripBase64WithSpaces(Base64WithWhiteSpaceToBinary [Fact] [Trait("Category", "scalar")] - public void AbortedSafeRoundtripBase64WithSpacesScalar() + public void AbortedSafeRoundtripBase64WithSpacesScalarUTF8() { AbortedSafeRoundtripBase64WithSpaces(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void AbortedSafeRoundtripBase64WithSpacesSSE() + public void AbortedSafeRoundtripBase64WithSpacesSSEUTF8() { AbortedSafeRoundtripBase64WithSpaces(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void StreamingBase64Roundtrip(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void StreamingBase64RoundtripUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { int len = 2048; byte[] source = new byte[len]; @@ -1006,20 +884,20 @@ protected void StreamingBase64Roundtrip(Base64WithWhiteSpaceToBinary Base64WithW [Fact] [Trait("Category", "scalar")] - public void StreamingBase64RoundtripScalar() + public void StreamingBase64RoundtripScalarUTF8() { - StreamingBase64Roundtrip(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + StreamingBase64RoundtripUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void StreamingBase64RoundtripSSE() + public void StreamingBase64RoundtripSSEUTF8() { - StreamingBase64Roundtrip(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + StreamingBase64RoundtripUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void ReadmeTest(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void ReadmeTestUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { int len = 2048; string source = new string('a', len); @@ -1076,20 +954,20 @@ protected static void ReadmeTest(Base64WithWhiteSpaceToBinary Base64WithWhiteSpa [Fact] [Trait("Category", "scalar")] - public void ReadmeTestScalar() + public void ReadmeTestScalarUTF8() { - ReadmeTest(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + ReadmeTestUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void ReadmeTestSSE() + public void ReadmeTestSSEUTF8() { - ReadmeTest(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + ReadmeTestUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void ReadmeTestSafe(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void ReadmeTestSafeUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { int len = 72; string source = new string('a', len); @@ -1124,20 +1002,20 @@ protected static void ReadmeTestSafe(Base64WithWhiteSpaceToBinary Base64WithWhit [Fact] [Trait("Category", "scalar")] - public void ReadmeTestSafeScalar() + public void ReadmeTestSafeScalarUTF8() { - ReadmeTestSafe(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + ReadmeTestSafeUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void ReadmeTestSafeSSE() + public void ReadmeTestSafeSSEUTF8() { - ReadmeTestSafe(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + ReadmeTestSafeUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void DoomedBase64AtPos0(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void DoomedBase64AtPos0UTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { if (Base64WithWhiteSpaceToBinary == null || DecodeFromBase64DelegateSafe == null || MaxBase64ToBinaryLengthDelegate == null) { @@ -1195,19 +1073,19 @@ protected void DoomedBase64AtPos0(Base64WithWhiteSpaceToBinary Base64WithWhiteSp [Fact] [Trait("Category", "scalar")] - public void DoomedBase64AtPos0Scalar() + public void DoomedBase64AtPos0ScalarUTF8() { - DoomedBase64AtPos0(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + DoomedBase64AtPos0UTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void DoomedBase64AtPos0SSE() + public void DoomedBase64AtPos0SSEUTF8() { - DoomedBase64AtPos0(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + DoomedBase64AtPos0UTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void EnronFilesTest(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void EnronFilesTestUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { string[] fileNames = Directory.GetFiles("../../../../benchmark/data/email"); string[] FileContent = new string[fileNames.Length]; @@ -1221,7 +1099,7 @@ protected static void EnronFilesTest(Base64WithWhiteSpaceToBinary Base64WithWhit { byte[] base64 = Encoding.UTF8.GetBytes(s); - Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64)]; + Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64)]; int bytesConsumed = 0; int bytesWritten = 0; @@ -1241,20 +1119,20 @@ protected static void EnronFilesTest(Base64WithWhiteSpaceToBinary Base64WithWhit [Fact] [Trait("Category", "scalar")] - public void EnronFilesTestScalar() + public void EnronFilesTestScalarUTF8() { - EnronFilesTest(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + EnronFilesTestUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void EnronFilesTestSSE() + public void EnronFilesTestSSEUTF8() { - EnronFilesTest(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + EnronFilesTestUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected static void SwedenZoneBaseFileTest(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected static void SwedenZoneBaseFileTestUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { string FilePath = "../../../../benchmark/data/dns/swedenzonebase.txt"; // Read the contents of the file @@ -1263,7 +1141,7 @@ protected static void SwedenZoneBaseFileTest(Base64WithWhiteSpaceToBinary Base64 // Convert file content to byte array (assuming it's base64 encoded) byte[] base64Bytes = Encoding.UTF8.GetBytes(fileContent); - Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64Bytes)]; + Span output = new byte[SimdBase64.Base64.MaximalBinaryLengthFromBase64Scalar(base64Bytes)]; // Decode the base64 content @@ -1284,21 +1162,21 @@ protected static void SwedenZoneBaseFileTest(Base64WithWhiteSpaceToBinary Base64 [Fact] [Trait("Category", "scalar")] - public void SwedenZoneBaseFileTestScalar() + public void SwedenZoneBaseFileTestScalarUTF8() { - SwedenZoneBaseFileTest(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + SwedenZoneBaseFileTestUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void SwedenZoneBaseFileTestSSE() + public void SwedenZoneBaseFileTestSSEUTF8() { - SwedenZoneBaseFileTest(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + SwedenZoneBaseFileTestUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } - protected void DoomedPartialBuffer(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) + protected void DoomedPartialBufferUTF8(Base64WithWhiteSpaceToBinary Base64WithWhiteSpaceToBinary, DecodeFromBase64DelegateSafe DecodeFromBase64DelegateSafe, MaxBase64ToBinaryLengthDelegateFnc MaxBase64ToBinaryLengthDelegate) { byte[] VectorToBeCompressed = new byte[] { 0x6D, 0x6A, 0x6D, 0x73, 0x41, 0x71, 0x39, 0x75, @@ -1364,16 +1242,16 @@ protected void DoomedPartialBuffer(Base64WithWhiteSpaceToBinary Base64WithWhiteS [Fact] [Trait("Category", "scalar")] - public void DoomedPartialBufferScalar() + public void DoomedPartialBufferScalarUTF8() { - DoomedPartialBuffer(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + DoomedPartialBufferUTF8(Base64.Base64WithWhiteSpaceToBinaryScalar, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } [Trait("Category", "sse")] [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Sse)] - public void DoomedPartialBufferSSE() + public void DoomedPartialBufferSSEUTF8() { - DoomedPartialBuffer(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); + DoomedPartialBufferUTF8(Base64.DecodeFromBase64SSE, Base64.SafeBase64ToBinaryWithWhiteSpace, Base64.MaximalBinaryLengthFromBase64Scalar); } diff --git a/test/TestHelpers.cs b/test/TestHelpers.cs new file mode 100644 index 0000000..b495a44 --- /dev/null +++ b/test/TestHelpers.cs @@ -0,0 +1,200 @@ +namespace tests; +using System.Text; +using SimdBase64; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; +using System.Runtime.Intrinsics.Arm; +using System.Buffers; +using Newtonsoft.Json; + +public partial class Base64DecodingTests +{ + Random random = new Random(12345680); + + private static readonly char[] SpaceCharacters = { ' ', '\t', '\n', '\r' }; +#pragma warning disable CA1002 + protected static void AddSpace(List list, Random random) + { + ArgumentNullException.ThrowIfNull(random); + ArgumentNullException.ThrowIfNull(list); +#pragma warning disable CA5394 // Do not use insecure randomness + int index = random.Next(list.Count + 1); // Random index to insert at +#pragma warning disable CA5394 // Do not use insecure randomness + int charIndex = random.Next(SpaceCharacters.Length); // Random space character + char spaceChar = SpaceCharacters[charIndex]; + byte[] spaceBytes = Encoding.UTF8.GetBytes(new char[] { spaceChar }); + list.Insert(index, spaceBytes[0]); + } + + protected static void AddSpace(List list, Random random) + { + ArgumentNullException.ThrowIfNull(random); + ArgumentNullException.ThrowIfNull(list); +#pragma warning disable CA5394 // Do not use insecure randomness + int index = random.Next(list.Count + 1); // Random index to insert at +#pragma warning disable CA5394 // Do not use insecure randomness + int charIndex = random.Next(SpaceCharacters.Length); // Random space character + char spaceChar = SpaceCharacters[charIndex]; + list.Insert(index, spaceChar); + } + + public static (byte[] modifiedArray, int location) AddGarbage( + byte[] inputArray, Random gen, int? specificLocation = null, byte? specificGarbage = null) + { + ArgumentNullException.ThrowIfNull(inputArray); + ArgumentNullException.ThrowIfNull(gen); + List v = new List(inputArray); + + int len = v.Count; + int i; + + int equalSignIndex = v.FindIndex(c => c == '='); + if (equalSignIndex != -1) + { + len = equalSignIndex; // Adjust the length to before the '=' + } + + if (specificLocation.HasValue && specificLocation.Value < len) + { + i = specificLocation.Value; + } + else + { + i = gen.Next(len + 1); + } + + byte c; + if (specificGarbage.HasValue) + { + c = specificGarbage.Value; + } + else + { + do + { + c = (byte)gen.Next(256); + } while (c == '=' || SimdBase64.Tables.ToBase64Value[c] != 255); + } + + v.Insert(i, c); + + byte[] modifiedArray = v.ToArray(); + + return (modifiedArray, i); + } + + public static (char[] modifiedArray, int location) AddGarbage( + char[] inputArray, Random gen, int? specificLocation = null, byte? specificGarbage = null) + { + ArgumentNullException.ThrowIfNull(inputArray); + ArgumentNullException.ThrowIfNull(gen); + List v = new List(inputArray); + + int len = v.Count; + int i; + + int equalSignIndex = v.FindIndex(c => c == '='); + if (equalSignIndex != -1) + { + len = equalSignIndex; // Adjust the length to before the '=' + } + + if (specificLocation.HasValue && specificLocation.Value < len) + { + i = specificLocation.Value; + } + else + { + i = gen.Next(len + 1); + } + + char c; + + + do + { + c = (char)gen.Next(256); + } while (c == '=' || SimdBase64.Tables.ToBase64Value[c] != 255); + + v.Insert(i, c); + + char[] modifiedArray = v.ToArray(); + + return (modifiedArray, i); + } + + + + [Flags] + public enum TestSystemRequirements + { + None = 0, + Arm64 = 1, + X64Avx512 = 2, + X64Avx2 = 4, + X64Sse = 8, + } + + protected sealed class FactOnSystemRequirementAttribute : FactAttribute + { + private TestSystemRequirements RequiredSystems; +#pragma warning disable CA1019 + public FactOnSystemRequirementAttribute(TestSystemRequirements requiredSystems) + { + RequiredSystems = requiredSystems; + + if (!IsSystemSupported(requiredSystems)) + { + Skip = "Test is skipped due to not meeting system requirements."; + } + } + + private static bool IsSystemSupported(TestSystemRequirements requiredSystems) + { + switch (RuntimeInformation.ProcessArchitecture) + { + case Architecture.Arm64: + return requiredSystems.HasFlag(TestSystemRequirements.Arm64); + case Architecture.X64: + return (requiredSystems.HasFlag(TestSystemRequirements.X64Avx512) && Vector512.IsHardwareAccelerated && System.Runtime.Intrinsics.X86.Avx512F.IsSupported) || + (requiredSystems.HasFlag(TestSystemRequirements.X64Avx2) && System.Runtime.Intrinsics.X86.Avx2.IsSupported) || + (requiredSystems.HasFlag(TestSystemRequirements.X64Sse) && System.Runtime.Intrinsics.X86.Sse.IsSupported); + default: + return false; + } + } + } + + + protected sealed class TestIfCondition : FactAttribute + { +#pragma warning disable CA1019 + public TestIfCondition(Func condition, string skipReason) + { + ArgumentNullException.ThrowIfNull(condition); + // Only set the Skip property if the condition evaluates to false + if (!condition.Invoke()) + { + Skip = skipReason; + } + } + + } + + + + + + + +} + + + + + + + +