Skip to content

Commit b89941e

Browse files
kunalspathakRuihan-Yin
authored andcommitted
Arm64/Sve: Implement SVE Math *Multiply* APIs (dotnet#102007)
* Add *Fused* APIs * fix an assert in morph * Map APIs to instructions * Add test cases * handle fused* instructions * jit format * Added MultiplyAdd/MultiplySubtract * Add mapping of API to instruction * Add test cases * Handle mov Z, Z instruction * Reuse GetResultOpNumForRmwIntrinsic() for arm64 * Reuse HW_Flag_FmaIntrinsic for arm64 * Mark FMA APIs as HW_Flag_FmaIntrinsic * Handle FMA in LSRA and codegen * Remove the SpecialCodeGen flag from selectedScalar * address some more scenarios * jit format * Add MultiplyBySelectedScalar * Map the API to the instruction * fix a bug where *Indexed API used with ConditionalSelect were failing ` Sve.ConditionalSelect(op1, Sve.MultiplyBySelectedScalar(op1, op2, 0), op3);` was failing because we were trying to check if `MultiplyBySelectedScalar` is contained and we hit the assert because it is not containable. Added the check. * unpredicated movprfx should not send opt * Add the missing flags for Subtract/Multiply * Added tests for MultiplyBySelectedScalar Also updated *SelectedScalar* tests for ConditionalSelect * fixes to test cases * fix the parameter for selectedScalar test * jit format * Contain(op3) of CndSel if op1 is AllTrueMask * Handle FMA properly * added assert
1 parent b4778ce commit b89941e

18 files changed

+2292
-43
lines changed

src/coreclr/jit/emitarm64.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -4250,9 +4250,11 @@ void emitter::emitIns_Mov(
42504250

42514251
case INS_sve_mov:
42524252
{
4253-
if (isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
4253+
// TODO-SVE: Remove check for insOptsNone() when predicate registers
4254+
// are present.
4255+
if (insOptsNone(opt) && isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
42544256
{
4255-
assert(insOptsNone(opt));
4257+
// assert(insOptsNone(opt));
42564258

42574259
opt = INS_OPTS_SCALABLE_B;
42584260
attr = EA_SCALABLE;
@@ -4263,6 +4265,16 @@ void emitter::emitIns_Mov(
42634265
}
42644266
fmt = IF_SVE_CZ_4A_L;
42654267
}
4268+
else if (isVectorRegister(dstReg) && isVectorRegister(srcReg))
4269+
{
4270+
assert(insOptsScalable(opt));
4271+
4272+
if (IsRedundantMov(ins, size, dstReg, srcReg, canSkip))
4273+
{
4274+
return;
4275+
}
4276+
fmt = IF_SVE_AU_3A;
4277+
}
42664278
else
42674279
{
42684280
unreached();

src/coreclr/jit/emitarm64sve.cpp

+31-3
Original file line numberDiff line numberDiff line change
@@ -10374,7 +10374,6 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
1037410374
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
1037510375
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
1037610376
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
10377-
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
1037810377
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
1037910378
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
1038010379
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
@@ -10396,6 +10395,17 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
1039610395
dst += emitOutput_Instr(dst, code);
1039710396
break;
1039810397

10398+
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
10399+
code = emitInsCodeSve(ins, fmt);
10400+
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
10401+
code |= insEncodeReg_V<9, 5>(id->idReg2()); // nnnnn
10402+
if (id->idIns() != INS_sve_mov)
10403+
{
10404+
code |= insEncodeReg_V<20, 16>(id->idReg3()); // mmmmm
10405+
}
10406+
dst += emitOutput_Instr(dst, code);
10407+
break;
10408+
1039910409
case IF_SVE_AV_3A: // ...........mmmmm ......kkkkkddddd -- SVE2 bitwise ternary operations
1040010410
code = emitInsCodeSve(ins, fmt);
1040110411
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
@@ -12882,7 +12892,6 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
1288212892
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
1288312893
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
1288412894
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
12885-
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
1288612895
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
1288712896
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
1288812897
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
@@ -12902,6 +12911,12 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
1290212911
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
1290312912
assert(isVectorRegister(id->idReg3())); // mmmmm/aaaaa
1290412913
break;
12914+
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
12915+
assert(insOptsScalable(id->idInsOpt()));
12916+
assert(isVectorRegister(id->idReg1())); // ddddd
12917+
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
12918+
assert((id->idIns() == INS_sve_mov) || isVectorRegister(id->idReg3())); // mmmmm/aaaaa
12919+
break;
1290512920

1290612921
case IF_SVE_HA_3A_F: // ...........mmmmm ......nnnnnddddd -- SVE BFloat16 floating-point dot product
1290712922
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
@@ -14526,7 +14541,6 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
1452614541
case IF_SVE_HD_3A_A: // ...........mmmmm ......nnnnnddddd -- SVE floating point matrix multiply accumulate
1452714542
// <Zd>.D, <Zn>.D, <Zm>.D
1452814543
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
14529-
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
1453014544
// <Zd>.B, <Zn>.B, <Zm>.B
1453114545
case IF_SVE_GF_3A: // ........xx.mmmmm ......nnnnnddddd -- SVE2 histogram generation (segment)
1453214546
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
@@ -14541,6 +14555,20 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
1454114555
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
1454214556
break;
1454314557

14558+
// <Zd>.D, <Zn>.D, <Zm>.D
14559+
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
14560+
emitDispSveReg(id->idReg1(), id->idInsOpt(), true); // ddddd
14561+
if (id->idIns() == INS_sve_mov)
14562+
{
14563+
emitDispSveReg(id->idReg2(), id->idInsOpt(), false); // nnnnn/mmmmm
14564+
}
14565+
else
14566+
{
14567+
emitDispSveReg(id->idReg2(), id->idInsOpt(), true); // nnnnn/mmmmm
14568+
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
14569+
}
14570+
break;
14571+
1454414572
// <Zda>.D, <Zn>.D, <Zm>.D
1454514573
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
1454614574
// <Zdn>.D, <Zm>.D, <Za>.D

src/coreclr/jit/gentree.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -27955,7 +27955,7 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
2795527955
return comp->lvaGetDesc(GetLclNum())->IsNeverNegative();
2795627956
}
2795727957

27958-
#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
27958+
#if (defined(TARGET_XARCH) || defined(TARGET_ARM64)) && defined(FEATURE_HW_INTRINSICS)
2795927959
//------------------------------------------------------------------------
2796027960
// GetResultOpNumForRmwIntrinsic: check if the result is written into one of the operands.
2796127961
// In the case that none of the operand is overwritten, check if any of them is lastUse.
@@ -27966,7 +27966,11 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
2796627966
//
2796727967
unsigned GenTreeHWIntrinsic::GetResultOpNumForRmwIntrinsic(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
2796827968
{
27969+
#if defined(TARGET_XARCH)
2796927970
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId) || HWIntrinsicInfo::IsPermuteVar2x(gtHWIntrinsicId));
27971+
#elif defined(TARGET_ARM64)
27972+
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId));
27973+
#endif
2797027974

2797127975
if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
2797227976
{

src/coreclr/jit/hwintrinsic.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -216,27 +216,27 @@ enum HWIntrinsicFlag : unsigned int
216216
// The intrinsic is an RMW intrinsic
217217
HW_Flag_RmwIntrinsic = 0x1000000,
218218

219-
// The intrinsic is a FusedMultiplyAdd intrinsic
220-
HW_Flag_FmaIntrinsic = 0x2000000,
221-
222219
// The intrinsic is a PermuteVar2x intrinsic
223-
HW_Flag_PermuteVar2x = 0x4000000,
220+
HW_Flag_PermuteVar2x = 0x2000000,
224221

225222
// The intrinsic is an embedded broadcast compatible intrinsic
226-
HW_Flag_EmbBroadcastCompatible = 0x8000000,
223+
HW_Flag_EmbBroadcastCompatible = 0x4000000,
227224

228225
// The intrinsic is an embedded rounding compatible intrinsic
229-
HW_Flag_EmbRoundingCompatible = 0x10000000,
226+
HW_Flag_EmbRoundingCompatible = 0x8000000,
230227

231228
// The intrinsic is an embedded masking compatible intrinsic
232-
HW_Flag_EmbMaskingCompatible = 0x20000000,
229+
HW_Flag_EmbMaskingCompatible = 0x10000000,
233230
#elif defined(TARGET_ARM64)
234231

235232
// The intrinsic has an enum operand. Using this implies HW_Flag_HasImmediateOperand.
236233
HW_Flag_HasEnumOperand = 0x1000000,
237234

238235
#endif // TARGET_XARCH
239236

237+
// The intrinsic is a FusedMultiplyAdd intrinsic
238+
HW_Flag_FmaIntrinsic = 0x20000000,
239+
240240
HW_Flag_CanBenefitFromConstantProp = 0x80000000,
241241
};
242242

@@ -935,17 +935,17 @@ struct HWIntrinsicInfo
935935
return (flags & HW_Flag_MaybeNoJmpTableIMM) != 0;
936936
}
937937

938-
#if defined(TARGET_XARCH)
939-
static bool IsRmwIntrinsic(NamedIntrinsic id)
938+
static bool IsFmaIntrinsic(NamedIntrinsic id)
940939
{
941940
HWIntrinsicFlag flags = lookupFlags(id);
942-
return (flags & HW_Flag_RmwIntrinsic) != 0;
941+
return (flags & HW_Flag_FmaIntrinsic) != 0;
943942
}
944943

945-
static bool IsFmaIntrinsic(NamedIntrinsic id)
944+
#if defined(TARGET_XARCH)
945+
static bool IsRmwIntrinsic(NamedIntrinsic id)
946946
{
947947
HWIntrinsicFlag flags = lookupFlags(id);
948-
return (flags & HW_Flag_FmaIntrinsic) != 0;
948+
return (flags & HW_Flag_RmwIntrinsic) != 0;
949949
}
950950

951951
static bool IsPermuteVar2x(NamedIntrinsic id)

src/coreclr/jit/hwintrinsicarm64.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ void HWIntrinsicInfo::lookupImmBounds(
277277
case NI_AdvSimd_Arm64_StoreSelectedScalarVector128x4:
278278
case NI_AdvSimd_Arm64_DuplicateSelectedScalarToVector128:
279279
case NI_AdvSimd_Arm64_InsertSelectedScalar:
280+
case NI_Sve_FusedMultiplyAddBySelectedScalar:
281+
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
280282
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
281283
break;
282284

0 commit comments

Comments
 (0)