diff --git a/crypto/src/crypto/engines/AesEngine_X86.cs b/crypto/src/crypto/engines/AesEngine_X86.cs index bd7143cd1c..91d64fccb9 100644 --- a/crypto/src/crypto/engines/AesEngine_X86.cs +++ b/crypto/src/crypto/engines/AesEngine_X86.cs @@ -1,41 +1,105 @@ #if NETCOREAPP3_0_OR_GREATER using System; -using System.Buffers.Binary; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; using Org.BouncyCastle.Crypto.Parameters; -using Org.BouncyCastle.Utilities; namespace Org.BouncyCastle.Crypto.Engines { using Aes = System.Runtime.Intrinsics.X86.Aes; using Sse2 = System.Runtime.Intrinsics.X86.Sse2; - public struct AesEngine_X86 - : IBlockCipher + public sealed class AesEngine_X86 : IBlockCipher { public static bool IsSupported => Aes.IsSupported; - private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) + public AesEngine_X86() + { + if (!IsSupported) + throw new PlatformNotSupportedException(nameof(AesEngine_X86)); + } + + public string AlgorithmName => "AES"; + + public int GetBlockSize() => 16; + + private AesEncoderDecoder _implementation; + + public void Init(bool forEncryption, ICipherParameters parameters) + { + if (parameters is not KeyParameter keyParameter) + { + ArgumentNullException.ThrowIfNull(parameters, nameof(parameters)); + throw new ArgumentException("invalid type: " + parameters.GetType(), nameof(parameters)); + } + + Vector128[] roundKeys = CreateRoundKeys(keyParameter.GetKey(), forEncryption); + _implementation = AesEncoderDecoder.Init(forEncryption, roundKeys); + } + + public int ProcessBlock(byte[] inBuf, int inOff, byte[] outBuf, int outOff) { - Vector128[] K; + Check.DataLength(inBuf, inOff, 16); + Check.OutputLength(outBuf, outOff, 16); - switch (key.Length) + Vector128 state = Unsafe.As>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(inBuf), inOff)); + + _implementation.ProcessRounds(ref state); + + Unsafe.As>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(outBuf), outOff)) = state; + + return 16; + } + + public int ProcessBlock(ReadOnlySpan input, Span output) + { + Check.DataLength(input, 16); + Check.OutputLength(output, 16); + + Vector128 state = Unsafe.As>(ref MemoryMarshal.GetReference(input)); + + _implementation.ProcessRounds(ref state); + + Unsafe.As>(ref MemoryMarshal.GetReference(output)) = state; + + return 16;; + } + + private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) + { + Vector128[] K = key.Length switch { - case 16: + 16 => KeyLength16(key), + 24 => KeyLength24(key), + 32 => KeyLength32(key), + _ => throw new ArgumentException("Key length not 128/192/256 bits.") + }; + + if (!forEncryption) { - ReadOnlySpan rcon = stackalloc byte[]{ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 }; + for (int i = 1, last = K.Length - 1; i < last; ++i) + { + K[i] = Aes.InverseMixColumns(K[i]); + } + + Array.Reverse(K); + } - K = new Vector128[11]; + return K; - var s = Load128(key.AsSpan(0, 16)); + static Vector128[] KeyLength16(byte[] key) + { + ReadOnlySpan rcon = stackalloc byte[] { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 }; + + Vector128 s = MemoryMarshal.Read>(key.AsSpan(0, 16)); + Vector128[] K = new Vector128[11]; K[0] = s; for (int round = 0; round < 10;) { - var t = Aes.KeygenAssist(s, rcon[round++]); + Vector128 t = Aes.KeygenAssist(s, rcon[round++]); t = Sse2.Shuffle(t.AsInt32(), 0xFF).AsByte(); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); @@ -43,20 +107,20 @@ private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) K[round] = s; } - break; + return K; } - case 24: - { - K = new Vector128[13]; - var s1 = Load128(key.AsSpan(0, 16)); - var s2 = Load64(key.AsSpan(16, 8)).ToVector128(); + static Vector128[] KeyLength24(byte[] key) + { + Vector128 s1 = MemoryMarshal.Read>(key.AsSpan(0, 16)); + Vector128 s2 = MemoryMarshal.Read>(key.AsSpan(16, 8)).ToVector128(); + Vector128[] K = new Vector128[13]; K[0] = s1; byte rcon = 0x01; - for (int round = 0;;) + for (int round = 0; ;) { - var t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1; + Vector128 t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1; t1 = Sse2.Shuffle(t1.AsInt32(), 0x55).AsByte(); s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8)); @@ -65,14 +129,14 @@ private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) K[++round] = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s1, 8)); - var s3 = Sse2.Xor(s2, Sse2.ShiftRightLogical128BitLane(s1, 12)); + Vector128 s3 = Sse2.Xor(s2, Sse2.ShiftRightLogical128BitLane(s1, 12)); s3 = Sse2.Xor(s3, Sse2.ShiftLeftLogical128BitLane(s3, 4)); K[++round] = Sse2.Xor( Sse2.ShiftRightLogical128BitLane(s1, 8), Sse2.ShiftLeftLogical128BitLane(s3, 8)); - var t2 = Aes.KeygenAssist(s3, rcon); rcon <<= 1; + Vector128 t2 = Aes.KeygenAssist(s3, rcon); rcon <<= 1; t2 = Sse2.Shuffle(t2.AsInt32(), 0x55).AsByte(); s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8)); @@ -89,21 +153,21 @@ private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) s2 = s2.WithUpper(Vector64.Zero); } - break; + return K; } - case 32: - { - K = new Vector128[15]; - var s1 = Load128(key.AsSpan(0, 16)); - var s2 = Load128(key.AsSpan(16, 16)); + static Vector128[] KeyLength32(byte[] key) + { + Vector128 s1 = MemoryMarshal.Read>(key.AsSpan(0, 16)); + Vector128 s2 = MemoryMarshal.Read>(key.AsSpan(16, 16)); + Vector128[] K = new Vector128[15]; K[0] = s1; K[1] = s2; byte rcon = 0x01; - for (int round = 1;;) + for (int round = 1; ;) { - var t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1; + Vector128 t1 = Aes.KeygenAssist(s2, rcon); rcon <<= 1; t1 = Sse2.Shuffle(t1.AsInt32(), 0xFF).AsByte(); s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 8)); s1 = Sse2.Xor(s1, Sse2.ShiftLeftLogical128BitLane(s1, 4)); @@ -113,7 +177,7 @@ private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) if (round == 14) break; - var t2 = Aes.KeygenAssist(s1, 0x00); + Vector128 t2 = Aes.KeygenAssist(s1, 0x00); t2 = Sse2.Shuffle(t2.AsInt32(), 0xAA).AsByte(); s2 = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s2, 8)); s2 = Sse2.Xor(s2, Sse2.ShiftLeftLogical128BitLane(s2, 4)); @@ -121,698 +185,256 @@ private static Vector128[] CreateRoundKeys(byte[] key, bool forEncryption) K[++round] = s2; } - break; + return K; } - default: - throw new ArgumentException("Key length not 128/192/256 bits."); + } + + private abstract class AesEncoderDecoder + { + protected readonly Vector128[] _roundKeys; + + public AesEncoderDecoder(Vector128[] roundKeys) + { + _roundKeys = roundKeys; } - if (!forEncryption) + public static AesEncoderDecoder Init(bool forEncryption, Vector128[] roundKeys) { - for (int i = 1, last = K.Length - 1; i < last; ++i) + if (roundKeys.Length == 11) { - K[i] = Aes.InverseMixColumns(K[i]); + return forEncryption ? new Encode128(roundKeys) : new Decode128(roundKeys); + } + else if (roundKeys.Length == 13) + { + return forEncryption ? new Encode192(roundKeys) : new Decode192(roundKeys); + } + else + { + return forEncryption ? new Encode256(roundKeys) : new Decode256(roundKeys); } - - Array.Reverse(K); } - return K; - } + public abstract void ProcessRounds(ref Vector128 state); - private enum Mode { DEC_128, DEC_192, DEC_256, ENC_128, ENC_192, ENC_256, UNINITIALIZED }; - - private Vector128[] m_roundKeys = null; - private Mode m_mode = Mode.UNINITIALIZED; - - public AesEngine_X86() - { - if (!IsSupported) - throw new PlatformNotSupportedException(nameof(AesEngine_X86)); - } - - public string AlgorithmName => "AES"; - - public int GetBlockSize() => 16; - - public void Init(bool forEncryption, ICipherParameters parameters) - { - if (!(parameters is KeyParameter keyParameter)) + private sealed class Encode128 : AesEncoderDecoder { - ArgumentNullException.ThrowIfNull(parameters, nameof(parameters)); - throw new ArgumentException("invalid type: " + Platform.GetTypeName(parameters), nameof(parameters)); - } + public Encode128(Vector128[] roundKeys) : base(roundKeys) { } - m_roundKeys = CreateRoundKeys(keyParameter.GetKey(), forEncryption); + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[10]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Encrypt(state2, roundKeys[1]); + state2 = Aes.Encrypt(state2, roundKeys[2]); + state2 = Aes.Encrypt(state2, roundKeys[3]); + state2 = Aes.Encrypt(state2, roundKeys[4]); + state2 = Aes.Encrypt(state2, roundKeys[5]); + state2 = Aes.Encrypt(state2, roundKeys[6]); + state2 = Aes.Encrypt(state2, roundKeys[7]); + state2 = Aes.Encrypt(state2, roundKeys[8]); + state2 = Aes.Encrypt(state2, roundKeys[9]); + state2 = Aes.EncryptLast(state2, roundKeys[10]); + // Copy back to ref + state = state2; + } + } - if (m_roundKeys.Length == 11) + private sealed class Decode128 : AesEncoderDecoder { - m_mode = forEncryption ? Mode.ENC_128 : Mode.DEC_128; + public Decode128(Vector128[] roundKeys) : base(roundKeys) { } + + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[10]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Decrypt(state2, roundKeys[1]); + state2 = Aes.Decrypt(state2, roundKeys[2]); + state2 = Aes.Decrypt(state2, roundKeys[3]); + state2 = Aes.Decrypt(state2, roundKeys[4]); + state2 = Aes.Decrypt(state2, roundKeys[5]); + state2 = Aes.Decrypt(state2, roundKeys[6]); + state2 = Aes.Decrypt(state2, roundKeys[7]); + state2 = Aes.Decrypt(state2, roundKeys[8]); + state2 = Aes.Decrypt(state2, roundKeys[9]); + state2 = Aes.DecryptLast(state2, roundKeys[10]); + // Copy back to ref + state = state2; + } } - else if (m_roundKeys.Length == 13) + + private sealed class Encode192 : AesEncoderDecoder { - m_mode = forEncryption ? Mode.ENC_192 : Mode.DEC_192; + public Encode192(Vector128[] roundKeys) : base(roundKeys) { } + + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[12]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Encrypt(state2, roundKeys[1]); + state2 = Aes.Encrypt(state2, roundKeys[2]); + state2 = Aes.Encrypt(state2, roundKeys[3]); + state2 = Aes.Encrypt(state2, roundKeys[4]); + state2 = Aes.Encrypt(state2, roundKeys[5]); + state2 = Aes.Encrypt(state2, roundKeys[6]); + state2 = Aes.Encrypt(state2, roundKeys[7]); + state2 = Aes.Encrypt(state2, roundKeys[8]); + state2 = Aes.Encrypt(state2, roundKeys[9]); + state2 = Aes.Encrypt(state2, roundKeys[10]); + state2 = Aes.Encrypt(state2, roundKeys[11]); + state2 = Aes.EncryptLast(state2, roundKeys[12]); + // Copy back to ref + state = state2; + } } - else + + private sealed class Decode192 : AesEncoderDecoder { - m_mode = forEncryption ? Mode.ENC_256 : Mode.DEC_256; - } - } + public Decode192(Vector128[] roundKeys) : base(roundKeys) { } - public int ProcessBlock(byte[] inBuf, int inOff, byte[] outBuf, int outOff) - { - Check.DataLength(inBuf, inOff, 16, "input buffer too short"); - Check.OutputLength(outBuf, outOff, 16, "output buffer too short"); + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[12]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Decrypt(state2, roundKeys[1]); + state2 = Aes.Decrypt(state2, roundKeys[2]); + state2 = Aes.Decrypt(state2, roundKeys[3]); + state2 = Aes.Decrypt(state2, roundKeys[4]); + state2 = Aes.Decrypt(state2, roundKeys[5]); + state2 = Aes.Decrypt(state2, roundKeys[6]); + state2 = Aes.Decrypt(state2, roundKeys[7]); + state2 = Aes.Decrypt(state2, roundKeys[8]); + state2 = Aes.Decrypt(state2, roundKeys[9]); + state2 = Aes.Decrypt(state2, roundKeys[10]); + state2 = Aes.Decrypt(state2, roundKeys[11]); + state2 = Aes.DecryptLast(state2, roundKeys[12]); + // Copy back to ref + state = state2; + } + } - var state = Load128(inBuf.AsSpan(inOff, 16)); - ImplRounds(ref state); - Store128(state, outBuf.AsSpan(outOff, 16)); - return 16; - } + private sealed class Encode256 : AesEncoderDecoder + { + public Encode256(Vector128[] roundKeys) : base(roundKeys) { } - public int ProcessBlock(ReadOnlySpan input, Span output) - { - Check.DataLength(input, 16, "input buffer too short"); - Check.OutputLength(output, 16, "output buffer too short"); + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[14]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Encrypt(state2, roundKeys[1]); + state2 = Aes.Encrypt(state2, roundKeys[2]); + state2 = Aes.Encrypt(state2, roundKeys[3]); + state2 = Aes.Encrypt(state2, roundKeys[4]); + state2 = Aes.Encrypt(state2, roundKeys[5]); + state2 = Aes.Encrypt(state2, roundKeys[6]); + state2 = Aes.Encrypt(state2, roundKeys[7]); + state2 = Aes.Encrypt(state2, roundKeys[8]); + state2 = Aes.Encrypt(state2, roundKeys[9]); + state2 = Aes.Encrypt(state2, roundKeys[10]); + state2 = Aes.Encrypt(state2, roundKeys[11]); + state2 = Aes.Encrypt(state2, roundKeys[12]); + state2 = Aes.Encrypt(state2, roundKeys[13]); + state2 = Aes.EncryptLast(state2, roundKeys[14]); + // Copy back to ref + state = state2; + } + } - var state = Load128(input[..16]); - ImplRounds(ref state); - Store128(state, output[..16]); - return 16; - } + private sealed class Decode256 : AesEncoderDecoder + { + public Decode256(Vector128[] roundKeys) : base(roundKeys) { } - public int ProcessFourBlocks(ReadOnlySpan input, Span output) - { - Check.DataLength(input, 64, "input buffer too short"); - Check.OutputLength(output, 64, "output buffer too short"); - - var s1 = Load128(input[..16]); - var s2 = Load128(input[16..32]); - var s3 = Load128(input[32..48]); - var s4 = Load128(input[48..64]); - ImplRounds(ref s1, ref s2, ref s3, ref s4); - Store128(s1, output[..16]); - Store128(s2, output[16..32]); - Store128(s3, output[32..48]); - Store128(s4, output[48..64]); - return 64; + public override void ProcessRounds(ref Vector128 state) + { + // Take local refence to array so Jit can reason length doesn't change in method + Vector128[] roundKeys = _roundKeys; + { + // Get the Jit to bounds check once rather than each increasing array access + Vector128 temp = roundKeys[14]; + } + + // Operate on non-ref local so it remains in register rather than operating on memory + Vector128 state2 = Sse2.Xor(state, roundKeys[0]); + state2 = Aes.Decrypt(state2, roundKeys[1]); + state2 = Aes.Decrypt(state2, roundKeys[2]); + state2 = Aes.Decrypt(state2, roundKeys[3]); + state2 = Aes.Decrypt(state2, roundKeys[4]); + state2 = Aes.Decrypt(state2, roundKeys[5]); + state2 = Aes.Decrypt(state2, roundKeys[6]); + state2 = Aes.Decrypt(state2, roundKeys[7]); + state2 = Aes.Decrypt(state2, roundKeys[8]); + state2 = Aes.Decrypt(state2, roundKeys[9]); + state2 = Aes.Decrypt(state2, roundKeys[10]); + state2 = Aes.Decrypt(state2, roundKeys[11]); + state2 = Aes.Decrypt(state2, roundKeys[12]); + state2 = Aes.Decrypt(state2, roundKeys[13]); + state2 = Aes.DecryptLast(state2, roundKeys[14]); + // Copy back to ref + state = state2; + } + } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void ImplRounds(ref Vector128 state) + private static class Check { - switch (m_mode) + public static void DataLength(byte[] buf, int off, int len) { - case Mode.DEC_128: Decrypt128(m_roundKeys, ref state); break; - case Mode.DEC_192: Decrypt192(m_roundKeys, ref state); break; - case Mode.DEC_256: Decrypt256(m_roundKeys, ref state); break; - case Mode.ENC_128: Encrypt128(m_roundKeys, ref state); break; - case Mode.ENC_192: Encrypt192(m_roundKeys, ref state); break; - case Mode.ENC_256: Encrypt256(m_roundKeys, ref state); break; - default: throw new InvalidOperationException(nameof(AesEngine_X86) + " not initialised"); + if (off > (buf.Length - len)) ThrowDataLengthException(); } - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void ImplRounds( - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - switch (m_mode) + public static void DataLength(ReadOnlySpan buf, int len) { - case Mode.DEC_128: DecryptFour128(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - case Mode.DEC_192: DecryptFour192(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - case Mode.DEC_256: DecryptFour256(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - case Mode.ENC_128: EncryptFour128(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - case Mode.ENC_192: EncryptFour192(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - case Mode.ENC_256: EncryptFour256(m_roundKeys, ref s1, ref s2, ref s3, ref s4); break; - default: throw new InvalidOperationException(nameof(AesEngine_X86) + " not initialised"); + if (buf.Length < len) ThrowDataLengthException(); } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt128(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Decrypt(state, roundKeys[1]); - state = Aes.Decrypt(state, roundKeys[2]); - state = Aes.Decrypt(state, roundKeys[3]); - state = Aes.Decrypt(state, roundKeys[4]); - state = Aes.Decrypt(state, roundKeys[5]); - state = Aes.Decrypt(state, roundKeys[6]); - state = Aes.Decrypt(state, roundKeys[7]); - state = Aes.Decrypt(state, roundKeys[8]); - state = Aes.Decrypt(state, roundKeys[9]); - state = Aes.DecryptLast(state, roundKeys[10]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt192(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Decrypt(state, roundKeys[1]); - state = Aes.Decrypt(state, roundKeys[2]); - state = Aes.Decrypt(state, roundKeys[3]); - state = Aes.Decrypt(state, roundKeys[4]); - state = Aes.Decrypt(state, roundKeys[5]); - state = Aes.Decrypt(state, roundKeys[6]); - state = Aes.Decrypt(state, roundKeys[7]); - state = Aes.Decrypt(state, roundKeys[8]); - state = Aes.Decrypt(state, roundKeys[9]); - state = Aes.Decrypt(state, roundKeys[10]); - state = Aes.Decrypt(state, roundKeys[11]); - state = Aes.DecryptLast(state, roundKeys[12]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt256(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Decrypt(state, roundKeys[1]); - state = Aes.Decrypt(state, roundKeys[2]); - state = Aes.Decrypt(state, roundKeys[3]); - state = Aes.Decrypt(state, roundKeys[4]); - state = Aes.Decrypt(state, roundKeys[5]); - state = Aes.Decrypt(state, roundKeys[6]); - state = Aes.Decrypt(state, roundKeys[7]); - state = Aes.Decrypt(state, roundKeys[8]); - state = Aes.Decrypt(state, roundKeys[9]); - state = Aes.Decrypt(state, roundKeys[10]); - state = Aes.Decrypt(state, roundKeys[11]); - state = Aes.Decrypt(state, roundKeys[12]); - state = Aes.Decrypt(state, roundKeys[13]); - state = Aes.DecryptLast(state, roundKeys[14]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour128(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Decrypt(s1, rk[1]); - s2 = Aes.Decrypt(s2, rk[1]); - s3 = Aes.Decrypt(s3, rk[1]); - s4 = Aes.Decrypt(s4, rk[1]); - - s1 = Aes.Decrypt(s1, rk[2]); - s2 = Aes.Decrypt(s2, rk[2]); - s3 = Aes.Decrypt(s3, rk[2]); - s4 = Aes.Decrypt(s4, rk[2]); - - s1 = Aes.Decrypt(s1, rk[3]); - s2 = Aes.Decrypt(s2, rk[3]); - s3 = Aes.Decrypt(s3, rk[3]); - s4 = Aes.Decrypt(s4, rk[3]); - - s1 = Aes.Decrypt(s1, rk[4]); - s2 = Aes.Decrypt(s2, rk[4]); - s3 = Aes.Decrypt(s3, rk[4]); - s4 = Aes.Decrypt(s4, rk[4]); - - s1 = Aes.Decrypt(s1, rk[5]); - s2 = Aes.Decrypt(s2, rk[5]); - s3 = Aes.Decrypt(s3, rk[5]); - s4 = Aes.Decrypt(s4, rk[5]); - - s1 = Aes.Decrypt(s1, rk[6]); - s2 = Aes.Decrypt(s2, rk[6]); - s3 = Aes.Decrypt(s3, rk[6]); - s4 = Aes.Decrypt(s4, rk[6]); - - s1 = Aes.Decrypt(s1, rk[7]); - s2 = Aes.Decrypt(s2, rk[7]); - s3 = Aes.Decrypt(s3, rk[7]); - s4 = Aes.Decrypt(s4, rk[7]); - - s1 = Aes.Decrypt(s1, rk[8]); - s2 = Aes.Decrypt(s2, rk[8]); - s3 = Aes.Decrypt(s3, rk[8]); - s4 = Aes.Decrypt(s4, rk[8]); - - s1 = Aes.Decrypt(s1, rk[9]); - s2 = Aes.Decrypt(s2, rk[9]); - s3 = Aes.Decrypt(s3, rk[9]); - s4 = Aes.Decrypt(s4, rk[9]); - - s1 = Aes.DecryptLast(s1, rk[10]); - s2 = Aes.DecryptLast(s2, rk[10]); - s3 = Aes.DecryptLast(s3, rk[10]); - s4 = Aes.DecryptLast(s4, rk[10]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour192(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Decrypt(s1, rk[1]); - s2 = Aes.Decrypt(s2, rk[1]); - s3 = Aes.Decrypt(s3, rk[1]); - s4 = Aes.Decrypt(s4, rk[1]); - - s1 = Aes.Decrypt(s1, rk[2]); - s2 = Aes.Decrypt(s2, rk[2]); - s3 = Aes.Decrypt(s3, rk[2]); - s4 = Aes.Decrypt(s4, rk[2]); - - s1 = Aes.Decrypt(s1, rk[3]); - s2 = Aes.Decrypt(s2, rk[3]); - s3 = Aes.Decrypt(s3, rk[3]); - s4 = Aes.Decrypt(s4, rk[3]); - - s1 = Aes.Decrypt(s1, rk[4]); - s2 = Aes.Decrypt(s2, rk[4]); - s3 = Aes.Decrypt(s3, rk[4]); - s4 = Aes.Decrypt(s4, rk[4]); - - s1 = Aes.Decrypt(s1, rk[5]); - s2 = Aes.Decrypt(s2, rk[5]); - s3 = Aes.Decrypt(s3, rk[5]); - s4 = Aes.Decrypt(s4, rk[5]); - - s1 = Aes.Decrypt(s1, rk[6]); - s2 = Aes.Decrypt(s2, rk[6]); - s3 = Aes.Decrypt(s3, rk[6]); - s4 = Aes.Decrypt(s4, rk[6]); - - s1 = Aes.Decrypt(s1, rk[7]); - s2 = Aes.Decrypt(s2, rk[7]); - s3 = Aes.Decrypt(s3, rk[7]); - s4 = Aes.Decrypt(s4, rk[7]); - - s1 = Aes.Decrypt(s1, rk[8]); - s2 = Aes.Decrypt(s2, rk[8]); - s3 = Aes.Decrypt(s3, rk[8]); - s4 = Aes.Decrypt(s4, rk[8]); - - s1 = Aes.Decrypt(s1, rk[9]); - s2 = Aes.Decrypt(s2, rk[9]); - s3 = Aes.Decrypt(s3, rk[9]); - s4 = Aes.Decrypt(s4, rk[9]); - - s1 = Aes.Decrypt(s1, rk[10]); - s2 = Aes.Decrypt(s2, rk[10]); - s3 = Aes.Decrypt(s3, rk[10]); - s4 = Aes.Decrypt(s4, rk[10]); - - s1 = Aes.Decrypt(s1, rk[11]); - s2 = Aes.Decrypt(s2, rk[11]); - s3 = Aes.Decrypt(s3, rk[11]); - s4 = Aes.Decrypt(s4, rk[11]); - - s1 = Aes.DecryptLast(s1, rk[12]); - s2 = Aes.DecryptLast(s2, rk[12]); - s3 = Aes.DecryptLast(s3, rk[12]); - s4 = Aes.DecryptLast(s4, rk[12]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour256(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Decrypt(s1, rk[1]); - s2 = Aes.Decrypt(s2, rk[1]); - s3 = Aes.Decrypt(s3, rk[1]); - s4 = Aes.Decrypt(s4, rk[1]); - - s1 = Aes.Decrypt(s1, rk[2]); - s2 = Aes.Decrypt(s2, rk[2]); - s3 = Aes.Decrypt(s3, rk[2]); - s4 = Aes.Decrypt(s4, rk[2]); - - s1 = Aes.Decrypt(s1, rk[3]); - s2 = Aes.Decrypt(s2, rk[3]); - s3 = Aes.Decrypt(s3, rk[3]); - s4 = Aes.Decrypt(s4, rk[3]); - - s1 = Aes.Decrypt(s1, rk[4]); - s2 = Aes.Decrypt(s2, rk[4]); - s3 = Aes.Decrypt(s3, rk[4]); - s4 = Aes.Decrypt(s4, rk[4]); - - s1 = Aes.Decrypt(s1, rk[5]); - s2 = Aes.Decrypt(s2, rk[5]); - s3 = Aes.Decrypt(s3, rk[5]); - s4 = Aes.Decrypt(s4, rk[5]); - - s1 = Aes.Decrypt(s1, rk[6]); - s2 = Aes.Decrypt(s2, rk[6]); - s3 = Aes.Decrypt(s3, rk[6]); - s4 = Aes.Decrypt(s4, rk[6]); - - s1 = Aes.Decrypt(s1, rk[7]); - s2 = Aes.Decrypt(s2, rk[7]); - s3 = Aes.Decrypt(s3, rk[7]); - s4 = Aes.Decrypt(s4, rk[7]); - - s1 = Aes.Decrypt(s1, rk[8]); - s2 = Aes.Decrypt(s2, rk[8]); - s3 = Aes.Decrypt(s3, rk[8]); - s4 = Aes.Decrypt(s4, rk[8]); - - s1 = Aes.Decrypt(s1, rk[9]); - s2 = Aes.Decrypt(s2, rk[9]); - s3 = Aes.Decrypt(s3, rk[9]); - s4 = Aes.Decrypt(s4, rk[9]); - - s1 = Aes.Decrypt(s1, rk[10]); - s2 = Aes.Decrypt(s2, rk[10]); - s3 = Aes.Decrypt(s3, rk[10]); - s4 = Aes.Decrypt(s4, rk[10]); - - s1 = Aes.Decrypt(s1, rk[11]); - s2 = Aes.Decrypt(s2, rk[11]); - s3 = Aes.Decrypt(s3, rk[11]); - s4 = Aes.Decrypt(s4, rk[11]); - - s1 = Aes.Decrypt(s1, rk[12]); - s2 = Aes.Decrypt(s2, rk[12]); - s3 = Aes.Decrypt(s3, rk[12]); - s4 = Aes.Decrypt(s4, rk[12]); - - s1 = Aes.Decrypt(s1, rk[13]); - s2 = Aes.Decrypt(s2, rk[13]); - s3 = Aes.Decrypt(s3, rk[13]); - s4 = Aes.Decrypt(s4, rk[13]); - - s1 = Aes.DecryptLast(s1, rk[14]); - s2 = Aes.DecryptLast(s2, rk[14]); - s3 = Aes.DecryptLast(s3, rk[14]); - s4 = Aes.DecryptLast(s4, rk[14]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt128(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Encrypt(state, roundKeys[1]); - state = Aes.Encrypt(state, roundKeys[2]); - state = Aes.Encrypt(state, roundKeys[3]); - state = Aes.Encrypt(state, roundKeys[4]); - state = Aes.Encrypt(state, roundKeys[5]); - state = Aes.Encrypt(state, roundKeys[6]); - state = Aes.Encrypt(state, roundKeys[7]); - state = Aes.Encrypt(state, roundKeys[8]); - state = Aes.Encrypt(state, roundKeys[9]); - state = Aes.EncryptLast(state, roundKeys[10]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt192(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Encrypt(state, roundKeys[1]); - state = Aes.Encrypt(state, roundKeys[2]); - state = Aes.Encrypt(state, roundKeys[3]); - state = Aes.Encrypt(state, roundKeys[4]); - state = Aes.Encrypt(state, roundKeys[5]); - state = Aes.Encrypt(state, roundKeys[6]); - state = Aes.Encrypt(state, roundKeys[7]); - state = Aes.Encrypt(state, roundKeys[8]); - state = Aes.Encrypt(state, roundKeys[9]); - state = Aes.Encrypt(state, roundKeys[10]); - state = Aes.Encrypt(state, roundKeys[11]); - state = Aes.EncryptLast(state, roundKeys[12]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt256(Vector128[] roundKeys, ref Vector128 state) - { - state = Sse2.Xor(state, roundKeys[0]); - state = Aes.Encrypt(state, roundKeys[1]); - state = Aes.Encrypt(state, roundKeys[2]); - state = Aes.Encrypt(state, roundKeys[3]); - state = Aes.Encrypt(state, roundKeys[4]); - state = Aes.Encrypt(state, roundKeys[5]); - state = Aes.Encrypt(state, roundKeys[6]); - state = Aes.Encrypt(state, roundKeys[7]); - state = Aes.Encrypt(state, roundKeys[8]); - state = Aes.Encrypt(state, roundKeys[9]); - state = Aes.Encrypt(state, roundKeys[10]); - state = Aes.Encrypt(state, roundKeys[11]); - state = Aes.Encrypt(state, roundKeys[12]); - state = Aes.Encrypt(state, roundKeys[13]); - state = Aes.EncryptLast(state, roundKeys[14]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour128(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Encrypt(s1, rk[1]); - s2 = Aes.Encrypt(s2, rk[1]); - s3 = Aes.Encrypt(s3, rk[1]); - s4 = Aes.Encrypt(s4, rk[1]); - - s1 = Aes.Encrypt(s1, rk[2]); - s2 = Aes.Encrypt(s2, rk[2]); - s3 = Aes.Encrypt(s3, rk[2]); - s4 = Aes.Encrypt(s4, rk[2]); - - s1 = Aes.Encrypt(s1, rk[3]); - s2 = Aes.Encrypt(s2, rk[3]); - s3 = Aes.Encrypt(s3, rk[3]); - s4 = Aes.Encrypt(s4, rk[3]); - - s1 = Aes.Encrypt(s1, rk[4]); - s2 = Aes.Encrypt(s2, rk[4]); - s3 = Aes.Encrypt(s3, rk[4]); - s4 = Aes.Encrypt(s4, rk[4]); - - s1 = Aes.Encrypt(s1, rk[5]); - s2 = Aes.Encrypt(s2, rk[5]); - s3 = Aes.Encrypt(s3, rk[5]); - s4 = Aes.Encrypt(s4, rk[5]); - - s1 = Aes.Encrypt(s1, rk[6]); - s2 = Aes.Encrypt(s2, rk[6]); - s3 = Aes.Encrypt(s3, rk[6]); - s4 = Aes.Encrypt(s4, rk[6]); - - s1 = Aes.Encrypt(s1, rk[7]); - s2 = Aes.Encrypt(s2, rk[7]); - s3 = Aes.Encrypt(s3, rk[7]); - s4 = Aes.Encrypt(s4, rk[7]); - - s1 = Aes.Encrypt(s1, rk[8]); - s2 = Aes.Encrypt(s2, rk[8]); - s3 = Aes.Encrypt(s3, rk[8]); - s4 = Aes.Encrypt(s4, rk[8]); - - s1 = Aes.Encrypt(s1, rk[9]); - s2 = Aes.Encrypt(s2, rk[9]); - s3 = Aes.Encrypt(s3, rk[9]); - s4 = Aes.Encrypt(s4, rk[9]); - - s1 = Aes.EncryptLast(s1, rk[10]); - s2 = Aes.EncryptLast(s2, rk[10]); - s3 = Aes.EncryptLast(s3, rk[10]); - s4 = Aes.EncryptLast(s4, rk[10]); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour192(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Encrypt(s1, rk[1]); - s2 = Aes.Encrypt(s2, rk[1]); - s3 = Aes.Encrypt(s3, rk[1]); - s4 = Aes.Encrypt(s4, rk[1]); - - s1 = Aes.Encrypt(s1, rk[2]); - s2 = Aes.Encrypt(s2, rk[2]); - s3 = Aes.Encrypt(s3, rk[2]); - s4 = Aes.Encrypt(s4, rk[2]); - - s1 = Aes.Encrypt(s1, rk[3]); - s2 = Aes.Encrypt(s2, rk[3]); - s3 = Aes.Encrypt(s3, rk[3]); - s4 = Aes.Encrypt(s4, rk[3]); - - s1 = Aes.Encrypt(s1, rk[4]); - s2 = Aes.Encrypt(s2, rk[4]); - s3 = Aes.Encrypt(s3, rk[4]); - s4 = Aes.Encrypt(s4, rk[4]); - - s1 = Aes.Encrypt(s1, rk[5]); - s2 = Aes.Encrypt(s2, rk[5]); - s3 = Aes.Encrypt(s3, rk[5]); - s4 = Aes.Encrypt(s4, rk[5]); - - s1 = Aes.Encrypt(s1, rk[6]); - s2 = Aes.Encrypt(s2, rk[6]); - s3 = Aes.Encrypt(s3, rk[6]); - s4 = Aes.Encrypt(s4, rk[6]); - - s1 = Aes.Encrypt(s1, rk[7]); - s2 = Aes.Encrypt(s2, rk[7]); - s3 = Aes.Encrypt(s3, rk[7]); - s4 = Aes.Encrypt(s4, rk[7]); - - s1 = Aes.Encrypt(s1, rk[8]); - s2 = Aes.Encrypt(s2, rk[8]); - s3 = Aes.Encrypt(s3, rk[8]); - s4 = Aes.Encrypt(s4, rk[8]); - - s1 = Aes.Encrypt(s1, rk[9]); - s2 = Aes.Encrypt(s2, rk[9]); - s3 = Aes.Encrypt(s3, rk[9]); - s4 = Aes.Encrypt(s4, rk[9]); - - s1 = Aes.Encrypt(s1, rk[10]); - s2 = Aes.Encrypt(s2, rk[10]); - s3 = Aes.Encrypt(s3, rk[10]); - s4 = Aes.Encrypt(s4, rk[10]); - - s1 = Aes.Encrypt(s1, rk[11]); - s2 = Aes.Encrypt(s2, rk[11]); - s3 = Aes.Encrypt(s3, rk[11]); - s4 = Aes.Encrypt(s4, rk[11]); - - s1 = Aes.EncryptLast(s1, rk[12]); - s2 = Aes.EncryptLast(s2, rk[12]); - s3 = Aes.EncryptLast(s3, rk[12]); - s4 = Aes.EncryptLast(s4, rk[12]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour256(Vector128[] rk, - ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) - { - s1 = Sse2.Xor(s1, rk[0]); - s2 = Sse2.Xor(s2, rk[0]); - s3 = Sse2.Xor(s3, rk[0]); - s4 = Sse2.Xor(s4, rk[0]); - - s1 = Aes.Encrypt(s1, rk[1]); - s2 = Aes.Encrypt(s2, rk[1]); - s3 = Aes.Encrypt(s3, rk[1]); - s4 = Aes.Encrypt(s4, rk[1]); - - s1 = Aes.Encrypt(s1, rk[2]); - s2 = Aes.Encrypt(s2, rk[2]); - s3 = Aes.Encrypt(s3, rk[2]); - s4 = Aes.Encrypt(s4, rk[2]); - - s1 = Aes.Encrypt(s1, rk[3]); - s2 = Aes.Encrypt(s2, rk[3]); - s3 = Aes.Encrypt(s3, rk[3]); - s4 = Aes.Encrypt(s4, rk[3]); - - s1 = Aes.Encrypt(s1, rk[4]); - s2 = Aes.Encrypt(s2, rk[4]); - s3 = Aes.Encrypt(s3, rk[4]); - s4 = Aes.Encrypt(s4, rk[4]); - - s1 = Aes.Encrypt(s1, rk[5]); - s2 = Aes.Encrypt(s2, rk[5]); - s3 = Aes.Encrypt(s3, rk[5]); - s4 = Aes.Encrypt(s4, rk[5]); - - s1 = Aes.Encrypt(s1, rk[6]); - s2 = Aes.Encrypt(s2, rk[6]); - s3 = Aes.Encrypt(s3, rk[6]); - s4 = Aes.Encrypt(s4, rk[6]); - - s1 = Aes.Encrypt(s1, rk[7]); - s2 = Aes.Encrypt(s2, rk[7]); - s3 = Aes.Encrypt(s3, rk[7]); - s4 = Aes.Encrypt(s4, rk[7]); - - s1 = Aes.Encrypt(s1, rk[8]); - s2 = Aes.Encrypt(s2, rk[8]); - s3 = Aes.Encrypt(s3, rk[8]); - s4 = Aes.Encrypt(s4, rk[8]); - - s1 = Aes.Encrypt(s1, rk[9]); - s2 = Aes.Encrypt(s2, rk[9]); - s3 = Aes.Encrypt(s3, rk[9]); - s4 = Aes.Encrypt(s4, rk[9]); - - s1 = Aes.Encrypt(s1, rk[10]); - s2 = Aes.Encrypt(s2, rk[10]); - s3 = Aes.Encrypt(s3, rk[10]); - s4 = Aes.Encrypt(s4, rk[10]); - - s1 = Aes.Encrypt(s1, rk[11]); - s2 = Aes.Encrypt(s2, rk[11]); - s3 = Aes.Encrypt(s3, rk[11]); - s4 = Aes.Encrypt(s4, rk[11]); - - s1 = Aes.Encrypt(s1, rk[12]); - s2 = Aes.Encrypt(s2, rk[12]); - s3 = Aes.Encrypt(s3, rk[12]); - s4 = Aes.Encrypt(s4, rk[12]); - - s1 = Aes.Encrypt(s1, rk[13]); - s2 = Aes.Encrypt(s2, rk[13]); - s3 = Aes.Encrypt(s3, rk[13]); - s4 = Aes.Encrypt(s4, rk[13]); - - s1 = Aes.EncryptLast(s1, rk[14]); - s2 = Aes.EncryptLast(s2, rk[14]); - s3 = Aes.EncryptLast(s3, rk[14]); - s4 = Aes.EncryptLast(s4, rk[14]); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector128 Load128(ReadOnlySpan t) - { - if (BitConverter.IsLittleEndian && Unsafe.SizeOf>() == 16) - return MemoryMarshal.Read>(t); - - return Vector128.Create( - BinaryPrimitives.ReadUInt64LittleEndian(t[..8]), - BinaryPrimitives.ReadUInt64LittleEndian(t[8..]) - ).AsByte(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static Vector64 Load64(ReadOnlySpan t) - { - if (BitConverter.IsLittleEndian && Unsafe.SizeOf>() == 8) - return MemoryMarshal.Read>(t); - - return Vector64.Create( - BinaryPrimitives.ReadUInt64LittleEndian(t[..8]) - ).AsByte(); - } + public static void OutputLength(byte[] buf, int off, int len) + { + if (off > (buf.Length - len)) ThrowOutputLengthException(); + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Store128(Vector128 s, Span t) - { - if (BitConverter.IsLittleEndian && Unsafe.SizeOf>() == 16) + public static void OutputLength(Span buf, int len) { - MemoryMarshal.Write(t, ref s); - return; + if (buf.Length < len) ThrowOutputLengthException(); } - var u = s.AsUInt64(); - BinaryPrimitives.WriteUInt64LittleEndian(t[..8], u.GetElement(0)); - BinaryPrimitives.WriteUInt64LittleEndian(t[8..], u.GetElement(1)); + private static void ThrowDataLengthException() => throw new DataLengthException("input buffer too short"); + private static void ThrowOutputLengthException() => throw new OutputLengthException("output buffer too short"); } } + } #endif