diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py index cae73ba079c6..2a68ec599f71 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py @@ -155,6 +155,7 @@ def getRequiredParametersMin() -> set: 'WaveSeparateGlobalReadB', 'WavefrontSize', 'WorkGroup', + 'WorkGroupMappingXCC', 'DtlPlusLdsBuf', 'MinGRIncPerMfma', 'UsePLRPack', diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py index 47d8bff28c41..6ba34cc2e479 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py @@ -34,7 +34,7 @@ # fields compile to identical code objects. _INTERNAL_ARGS = ( "WorkGroupMapping", - "WorkGroupMappingXCC", + # "WorkGroupMappingXCC", # WGMXCC affects asm code gen "WorkGroupMappingXCCGroup", "StaggerU", "StaggerUStride", @@ -69,6 +69,7 @@ def getKeyNoInternalArgs(state, splitGSU: bool) -> str: backups = {k: s[k] for k in _INTERNAL_ARGS} gsu_backup = s["GlobalSplitU"] gg_backup = pt["GroupedGemm"] + wgmxcc_backup = s["WorkGroupMappingXCC"] # Mask internal args pt["GroupedGemm"] = False @@ -76,6 +77,8 @@ def getKeyNoInternalArgs(state, splitGSU: bool) -> str: s["GlobalSplitU"] = "M" if (gsu_backup > 1 or gsu_backup == -1) else gsu_backup elif gsu_backup != 0: s["GlobalSplitU"] = "M" + if "WorkGroupMappingXCC" in s and s["WorkGroupMappingXCC"] != -1: + s["WorkGroupMappingXCC"] = 1 for k in _INTERNAL_ARGS: s[k] = "M" @@ -85,6 +88,7 @@ def getKeyNoInternalArgs(state, splitGSU: bool) -> str: # Restore pt["GroupedGemm"] = gg_backup s["GlobalSplitU"] = gsu_backup + s["WorkGroupMappingXCC"] = wgmxcc_backup for k in _INTERNAL_ARGS: s[k] = backups[k] @@ -148,6 +152,13 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna gsuBackup = state["GlobalSplitU"] ggBackup = state["ProblemType"]["GroupedGemm"] + wgmxccBackup = state["WorkGroupMappingXCC"] + + # Include WGMXCC in kernel name as either n1 for auto or 1 for set value + # Fixed values produce different assembly code + # If the key is missing from name, kernels are dropped as duplicates when they should be kept + if "WorkGroupMappingXCC" in state and state["WorkGroupMappingXCC"] != -1: + state["WorkGroupMappingXCC"] = 1 if ignoreInternalArgs: state["ProblemType"]["GroupedGemm"] = False @@ -161,7 +172,7 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna requiredParametersTemp.discard("GlobalSplitU") else: requiredParametersTemp = requiredParametersTemp.union(["WorkGroupMapping", - "WorkGroupMappingXCC", + # "WorkGroupMappingXCC", # WGMXCC affects asm code gen "WorkGroupMappingXCCGroup", "StaggerU", "StaggerUStride", @@ -201,6 +212,7 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna state["GlobalSplitU"] = gsuBackup state["ProblemType"]["GroupedGemm"] = ggBackup + state["WorkGroupMappingXCC"] = wgmxccBackup return '_'.join(components) diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_wgmxcc_kernel_name.py b/projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_wgmxcc_kernel_name.py new file mode 100644 index 000000000000..0f179980f41b --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_wgmxcc_kernel_name.py @@ -0,0 +1,85 @@ +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +""" +Unit tests for WorkGroupMappingXCC (WGMXCC) kernel naming. + +WorkGroupMappingXCC must be included in the kernel name because it generates different +kernel code for -1 (chunking) and other values (regular XCC mapping). This test checks +that kernel names include WGMXCC so kernels are not rejected as duplicates. +""" + +from copy import deepcopy + +import pytest + +from Tensile.Common.GlobalParameters import defaultSolution +from Tensile.SolutionStructs.Naming import getKernelNameMin, getKeyNoInternalArgs + +pytestmark = pytest.mark.unit + + +def _minimal_kernel(*, work_group_mapping_xcc=-1): + """Build the smallest kernel dict that exercises Naming._getName.""" + kernel = deepcopy(defaultSolution) + kernel["ProblemType"] = { + "OperationIdentifier": "Cijk_Ailk_Bljk", + "DataType": 0, + "DestDataType": 0, + "ComputeDataType": 0, + "GroupedGemm": False, + "UseBeta": False, + "UseBias": False, + } + kernel["WorkGroupMappingXCC"] = work_group_mapping_xcc + kernel["MacroTile0"] = 64 + kernel["MacroTile1"] = 32 + kernel["DepthU"] = 256 + kernel["MatrixInstM"] = 16 + kernel["MatrixInstN"] = 16 + kernel["MatrixInstB"] = 1 + kernel["MatrixInstruction"] = [16, 16, 1, 1] + kernel["MIWaveTile"] = [2, 2] + kernel["WorkGroup"] = [32, 4, 2] + kernel["ISA"] = (9, 5, 0) + return kernel + + +@pytest.mark.parametrize("work_group_mapping_xcc", [-1, 0, 2, 8]) +def test_kernel_name_includes_wgmxcc(work_group_mapping_xcc): + """Kernel names must include a WGMXCC tag for dedup and codegen lookup.""" + name = getKernelNameMin(_minimal_kernel(work_group_mapping_xcc=work_group_mapping_xcc), False) + + assert "WGMXCC" in name + + +def test_auto_wgmxcc_encodes_as_n1(): + """Runtime auto WGMXCC (-1) is abbreviated WGMXCCn1 in the kernel name.""" + name = getKernelNameMin(_minimal_kernel(work_group_mapping_xcc=-1), False) + + assert "WGMXCCn1" in name + + +def test_fixed_wgmxcc_encodes_as_one(): + """Any fixed WGMXCC value is normalized to WGMXCC1 in the kernel name.""" + name = getKernelNameMin(_minimal_kernel(work_group_mapping_xcc=8), False) + + assert "WGMXCC1" in name + assert "WGMXCCn1" not in name + + +def test_auto_and_fixed_wgmxcc_produce_distinct_names(): + """Auto and fixed WGMXCC must not collapse to the same kernel name.""" + auto_name = getKernelNameMin(_minimal_kernel(work_group_mapping_xcc=-1), False) + fixed_name = getKernelNameMin(_minimal_kernel(work_group_mapping_xcc=8), False) + + assert auto_name != fixed_name + + +def test_auto_and_fixed_wgmxcc_produce_distinct_dedup_keys(): + """getKeyNoInternalArgs must also distinguish auto from fixed WGMXCC.""" + auto_key = getKeyNoInternalArgs(_minimal_kernel(work_group_mapping_xcc=-1), False) + fixed_key = getKeyNoInternalArgs(_minimal_kernel(work_group_mapping_xcc=8), False) + + assert "WGMXCC" in auto_key + assert "WGMXCC" in fixed_key + assert auto_key != fixed_key