Skip to content

Commit 3f550a0

Browse files
committed
update jit_gemm primitive for small K
1 parent 05ba78e commit 3f550a0

File tree

1 file changed

+93
-19
lines changed
  • src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64

1 file changed

+93
-19
lines changed

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/matmul.cpp

Lines changed: 93 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ std::vector<ov::AnyMap> filterAdditionalConfig_Brgemm() {
164164
return additionalConfig;
165165
}
166166

167+
std::vector<ov::AnyMap> filterAdditionalConfig_Brgemm_Small_K() {
168+
#ifndef OV_CPU_WITH_MLAS
169+
// FP32 precision is covered by MLAS
170+
std::vector<ov::AnyMap> additionalConfig = {
171+
ov::AnyMap{/* empty config */}
172+
};
173+
#else
174+
std::vector<ov::AnyMap> additionalConfig = {{}};
175+
#endif
176+
177+
return additionalConfig;
178+
}
179+
167180
//For FP32 precision, FC has brgemm avx2 support but Matmul doen't have brgemm avx2.
168181
//Need to specify tryBrgAVX2 based on test case.
169182
std::vector<CPUSpecificParams> filterSpecificParams_Brgemm(bool tryBrgAVX2 = false) {
@@ -177,6 +190,15 @@ std::vector<CPUSpecificParams> filterSpecificParams_Brgemm(bool tryBrgAVX2 = fal
177190
return specificParams;
178191
}
179192

193+
std::vector<CPUSpecificParams> filterSpecificParams_Brgemm_Small_K() {
194+
std::vector<CPUSpecificParams> specificParams;
195+
if (with_cpu_x86_avx512_core()) {
196+
specificParams.push_back(CPUSpecificParams{{}, {}, {"jit_gemm"}, "jit_gemm"});
197+
}
198+
199+
return specificParams;
200+
}
201+
180202
const std::vector<ShapeRelatedParams> IS2D_Brgemm_smoke = {
181203
// needed by 'IS2D_Brgconv1x1_smoke'
182204
{static_shapes_to_test_representation({{1, 120}, {120, 120}}), {true, false}},
@@ -278,6 +300,25 @@ const auto testParams2D_Brgemm_FP16_smoke = ::testing::Combine(fullyConnectedPar
278300
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_Brgemm, MatMulLayerCPUTest, testParams2D_Brgemm_smoke, MatMulLayerCPUTest::getTestCaseName);
279301
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_Brgemm_FP16, MatMulLayerCPUTest, testParams2D_Brgemm_FP16_smoke, MatMulLayerCPUTest::getTestCaseName);
280302

303+
const std::vector<ShapeRelatedParams> IS_brgemm_small_k_smoke = {
304+
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, true}},
305+
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, true}},
306+
};
307+
const auto matMulBrgemmSmallKParams_smoke = ::testing::Combine(::testing::ValuesIn(IS_brgemm_small_k_smoke),
308+
::testing::Values(ElementType::f32),
309+
::testing::Values(ElementType::dynamic),
310+
::testing::Values(ElementType::dynamic),
311+
::testing::Values(utils::InputLayerType::PARAMETER),
312+
::testing::Values(ov::test::utils::DEVICE_CPU),
313+
::testing::ValuesIn(filterAdditionalConfig_Brgemm_Small_K()));
314+
315+
const auto testBrgemmSmallKParams_smoke = ::testing::Combine(matMulBrgemmSmallKParams_smoke,
316+
::testing::Values(MatMulNodeType::MatMul),
317+
::testing::ValuesIn(matmulFusingParams()),
318+
::testing::ValuesIn(filterSpecificParams_Brgemm_Small_K()));
319+
320+
INSTANTIATE_TEST_SUITE_P(smoke_MM_Brgemm_Small_K_Static, MatMulLayerCPUTest, testBrgemmSmallKParams_smoke, MatMulLayerCPUTest::getTestCaseName);
321+
281322
const std::vector<ShapeRelatedParams> IS_brgemm_smoke = {
282323
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {false, false}},
283324
{static_shapes_to_test_representation({{1, 2, 32, 120}, {120, 5}}), {true, false}},
@@ -287,9 +328,6 @@ const std::vector<ShapeRelatedParams> IS_brgemm_smoke = {
287328

288329
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {false, false}},
289330
{static_shapes_to_test_representation({{10, 10, 10}, {10, 10, 10}}), {true, false}},
290-
291-
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {false, true}},
292-
{static_shapes_to_test_representation({{55, 12}, {12, 55}}), {true, true}},
293331
};
294332

295333
const auto matMulBrgemmParams_smoke = ::testing::Combine(::testing::ValuesIn(IS_brgemm_smoke),
@@ -366,21 +404,39 @@ const auto testBrgemmParams_FP16_nightly = ::testing::Combine(matMulBrgemmParams
366404

367405
INSTANTIATE_TEST_SUITE_P(nightly_MM_Brgemm_Static_FP16, MatMulLayerCPUTest, testBrgemmParams_FP16_nightly, MatMulLayerCPUTest::getTestCaseName);
368406

407+
const std::vector<ShapeRelatedParams> IS_Brgemm_Small_K_Dynamic = {
408+
{
409+
{
410+
{{-1, 256}, {{1, 256}}},
411+
{{256, 384}, {{256, 384}}}
412+
},
413+
{false, false}
414+
},
415+
{
416+
{
417+
{{-1, -1}, {{55, 12}, {33, 7}}},
418+
{{-1, -1}, {{12, 55}, {7, 33}}}
419+
},
420+
{false, false}
421+
},
422+
};
423+
424+
const auto matMulBrgemmSmallKParamsDynamic = ::testing::Combine(::testing::ValuesIn(IS_Brgemm_Small_K_Dynamic),
425+
::testing::Values(ElementType::f32),
426+
::testing::Values(ElementType::dynamic),
427+
::testing::Values(ElementType::dynamic),
428+
::testing::Values(utils::InputLayerType::PARAMETER),
429+
::testing::Values(ov::test::utils::DEVICE_CPU),
430+
::testing::ValuesIn(filterAdditionalConfig_Brgemm_Small_K()));
431+
432+
const auto testBrgemmSmallKParamsDynamic = ::testing::Combine(matMulBrgemmSmallKParamsDynamic,
433+
::testing::Values(MatMulNodeType::MatMul),
434+
::testing::Values(emptyFusingSpec),
435+
::testing::ValuesIn(filterSpecificParams_Brgemm_Small_K()));
436+
437+
INSTANTIATE_TEST_SUITE_P(smoke_MM_Brgemm_Small_K_Dynamic, MatMulLayerCPUTest, testBrgemmSmallKParamsDynamic, MatMulLayerCPUTest::getTestCaseName);
438+
369439
const std::vector<ShapeRelatedParams> IS_Brgemm_Dynamic = {
370-
{
371-
{
372-
{{-1, 256}, {{1, 256}}},
373-
{{256, 384}, {{256, 384}}}
374-
},
375-
{false, false}
376-
},
377-
{
378-
{
379-
{{-1, -1}, {{55, 12}, {33, 7}}},
380-
{{-1, -1}, {{12, 55}, {7, 33}}}
381-
},
382-
{false, false}
383-
},
384440
{
385441
{
386442
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}},
@@ -431,7 +487,7 @@ const auto matMulBrgemmParamsDynamic = ::testing::Combine(::testing::ValuesIn(IS
431487
::testing::Values(ElementType::dynamic),
432488
::testing::Values(utils::InputLayerType::PARAMETER),
433489
::testing::Values(ov::test::utils::DEVICE_CPU),
434-
::testing::ValuesIn(filterAdditionalConfig_Brgemm()));
490+
::testing::ValuesIn(filterAdditionalConfig_Brgemm_Small_K()));
435491

436492
const auto testBrgemmParamsDynamic = ::testing::Combine(matMulBrgemmParamsDynamic,
437493
::testing::Values(MatMulNodeType::MatMul),
@@ -455,14 +511,17 @@ const auto testBrgemmParamsDynamic_FP16 = ::testing::Combine(matMulBrgemmParamsD
455511

456512
INSTANTIATE_TEST_SUITE_P(smoke_MM_Brgemm_Dynamic_FP16, MatMulLayerCPUTest, testBrgemmParamsDynamic_FP16, MatMulLayerCPUTest::getTestCaseName);
457513

458-
const std::vector<ShapeRelatedParams> IS_Dynamic_Fusing = {
514+
const std::vector<ShapeRelatedParams> IS_Dynamic_Fusing_Small_K = {
459515
{
460516
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
461517
{{-1, -1}, {{16, 12}, {33, 7}, {16, 12}}}, // input 0
462518
{{-1, 33}, {{12, 33}, {7, 33}, {12, 33}}} // input 1
463519
},
464520
{false, false}
465521
},
522+
};
523+
524+
const std::vector<ShapeRelatedParams> IS_Dynamic_Fusing = {
466525
{
467526
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
468527
{{-1, -1, -1, -1}, {{1, 2, 32, 60}, {1, 2, 32, 30}}}, // input 0
@@ -534,6 +593,21 @@ const auto testParamsDynamicFusing_FP16 = ::testing::Combine(matMulParamsDynamic
534593

535594
INSTANTIATE_TEST_SUITE_P(smoke_MM_Dynamic_Fusing_FP16, MatMulLayerCPUTest, testParamsDynamicFusing_FP16, MatMulLayerCPUTest::getTestCaseName);
536595

596+
const auto matMulParamsBrgemmSmallKDynamicFusing = ::testing::Combine(::testing::ValuesIn(IS_Dynamic_Fusing_Small_K),
597+
::testing::Values(ElementType::f32),
598+
::testing::Values(ElementType::dynamic),
599+
::testing::Values(ElementType::dynamic),
600+
::testing::Values(utils::InputLayerType::PARAMETER),
601+
::testing::Values(ov::test::utils::DEVICE_CPU),
602+
::testing::ValuesIn(filterAdditionalConfig_Brgemm()));
603+
604+
const auto testParamsBrgemmSmallKDynamicFusing = ::testing::Combine(matMulParamsBrgemmSmallKDynamicFusing,
605+
::testing::Values(MatMulNodeType::MatMul),
606+
::testing::ValuesIn(matmulFusingParams()),
607+
::testing::ValuesIn(filterSpecificParams_Brgemm_Small_K()));
608+
609+
INSTANTIATE_TEST_SUITE_P(smoke_MM_Brgemm_Small_K_Dynamic_Fusing, MatMulLayerCPUTest, testParamsBrgemmSmallKDynamicFusing, MatMulLayerCPUTest::getTestCaseName);
610+
537611
const auto matMulParamsBrgemmDynamicFusing = ::testing::Combine(::testing::ValuesIn(IS_Dynamic_Fusing),
538612
::testing::Values(ElementType::f32),
539613
::testing::Values(ElementType::dynamic),

0 commit comments

Comments
 (0)