Skip to content

Commit b4e1355

Browse files
committed
Improve AddImpl and SubtractImpl with Avx512
1 parent 53d6e10 commit b4e1355

File tree

1 file changed

+59
-45
lines changed

1 file changed

+59
-45
lines changed

src/Nethermind.Int256/UInt256.cs

+59-45
Original file line numberDiff line numberDiff line change
@@ -406,24 +406,32 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
406406
{
407407
if (Avx2.IsSupported)
408408
{
409-
var av = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in a));
410-
var bv = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in b));
409+
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
410+
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));
411411

412-
var result = Avx2.Add(av, bv);
413-
414-
var carryFromBothHighBits = Avx2.And(av, bv);
415-
var eitherHighBit = Avx2.Or(av, bv);
416-
var highBitNotInResult = Avx2.AndNot(result, eitherHighBit);
412+
Vector256<ulong> result = Avx2.Add(av, bv);
413+
Vector256<ulong> vCarry;
414+
if (Avx512F.VL.IsSupported)
415+
{
416+
vCarry = Avx512F.VL.CompareLessThan(result, av);
417+
}
418+
else
419+
{
420+
// Work around for missing Vector256.CompareLessThan
421+
Vector256<ulong> carryFromBothHighBits = Avx2.And(av, bv);
422+
Vector256<ulong> eitherHighBit = Avx2.Or(av, bv);
423+
Vector256<ulong> highBitNotInResult = Avx2.AndNot(result, eitherHighBit);
417424

418-
// Set high bits where carry occurs
419-
var vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
425+
// Set high bits where carry occurs
426+
vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
427+
}
420428
// Move carry from Vector space to int
421-
var carry = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCarry));
429+
int carry = Avx.MoveMask(vCarry.AsDouble());
422430

423431
// All bits set will cascade another carry when carry is added to it
424-
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
432+
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
425433
// Move cascade from Vector space to int
426-
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
434+
int cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
427435

428436
// Use ints to work out the Vector cross lane cascades
429437
// Move carry to next bit and add cascade
@@ -434,12 +442,12 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
434442
cascade &= 0x0f;
435443

436444
// Lookup the carries to broadcast to the Vectors
437-
var cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
445+
Vector256<ulong> cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
438446

439-
// Mark res as initalized so we can use it as left said of ref assignment
447+
// Mark res as initialized so we can use it as left said of ref assignment
440448
Unsafe.SkipInit(out res);
441449
// Add the cascadedCarries to the result
442-
Unsafe.As<UInt256,Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);
450+
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);
443451

444452
return (carry & 0b1_0000) != 0;
445453
}
@@ -458,7 +466,6 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
458466
// Debug.Assert((BigInteger)res == ((BigInteger)a + (BigInteger)b) % ((BigInteger)1 << 256));
459467
// #endif
460468
}
461-
462469
public void Add(in UInt256 a, out UInt256 res) => Add(this, a, out res);
463470

464471
/// <summary>
@@ -665,7 +672,7 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
665672
int uLen = 0;
666673
for (int i = length - 1; i >= 0; i--)
667674
{
668-
if (Unsafe.Add(ref u,i) != 0)
675+
if (Unsafe.Add(ref u, i) != 0)
669676
{
670677
uLen = i + 1;
671678
break;
@@ -730,13 +737,13 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
730737
goto r3;
731738
}
732739

733-
r3:
740+
r3:
734741
rem2 = Rsh(un[2], shift) | Lsh(un[3], 64 - shift);
735-
r2:
742+
r2:
736743
rem1 = Rsh(un[1], shift) | Lsh(un[2], 64 - shift);
737-
r1:
744+
r1:
738745
rem0 = Rsh(un[0], shift) | Lsh(un[1], 64 - shift);
739-
r0:
746+
r0:
740747

741748
rem = new UInt256(rem0, rem1, rem2, rem3);
742749
}
@@ -879,25 +886,32 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
879886
{
880887
if (Avx2.IsSupported)
881888
{
882-
var av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
883-
var bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));
884-
885-
var result = Avx2.Subtract(av, bv);
886-
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
887-
var resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
888-
var avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));
889+
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
890+
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));
889891

890-
// Which vectors need to borrow from the next
891-
var vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
892-
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned));
892+
Vector256<ulong> result = Avx2.Subtract(av, bv);
893+
Vector256<ulong> vBorrow;
894+
if (Avx512F.VL.IsSupported)
895+
{
896+
vBorrow = Avx512F.VL.CompareGreaterThan(result, av);
897+
}
898+
else
899+
{
900+
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
901+
Vector256<ulong> resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
902+
Vector256<ulong> avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));
893903

904+
// Which vectors need to borrow from the next
905+
vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
906+
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned)).AsUInt64();
907+
}
894908
// Move borrow from Vector space to int
895-
var borrow = Avx.MoveMask(Unsafe.As<Vector256<long>, Vector256<double>>(ref vBorrow));
909+
int borrow = Avx.MoveMask(vBorrow.AsDouble());
896910

897911
// All zeros will cascade another borrow when borrow is subtracted from it
898-
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
912+
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
899913
// Move cascade from Vector space to int
900-
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
914+
int cascade = Avx.MoveMask(vCascade.AsDouble());
901915

902916
// Use ints to work out the Vector cross lane cascades
903917
// Move borrow to next bit and add cascade
@@ -908,9 +922,9 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
908922
cascade &= 0x0f;
909923

910924
// Lookup the borrows to broadcast to the Vectors
911-
var cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
925+
Vector256<ulong> cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
912926

913-
// Mark res as initalized so we can use it as left said of ref assignment
927+
// Mark res as initialized so we can use it as left said of ref assignment
914928
Unsafe.SkipInit(out res);
915929
// Subtract the cascadedBorrows from the result
916930
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Subtract(result, cascadedBorrows);
@@ -1315,15 +1329,15 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res)
13151329
a = Rsh(res.u0, 64 - n);
13161330
z0 = Lsh(res.u0, n);
13171331

1318-
sh64:
1332+
sh64:
13191333
b = Rsh(res.u1, 64 - n);
13201334
z1 = Lsh(res.u1, n) | a;
13211335

1322-
sh128:
1336+
sh128:
13231337
a = Rsh(res.u2, 64 - n);
13241338
z2 = Lsh(res.u2, n) | b;
13251339

1326-
sh192:
1340+
sh192:
13271341
z3 = Lsh(res.u3, n) | a;
13281342

13291343
res = new UInt256(z0, z1, z2, z3);
@@ -1425,15 +1439,15 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res)
14251439
a = Lsh(res.u3, 64 - n);
14261440
z3 = Rsh(res.u3, n);
14271441

1428-
sh64:
1442+
sh64:
14291443
b = Lsh(res.u2, 64 - n);
14301444
z2 = Rsh(res.u2, n) | a;
14311445

1432-
sh128:
1446+
sh128:
14331447
a = Lsh(res.u1, 64 - n);
14341448
z1 = Rsh(res.u1, n) | b;
14351449

1436-
sh192:
1450+
sh192:
14371451
z0 = Rsh(res.u0, n) | a;
14381452

14391453
res = new UInt256(z0, z1, z2, z3);
@@ -1923,13 +1937,13 @@ public static bool TryParse(in ReadOnlySpan<char> value, NumberStyles style, IFo
19231937
public TypeCode GetTypeCode() => TypeCode.Object;
19241938
public bool ToBoolean(IFormatProvider? provider) => !IsZero;
19251939
public byte ToByte(IFormatProvider? provider) => System.Convert.ToByte(ToDecimal(provider), provider);
1926-
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
1927-
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
1940+
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
1941+
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
19281942
public decimal ToDecimal(IFormatProvider? provider) => (decimal)this;
19291943
public double ToDouble(IFormatProvider? provider) => (double)this;
19301944
public short ToInt16(IFormatProvider? provider) => System.Convert.ToInt16(ToDecimal(provider), provider);
19311945
public int ToInt32(IFormatProvider? provider) => System.Convert.ToInt32(ToDecimal(provider), provider);
1932-
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
1946+
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
19331947
public sbyte ToSByte(IFormatProvider? provider) => System.Convert.ToSByte(ToDecimal(provider), provider);
19341948
public float ToSingle(IFormatProvider? provider) => System.Convert.ToSingle(ToDouble(provider), provider);
19351949
public string ToString(IFormatProvider? provider) => ((BigInteger)this).ToString(provider);

0 commit comments

Comments
 (0)