From 1f2b3bd085e50d3edf5937a2d84ee8d7ccc28f43 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 1 Oct 2025 09:21:42 -0700 Subject: [PATCH] [Mosaic GPU] Add an SMEM inference rule for `memref.subview`. PiperOrigin-RevId: 813783302 --- .../mosaic/gpu/layout_inference.py | 562 ++++++++++++++++-- tests/mosaic/gpu_layout_inference_test.py | 519 +++++++++++++++- 2 files changed, 1025 insertions(+), 56 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 002a6e417754..51766f003fec 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -21,13 +21,16 @@ import enum from functools import partial import itertools +import math import re from typing import Any, assert_never, cast from jax._src.lib import mosaic_gpu_dialect as mgpu # noqa: F401 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import math as mlir_math +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector @@ -252,8 +255,76 @@ def _strided_layout_for_variable( return fa.WGStridedFragLayout.from_shaped_type(type) -def _extract_variable_assignments_from_constraint( +def _extract_tiling_candidates_from_transposes( constraint: eqns.Constraint, +) -> Iterator[tuple[eqns.Variable, eqns.Constant]]: + """Attempts to extract variable assignments from a Transposed constraints.""" + if not isinstance(constraint, eqns.Transposed): + return + + lhs, rhs = constraint.lhs, constraint.rhs + match lhs, rhs: + case eqns.Variable(), eqns.SMEMTiling(): + variable, constant = lhs, rhs + case eqns.SMEMTiling(), eqns.Variable(): + variable, constant = rhs, lhs + case _: + return + + if constant.value is not None: + tiling = constant.value.tiling + yield variable, eqns.SMEMTiling(lc.TileTransform(tiling[::-1])) + # Note that in case constant.value is None, yielding SMEMTiling(None) is + # valid. However the calling function already yields that so we don't + # duplicate it here. + + +def _extract_tiling_candidate( + divides: list[eqns.Divides], + num_tiled_dims: int, +) -> Iterator[tuple[eqns.Variable, eqns.Constant]]: + if not divides: + return + + [divides] = eqns.merge_divides_constraints(divides) + assert isinstance(divides, eqns.Divides) + if not isinstance(divides.expr, eqns.Variable): + return + if num_tiled_dims > len(divides.dimensions_to_tile): + return + + if num_tiled_dims == 0: + yield divides.expr, eqns.SMEMTiling(None) + return + + # Below we first ignore dynamic values as this can give better results. If + # that doesn't work, then we try again, using 1 for dynamic values, but only + # if there is at least one dynamic value. + has_dynamic = False + for ignore_dynamic in (True, False): + tiling: list[int] = [] + for dim in divides.dimensions_to_tile: + has_dynamic = has_dynamic or any(map(lambda x: not isinstance(x, int), dim)) + if ignore_dynamic: + dims = [x for x in dim if isinstance(x, int)] + else: + sanitize = lambda x: x if isinstance(x, int) else 1 + dims = [sanitize(x) for x in dim] + tiling.append(math.gcd(*dims)) + + non_tiled_dims = len(divides.dimensions_to_tile) - num_tiled_dims + tiling = tiling[non_tiled_dims:] + + const = eqns.SMEMTiling(lc.TileTransform(tuple(tiling))) + yield divides.expr, const + + if not has_dynamic: + break + + +def _extract_layout_candidates_from_memory_space_transfers( + constraint: eqns.Constraint, + divides_per_var: dict[eqns.Variable, list[eqns.Divides]], ) -> Iterator[tuple[eqns.Variable, eqns.Constant]]: """Attempts to extract variable assignments from a `Constraint`.""" if not isinstance(constraint, eqns.IsTransferable): @@ -271,31 +342,70 @@ def _extract_variable_assignments_from_constraint( return if isinstance(constant, eqns.RegisterLayout): - for packing in (1, 2, 4, 8): - for tmem_layout, reg_layout in constraint.supported_tmem_transfers( - packing - ): - if constant.value == reg_layout: - yield variable, eqns.TMEMLayout(tmem_layout) + layout = constant.value + if variable.key.memory_space == MemorySpace.TMEM: + for packing in (1, 2, 4, 8): + for tmem_layout, reg_layout in constraint.supported_tmem_transfers( + packing + ): + if layout == reg_layout: + yield variable, eqns.TMEMLayout(tmem_layout) + elif variable.key.memory_space == MemorySpace.SMEM: + if inference_utils.is_mma_layout(layout): + tiling = _infer_tiling_for_mma_ref( + variable.key.value.type, + max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) + assert len(tiling) == 2 + divides = divides_per_var.get(variable, []) + divides.append(eqns.Divides(variable, ((tiling[0],), (tiling[1],)))) + yield from _extract_tiling_candidate(divides, len(tiling)) + else: + # An empty tiling is valid here but we don't yield it in order to + # avoid duplicating the empty tiling yielded by the caller. + return elif isinstance(constant, eqns.TMEMLayout): - packing = constant.value.vector_length + layout = constant.value + packing = layout.vector_length for tmem_layout, reg_layout in constraint.supported_tmem_transfers(packing): - if constant.value == tmem_layout: + if layout == tmem_layout: yield variable, eqns.RegisterLayout(reg_layout) +def _divides_per_var( + constraints: Sequence[eqns.Constraint], +) -> dict[eqns.Variable, list[eqns.Divides]]: + """Returns all Divides constraints per variable.""" + result: dict[eqns.Variable, list[eqns.Divides]] = {} + for constraint in constraints: + match constraint: + case eqns.Divides(expr=expr) if isinstance(expr, eqns.Variable): + result.setdefault(expr, []).append(constraint) + return result + + +def _extract_variable_assignments_from_constraints( + constraints: Sequence[eqns.Constraint], +) -> Iterator[tuple[eqns.Variable, eqns.Constant]]: + """Attempts to extract variable assignments from all constraints.""" + dpv = _divides_per_var(constraints) + for c in constraints: + yield from _extract_layout_candidates_from_memory_space_transfers(c, dpv) + yield from _extract_tiling_candidates_from_transposes(c) + + def conjure_assignment( unknowns: Set[eqns.Variable], equation_system: eqns.EquationSystem, hints: Sequence[Hint], ) -> Iterator[tuple[eqns.Variable, eqns.Constant]]: """Attempts to conjure an assignment for an unknown variable.""" - for constraint in equation_system.constraints: - # TODO(allanrenucci): We should be able to short-circuit the search here if - # the constraint is not satisfiable. - for assg in _extract_variable_assignments_from_constraint(constraint): - yield assg + # TODO(allanrenucci): We should be able to short-circuit the search here if + # the constraint is not satisfiable. + yield from _extract_variable_assignments_from_constraints( + equation_system.constraints + ) for hint in hints: if (assignment := extract_variable_assignment_from_hint(hint)) is not None: @@ -312,11 +422,9 @@ def conjure_assignment( # reduces the system. if variable.key.memory_space == MemorySpace.REG: layout = _strided_layout_for_variable(variable) - else: - layout = None - if layout is None: - continue - yield variable, eqns.RegisterLayout(layout) + yield variable, eqns.RegisterLayout(layout) + elif variable.key.memory_space == MemorySpace.SMEM: + yield variable, eqns.SMEMTiling(None) def find_assignments_for( @@ -494,7 +602,6 @@ def _pointwise_op_equation_system( mlir_math.LogOp, mlir_math.RsqrtOp, mlir_math.TanhOp, - vector.StoreOp, ]: _add_equation_system_derivation_rule(op)(_pointwise_op_equation_system) @@ -504,20 +611,86 @@ def _vector_load_equation_system( ctx: DerivationContext, op: vector.LoadOp, ) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: - equation_system: eqns.EquationSystem | eqns.Unsatisfiable - equation_system, operand_or_results_for_variable, hints = ( - _pointwise_op_equation_system(ctx, op) - ) - [result_variable] = operand_or_results_for_variable.keys() - result_is_not_splat = eqns.Distinct( - result_variable, - eqns.RegisterLayout( - fa.WGSplatFragLayout(shape=tuple(op.result.type.shape)) + # Checks + for i in op.indices: + index_defining_op = i.owner.opview + if ( + not isinstance(index_defining_op, arith.ConstantOp) + or index_defining_op.literal_value != 0 + ): + # TODO(bchetioui): handle slicing. + raise NotImplementedError( + f"Only constants with value 0 are supported as indices for {op}" + ) + + # Registers + dest = OperandOrResult(op, VariableType.RESULT, 0) + dest_var = eqns.Variable(dest) + operand_or_results_for_variable = {dest_var: [dest]} + constraints = [ + eqns.Distinct( + dest_var, + eqns.RegisterLayout( + fa.WGSplatFragLayout(shape=tuple(op.result.type.shape)) + ), ), - ) - equation_system &= eqns.EquationSystem(constraints=[result_is_not_splat]) - assert not isinstance(equation_system, eqns.Unsatisfiable) - return equation_system, operand_or_results_for_variable, hints + ] + + # SMEM + if ctx.enable_smem_inference and utils.is_smem_ref(op.base): + source = OperandOrResult(op, VariableType.OPERAND, 0) + source_var = ctx.producer_ref(source) + operand_or_results_for_variable[source_var] = [source] + constraints.append( + eqns.IsTransferable( + source=source_var, + target=dest_var, + shape=tuple(op.result.type.shape), + ), + ) + + system = eqns.EquationSystem(constraints=constraints) + return system, operand_or_results_for_variable, [] + + +@_add_equation_system_derivation_rule(vector.StoreOp) +def _vector_store_equation_system( + ctx: DerivationContext, + op: vector.StoreOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + # Checks + for i in op.indices: + index_defining_op = i.owner.opview + if ( + not isinstance(index_defining_op, arith.ConstantOp) + or index_defining_op.literal_value != 0 + ): + # TODO(bchetioui): handle slicing. + raise NotImplementedError( + f"Only constants with value 0 are supported as indices for {op}" + ) + + # Registers + value = OperandOrResult(op, VariableType.OPERAND, 0) + value_var = eqns.Variable(value) + operand_or_results_for_variable = {value_var: [value]} + + # SMEM + constraints = [] + if ctx.enable_smem_inference and utils.is_smem_ref(op.base): + dest = OperandOrResult(op, VariableType.OPERAND, 1) + dest_var = ctx.producer_ref(dest) + operand_or_results_for_variable[dest_var] = [dest] + constraints = [ + eqns.IsTransferable( + source=value_var, + target=dest_var, + shape=tuple(op.base.type.shape), + ) + ] + + system = eqns.EquationSystem(constraints=constraints) + return system, operand_or_results_for_variable, [] @_add_equation_system_derivation_rule(mgpu.OptimizationBarrierOp) @@ -600,7 +773,6 @@ def _for_equation_system( ctx: DerivationContext, op: scf.ForOp, ) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: - del ctx [block] = op.region.blocks yield_op = _terminator(block, scf.YieldOp) operand_or_results_for_variable: OperandOrResultsForVariable = {} @@ -609,7 +781,7 @@ def _for_equation_system( # in the operands but not in the results. num_leading_args = 3 for index, o in enumerate(op.operands): - if not is_vector(o): + if not is_vector(o) and not (ctx.enable_smem_inference and _is_smem_ref(o)): continue result_index = index - num_leading_args operand = OperandOrResult(op, VariableType.OPERAND, index) @@ -617,9 +789,8 @@ def _for_equation_system( yield_operand = OperandOrResult( yield_op, VariableType.OPERAND, result_index ) - operand_or_results_for_variable[eqns.Variable(operand)] = [ - operand, result, yield_operand, - ] + var = eqns.Variable(operand) if is_vector(o) else ctx.producer_ref(operand) + operand_or_results_for_variable[var] = [operand, result, yield_operand] return eqns.EquationSystem(), operand_or_results_for_variable, [] @@ -701,18 +872,95 @@ def _layout_cast_equation_system( ) +def _infer_tiling_for_mma_ref( + ref_ty: ir.MemRefType, max_swizzle: mgpu.SwizzlingMode +) -> tuple[int, int]: + element_bytewidth = utils.bytewidth(ref_ty.element_type) + strides, _ = ref_ty.get_strides_and_offset() + min_dim_index = np.argmin(strides) + minor_dim = ref_ty.shape[min_dim_index] + + # Try tiling with all swizzling modes starting from the largest one. + for swizzle in [ + mgpu.SwizzlingMode.k128ByteSwizzle, + mgpu.SwizzlingMode.k64ByteSwizzle, + mgpu.SwizzlingMode.k32ByteSwizzle, + mgpu.SwizzlingMode.kNoSwizzle, + ]: + if swizzle > max_swizzle: + continue + swizzle_elems = swizzle // element_bytewidth + if minor_dim % swizzle_elems == 0: + minor_tiling = swizzle_elems + break + else: + # No valid tile transform can be inferred. + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") + + major_tiling = 8 + transposed = min_dim_index != len(strides) - 1 + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) + return tiling + + +def _infer_mma_tiling( + a_type: ir.Type, b_type: ir.Type +) -> tuple[tuple[int, int] | None, tuple[int, int]]: + b_tiling = _infer_tiling_for_mma_ref( + ir.MemRefType(b_type), max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) + b_swizzle = _compute_swizzle(b_type, lc.TileTransform(b_tiling)) + if ir.MemRefType.isinstance(a_type): + a_tiling = _infer_tiling_for_mma_ref( + cast(ir.MemRefType, a_type), max_swizzle=b_swizzle + ) + a_swizzle = _compute_swizzle(a_type, lc.TileTransform(a_tiling)) + if a_swizzle != b_swizzle: + # The swizzle for a and b has to match. + b_tiling = _infer_tiling_for_mma_ref( + ir.MemRefType(b_type), max_swizzle=a_swizzle + ) + b_swizzle = _compute_swizzle(b_type, lc.TileTransform(b_tiling)) + assert a_swizzle == b_swizzle + return a_tiling, b_tiling + return None, b_tiling + + @_add_equation_system_derivation_rule(mgpu.WGMMAOp) def _wgmma_equation_system( ctx: DerivationContext, op: mgpu.WGMMAOp, ) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: - del ctx - operands_or_results = vector_operands_and_results(op) - variable = eqns.Variable(operands_or_results[0]) + assignments: dict[eqns.Variable, eqns.Constant] = {} + # Registers + vector_operands_or_results = vector_operands_and_results(op) + vec_variable = eqns.Variable(vector_operands_or_results[0]) + assignments[vec_variable] = eqns.RegisterLayout(fa.WGMMA_LAYOUT) + oo_for_var = {vec_variable: vector_operands_or_results} + + # SMEM + if ctx.enable_smem_inference: + a_tiling, b_tiling = _infer_mma_tiling(op.a.type, op.b.type) + b = OperandOrResult(op, VariableType.OPERAND, 2) + b_var = ctx.producer_ref(b) + + assignments[b_var] = eqns.SMEMTiling(lc.TileTransform(b_tiling)) + oo_for_var[b_var] = [b] + + if a_tiling is not None: + # a is in SMEM + a = OperandOrResult(op, VariableType.OPERAND, 1) + a_var = ctx.producer_ref(a) + assignments[a_var] = eqns.SMEMTiling(lc.TileTransform(a_tiling)) + oo_for_var[a_var] = [a] + system = eqns.EquationSystem( - assignments={variable: eqns.RegisterLayout(fa.WGMMA_LAYOUT)} + assignments=assignments, ) - return system, {variable: operands_or_results}, [] + return system, oo_for_var, [] @_add_equation_system_derivation_rule(vector.BroadcastOp) @@ -859,17 +1107,32 @@ def _custom_primitive_equation_system( ctx: DerivationContext, op: mgpu.CustomPrimitiveOp, ) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: - del ctx assignments: dict[eqns.Variable, eqns.Constant] = {} + equations: list[eqns.Equation] = [] in_layouts = iter(op.in_layouts) + in_transforms = iter(op.in_transforms) variables: list[eqns.Variable] = [] for i, operand in enumerate(op.operands): - if ir.VectorType.isinstance(operand.type): + if is_vector(operand): v = eqns.Variable(OperandOrResult(op, VariableType.OPERAND, i)) variables.append(v) assignments[v] = eqns.RegisterLayout( layouts_lib.from_layout_attr(next(in_layouts)) ) + elif ctx.enable_smem_inference and _is_smem_ref(operand): + + # Here we need to create a new variable, even though it is equal to the + # source operand. This is because we directly assign the new variable and + # if we did that to the source there could be conflicting assignments. + operand_or_result = OperandOrResult(op, VariableType.OPERAND, i) + source_var = ctx.producer_ref(operand_or_result) + v = eqns.Variable(operand_or_result) + equations.append(eqns.Equation(lhs=source_var, rhs=v)) + variables.append(v) + transforms = next(in_transforms) + ref_ty = operand_or_result.value.type + tiling = _extract_smem_tiling_from_custom_transform_attrs(ref_ty, transforms) + assignments[v] = eqns.SMEMTiling(tiling) out_layouts = iter(op.out_layouts) for i, result in enumerate(op.results): @@ -880,7 +1143,7 @@ def _custom_primitive_equation_system( layouts_lib.from_layout_attr(next(out_layouts)) ) return ( - eqns.EquationSystem(assignments=assignments), + eqns.EquationSystem(equations=equations, assignments=assignments), {v: [v.key] for v in variables}, [], ) @@ -919,14 +1182,21 @@ def _tmem_alloc_equation_system( ) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: del ctx result = OperandOrResult(op, VariableType.RESULT, 0) - variable = eqns.Variable(result) + result_var = eqns.Variable(result) layout = tcgen05._infer_tmem_layout( tuple(op.result.type.shape), op.collective, packing=1 ) + + in_smem = OperandOrResult(op, VariableType.OPERAND, 0) + in_smem_var = eqns.Variable(in_smem) + assignments = {in_smem_var: eqns.SMEMTiling(None)} + operands_for_variable = {result_var: [result], in_smem_var: [in_smem]} + # This is a hint, not a hard constraint. This will be the default layout if # none can be inferred. - hint = Hint(variable, eqns.TMEMLayout(layout)) - return eqns.EquationSystem(), {variable: [result]}, [hint] + hint = Hint(result_var, eqns.TMEMLayout(layout)) + system = eqns.EquationSystem(assignments=assignments) + return system, operands_for_variable, [hint] @_add_equation_system_derivation_rule(mgpu.TmemDeallocOp) @@ -947,6 +1217,7 @@ def _tcgen05_mma_equation_system( assignments: dict[eqns.Variable, eqns.Constant] = {} operands_for_variable: OperandOrResultsForVariable = {} + # TMEM acc = OperandOrResult(op, VariableType.OPERAND, 0) acc_variable = ctx.producer_ref(acc) acc_layout = tcgen05._infer_tmem_layout( @@ -955,17 +1226,35 @@ def _tcgen05_mma_equation_system( assignments[acc_variable] = eqns.TMEMLayout(acc_layout) operands_for_variable[acc_variable] = [acc] - if utils.is_tmem_ref(op.a): + if _is_tmem_ref(op.a): a = OperandOrResult(op, VariableType.OPERAND, 1) - a_variable = ctx.producer_ref(a) + a_var = ctx.producer_ref(a) packing = 32 // utils.bitwidth(op.a.type.element_type) a_layout = tcgen05._infer_tmem_layout( tuple(op.a.type.shape), op.collective, packing ) - assignments[a_variable] = eqns.TMEMLayout(a_layout) - operands_for_variable[a_variable] = [a] + assignments[a_var] = eqns.TMEMLayout(a_layout) + operands_for_variable[a_var] = [a] + + # SMEM + if ctx.enable_smem_inference: + a_tiling, b_tiling = _infer_mma_tiling(op.a.type, op.b.type) + b = OperandOrResult(op, VariableType.OPERAND, 2) + b_var = ctx.producer_ref(b) + assignments[b_var] = eqns.SMEMTiling(lc.TileTransform(b_tiling)) + operands_for_variable[b_var] = [b] + + if _is_smem_ref(op.a): + a = OperandOrResult(op, VariableType.OPERAND, 1) + a_var = ctx.producer_ref(a) + assignments[a_var] = eqns.SMEMTiling(lc.TileTransform(a_tiling)) + operands_for_variable[a_var] = [a] - return eqns.EquationSystem(assignments), operands_for_variable, [] + system = eqns.EquationSystem( + assignments=assignments, + ) + + return system, operands_for_variable, [] @_add_equation_system_derivation_rule(mgpu.AsyncLoadTmemOp) @@ -1010,6 +1299,171 @@ def _async_store_tmem_equation_system( ) +@_add_equation_system_derivation_rule(mgpu.SliceSMEMOp) +def _slice_smem_equation_system( + ctx: DerivationContext, + op: mgpu.SliceSMEMOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + del ctx + res = OperandOrResult(op, VariableType.RESULT, 0) + res_var = eqns.Variable(res) + return (eqns.EquationSystem(), {res_var: [res]}, []) + + +@_add_equation_system_derivation_rule(memref.SubViewOp) +def _memref_subview_equation_system( + ctx: DerivationContext, + op: memref.SubViewOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + source = OperandOrResult(op, VariableType.OPERAND, 0) + source_var = ctx.producer_ref(source) + dest = OperandOrResult(op, VariableType.RESULT, 0) + dest_var = eqns.Variable(dest) + + # Note that even though we have source_var equal to dest_var, we should still + # use two different variables here. The reason is that the shapes of the + # ir values these variables refer to are not necessarily the same and these + # shapes are used during the layout inference. + equations = [eqns.Equation(source_var, dest_var)] + + # Collect all the constraints from all dimensions. + dimensions_to_tile = [] + index_of_last_dynamic_size = None + dynamic_offset_index = 0 + for i in range(len(op.static_sizes)): + if ir.ShapedType.is_dynamic_size(op.static_sizes[i]): + index_of_last_dynamic_size = i + + offset = op.static_offsets[i] + if ir.ShapedType.is_dynamic_size(offset): + offset = op.offsets[dynamic_offset_index] + dynamic_offset_index += 1 + + dims = (op.static_sizes[i], op.source.type.shape[i], offset) + dimensions_to_tile.append(dims) + + # Drop all dimensions up to and including the last dynamic size. Dynamic + # sizes are not supported yet. Note that is not trivial to directly compute + # the final array in the loop above, because we need to keep track of the + # dynamic offsets. + if index_of_last_dynamic_size is not None: + dimensions_to_tile = dimensions_to_tile[index_of_last_dynamic_size + 1 :] + + constraints = [eqns.Divides(dest_var, tuple(dimensions_to_tile))] + return ( + eqns.EquationSystem(constraints=constraints, equations=equations), + {source_var: [source], dest_var: [dest]}, + [], + ) + + +@_add_equation_system_derivation_rule(memref.ViewOp) +def _memref_view_op_equation_system( + ctx: DerivationContext, + op: memref.ViewOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + del ctx + + # The source is expeted to come from a DynamicSharedMemoryOp which does not + # participate in layout inference and no variable exists for it. + if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): + raise NotImplementedError( + "Memref view transforms are only inferred when the op is a direct user " + f"of a DynamicSharedMemoryOp but got {op}." + ) + + res = OperandOrResult(op, VariableType.RESULT, 0) + res_var = eqns.Variable(res) + return eqns.EquationSystem(), {res_var: [res]}, [] + + +@_add_equation_system_derivation_rule(memref.CastOp) +def _memref_cast_op_equation_system( + ctx: DerivationContext, + op: memref.CastOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + source = OperandOrResult(op, VariableType.OPERAND, 0) + dest = OperandOrResult(op, VariableType.RESULT, 0) + var = ctx.producer_ref(source) + return (eqns.EquationSystem(), {var: [source, dest]}, []) + + +# `memref.load` and `memref.store` are used to load barrier phases which are +# scalars---the rule needn't do anything interesting, but we need to have it. +@_add_equation_system_derivation_rule(memref.LoadOp) +@_add_equation_system_derivation_rule(memref.StoreOp) +def _memref_load_store_op_equation_system( + ctx: DerivationContext, + op: memref.LoadOp | memref.StoreOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + del ctx + + ref_shape = ir.MemRefType(op.memref.type).shape + if ref_shape != [] and ref_shape != [1]: + raise NotImplementedError( + f"Only scalar memrefs are supported, got {ref_shape}" + ) + + ref_op_index = 0 if isinstance(op, memref.LoadOp) else 1 + ref = OperandOrResult(op, VariableType.OPERAND, ref_op_index) + var = eqns.Variable(ref) + assignments = {var: eqns.SMEMTiling(None)} + return eqns.EquationSystem(assignments=assignments), {var: [ref]}, [] + + +def _extract_smem_tiling_from_custom_transform_attrs( + ref_type : ir.MemRefType, + transform_attrs: ir.ArrayAttr, +) -> eqns.SMEMTiling: + transforms = [layouts_lib.from_transform_attr(x) for x in transform_attrs] + match transforms: + case []: + tile_transform = None + swizzle = None + case [lc.TileTransform() as t]: + tile_transform = t + swizzle = None + case [lc.TileTransform() as t, mgpu.SwizzlingMode() as s]: + tile_transform = t + swizzle = s + case _: + raise NotImplementedError(f"Unsupported transforms {transforms}") + + if swizzle is not None: + computed_swizzle = _compute_swizzle(ref_type, tile_transform) + if computed_swizzle != swizzle: + raise NotImplementedError( + f"Inconsistent swizzling modes {computed_swizzle} and {swizzle} for op {op}" + ) + + return eqns.SMEMTiling(tile_transform) + + +@_add_equation_system_derivation_rule(mgpu.WithTransformsOp) +def _with_transforms_equation_system( + ctx: DerivationContext, + op: mgpu.WithTransformsOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + source = OperandOrResult(op, VariableType.OPERAND, 0) + dest = OperandOrResult(op, VariableType.RESULT, 0) + var = ctx.producer_ref(source) + tiling = _extract_smem_tiling_from_custom_transform_attrs(op.ref.type, op.transforms) + assignments = {var: tiling} + return eqns.EquationSystem(assignments=assignments), {var: [source, dest]}, [] + + +@_add_equation_system_derivation_rule(mgpu.AsyncLoadOp) +@_add_equation_system_derivation_rule(mgpu.AsyncStoreOp) +def _async_load_equation_system( + ctx: DerivationContext, + op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp, +) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]: + operand_index = 1 if isinstance(op, mgpu.AsyncLoadOp) else 0 + operand = OperandOrResult(op, VariableType.OPERAND, operand_index) + var = ctx.producer_ref(operand) + return (eqns.EquationSystem(), {var: [operand]}, []) + + def _ensure_all_layouts_are_set( op: ir.OpView, enable_smem_inference: bool ) -> None: diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 9e0e5174c4c7..6a7936f2c2e3 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import jax +from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter @@ -26,15 +27,18 @@ from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import llvm from jax._src.lib.mlir.dialects import math +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import equations as eqns from jax.experimental.mosaic.gpu import fragmented_array as fa +from jax.experimental.mosaic.gpu import inference_utils from jax.experimental.mosaic.gpu import launch_context as lc from jax.experimental.mosaic.gpu import layout_inference from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import tcgen05 +import numpy as np config.parse_flags_with_absl() @@ -411,7 +415,9 @@ def test_infer_layout_from_body_op_to_yield_op_to_for_op(self): shape = (64, 64) with ir.InsertionPoint(self.module.body): c_ty = ir.VectorType.get(shape, ir.BF16Type.get()) - ab_type = ir.MemRefType.get(shape, ir.BF16Type.get()) + ab_type = ir.MemRefType.get( + shape, ir.BF16Type.get(), memory_space=mgpu.utils.smem() + ) i32 = ir.IntegerType.get_signless(32) lower_bound, upper_bound, step, a, b, c = undefs( i32, i32, i32, ab_type, ab_type, c_ty @@ -792,7 +798,7 @@ def test_infer_wgmma_layout_correctly(self, lhs_memory_space): with ir.InsertionPoint(self.module.body): vec_ty = ir.VectorType.get(shape, f32) - ref_ty = ir.MemRefType.get(shape, f32) + ref_ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.smem()) lhs_ty = ref_ty if lhs_memory_space == "shared" else vec_ty acc, lhs, rhs = undefs(vec_ty, lhs_ty, ref_ty) wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs) @@ -1153,6 +1159,515 @@ def test_compute_swizzle(self, shape, type, transposed, tiling, want_swizzle): swizzle = layout_inference._compute_swizzle(ref_ty, tile_transform) self.assertEqual(swizzle, mgpu.dialect.SwizzlingMode(want_swizzle)) + @parameterized.parameters([False, True]) + def test_conjure_smem_assignment_from_is_transferrable(self, transposed): + # Create a var to use in the equation system. + shape = (128, 128) + f32 = ir.F32Type.get() + layout = ir.StridedLayoutAttr.get(0, [1, 128]) if transposed else None + ref_ty = ir.MemRefType.get(shape, f32, layout=layout, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + op_or_result = layout_inference.OperandOrResult( + operation=ref.owner, + type=layout_inference.VariableType.RESULT, + index=0, + ) + var = eqns.Variable(op_or_result) + t = lambda x, y: (y, x) if transposed else (x, y) + + def conjure(constraints) -> list[tuple[eqns.Variable, eqns.Constant]]: + system = eqns.EquationSystem(constraints=constraints) + return list(layout_inference.conjure_assignment(set([var]), system, [])) + + # Yield only empty tiling with no constraints. + self.assertEqual(conjure([]), [(var, eqns.SMEMTiling(None))]) + + # Yield empty if not an mma layout. + layout = eqns.RegisterLayout(fa.WGSplatFragLayout(shape)) + constraints = [eqns.IsTransferable(layout, var, (128, 128))] + conjured = conjure(constraints) + self.assertEqual(conjured, [(var, eqns.SMEMTiling(None))]) + + wgmma_layout = eqns.RegisterLayout(fa.WGMMA_LAYOUT) + + # Yield also maximal tiling with no Divides constraints. + constraints = [eqns.IsTransferable(wgmma_layout, var, (128, 128))] + conjured = conjure(constraints) + self.assertEqual(conjured, [ + (var, eqns.SMEMTiling(lc.TileTransform(t(8, 32)))), + (var, eqns.SMEMTiling(None)), + ] + ) + + # Yield also valid tiling with Divides constraints. + constraints = [ + eqns.IsTransferable(wgmma_layout, var, (128, 128)), + eqns.Divides(var, ((64,), (64,))), + eqns.Divides(var, ((32,), (16,))), + ] + conjured = conjure(constraints) + self.assertEqual(conjured, [ + (var, eqns.SMEMTiling(lc.TileTransform((32, 8) if transposed else (8, 16)))), + (var, eqns.SMEMTiling(None)), + ] + ) + + # Yield also 1-tiling with Divides constraints with ir.Value. + i32 = ir.IntegerType.get_signless(32) + ir_value = arith.constant(i32, 0) + constraints = [ + eqns.IsTransferable(wgmma_layout, var, (128, 128)), + eqns.Divides(var, ((32, ir_value), (32,))), + ] + conjured = conjure(constraints) + self.assertEqual(conjured, [ + (var, eqns.SMEMTiling(lc.TileTransform(t(8, 32)))), + (var, eqns.SMEMTiling(lc.TileTransform((1, 8) if transposed else (1, 32)))), + (var, eqns.SMEMTiling(None)), + ] + ) + + def test_conjure_smem_assignment_from_transposed(self): + shape = (128, 128) + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + op_or_result = layout_inference.OperandOrResult( + operation=ref.owner, + type=layout_inference.VariableType.RESULT, + index=0, + ) + var = eqns.Variable(op_or_result) + constraints = [eqns.Transposed(eqns.SMEMTiling(lc.TileTransform((8, 16))), var)] + s = eqns.EquationSystem(constraints=constraints) + conjured = list(layout_inference.conjure_assignment(set([var]), s, [])) + self.assertEqual(conjured, [ + (var, eqns.SMEMTiling(lc.TileTransform((16, 8)))), + (var, eqns.SMEMTiling(None)), + ] + ) + + def test_memref_load_store_op_transforms_are_empty(self): + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + ref_ty = ir.MemRefType.get((), i32, memory_space=mgpu.utils.smem()) + + [val, load_ref, store_ref] = undefs(i32, ref_ty, ref_ty) + load_op = memref.LoadOp(load_ref, []) + store_op = memref.StoreOp(val, store_ref, []) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + + want = ir.ArrayAttr.get([ir.ArrayAttr.get([])]) + self.assertEqual(inference_utils.in_transforms(load_op), want) + self.assertEqual(inference_utils.in_transforms(store_op), want) + + @parameterized.parameters( + (swizzle, dtype) + for swizzle in mgpu.dialect.SwizzlingMode + for dtype in [jnp.bfloat16, jnp.float32] + ) + def test_infer_transforms_for_wgmma_op(self, swizzle, dtype): + swizzle_elems = swizzle // np.dtype(dtype).itemsize + m = 64 + # Note: `group_m` and `group_k` should be coprime with 2 for the test to be + # correct. Otherwise, we may infer larger swizzles than the test intends to + # check. + group_m, group_k = 3, 3 + lhs_shape = (group_m * m, group_k * swizzle_elems) + rhs_shape = (group_k * swizzle_elems, group_k * swizzle_elems) + out_shape = (group_m * m, group_k * swizzle_elems) + + with ir.InsertionPoint(self.module.body): + elt_ty = mgpu.utils.dtype_to_ir_type(dtype) + lhs_ty = ir.MemRefType.get(lhs_shape, elt_ty, memory_space=mgpu.utils.smem()) + rhs_ty = ir.MemRefType.get(rhs_shape, elt_ty, memory_space=mgpu.utils.smem()) + acc_ty = ir.VectorType.get(out_shape, elt_ty) + [acc, lhs, rhs] = undefs(acc_ty, lhs_ty, rhs_ty) + wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + + arg_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, swizzle_elems)), + mgpu.dialect.SwizzleTransformAttr.get(int(swizzle)), + ]) + + self.assertSequenceEqual( + inference_utils.in_transforms(wgmma_op), + [arg_transforms, arg_transforms], + ) + + @parameterized.parameters(mgpu.dialect.AsyncLoadOp, mgpu.dialect.AsyncStoreOp) + def test_infer_transforms_for_async_load_store(self, op_type): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + gmem_ref, smem_ref, barrier = undefs(gmem_ty, smem_ty, barrier_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 32)), + mgpu.dialect.SwizzleTransformAttr.get(64), + ]) + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + transformed_smem_ref = mgpu.dialect.with_transforms(smem_ref, transforms) + if op_type == mgpu.dialect.AsyncLoadOp: + op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=transformed_smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + else: + op = mgpu.dialect.AsyncStoreOp( + source=smem_ref, + destination=gmem_ref, + indices=[zero, zero], + slice_lengths=shape, + ) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + + self.assertSequenceEqual( + inference_utils.in_transforms(op), [transforms] + ) + + @parameterized.product( + op_type=(vector.LoadOp, vector.StoreOp), + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_ROW_LAYOUT, + fa.WGMMA_COL_LAYOUT, + tcgen05.TMEM_NATIVE_LAYOUT, + fa.WGStridedFragLayout((64, 64), vec_size=4), + fa.WGSplatFragLayout((64, 64)), + ), + ) + def test_infer_transforms_for_vector_load_store_op(self, op_type, layout): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + [smem_ref] = undefs(smem_ty) + + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + if op_type == vector.LoadOp: + vector_op = vector.LoadOp( + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) + ) + + layout_cast(vector_op.result, layout) + else: + [value_to_store] = undefs(ir.VectorType.get(shape, elt_ty)) + vector_op = vector.StoreOp(value_to_store, smem_ref, [zero] * len(shape)) + layout_cast(value_to_store, layout) + + if layout == fa.WGMMA_LAYOUT: + expected_transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + else: + expected_transforms = ir.ArrayAttr.get([]) + + if op_type == vector.LoadOp and isinstance(layout, fa.WGSplatFragLayout): + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module, enable_smem_inference=True) + else: + mgpu.infer_layout(self.module, enable_smem_inference=True) + self.assertSequenceEqual( + inference_utils.in_transforms(vector_op), [expected_transforms] + ) + + def test_infer_transforms_for_slice_smem_op_two_identical_consumers(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + [offset] = undefs(i32) + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + slice_smem_op = mgpu.dialect.SliceSMEMOp(ref_ty, offset) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + mgpu.dialect.with_transforms(slice_smem_op.result, transforms) + mgpu.dialect.with_transforms(slice_smem_op.result, transforms) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + self.assertSequenceEqual( + inference_utils.out_transforms(slice_smem_op), [transforms] + ) + + def test_infer_transforms_for_slice_smem_op_two_mismatching_consumers(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + i32 = ir.IntegerType.get_signless(32) + [offset] = undefs(i32) + ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + slice_smem_op = mgpu.dialect.SliceSMEMOp(ref_ty, offset) + + transforms1 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + transforms2 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((16, 64)), + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + mgpu.dialect.with_transforms(slice_smem_op.result, transforms1) + mgpu.dialect.with_transforms(slice_smem_op.result, transforms2) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module, enable_smem_inference=True) + + def test_infer_transforms_sets_default_emptry_transforms(self): + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + with ir.InsertionPoint(self.module.body): + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + [gmem_ref, smem_ref, barrier] = undefs(gmem_ty, smem_ty, barrier_ty) + + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + async_load_op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + [in_transform] = inference_utils.in_transforms(async_load_op) + self.assertSequenceEqual(in_transform, ir.ArrayAttr.get([])) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( + self, annotate_input + ): + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=mgpu.utils.smem()) + + with ir.InsertionPoint(self.module.body): + [in_ref] = undefs(in_ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + in_ref = mgpu.dialect.with_transforms(in_ref, transforms) + + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets = [1, 0, 0], + static_sizes = [2, 64, 8], + static_strides = [1, 1, 1] + ) + + if not annotate_input: + mgpu.dialect.with_transforms(subview_op.result, transforms) + + with self.assertRaisesRegex(ValueError, "Failed to infer"): + mgpu.infer_layout(self.module, enable_smem_inference=True) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_sibling_subviews_and_distant_op( + self, even_offsets + ): + # This test uses the following op tree extracted from this ragged dot + # kernel: + # https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py + # + # subview_op0 (slice = 64, 64) + # - subview_op1 (slice = 2, 64) + # - subview_op2 (slice = 4, 64, either at an even or odd offset) + # - subview_op3 (slice = 8, 64) + # - user_op0 (in_transforms = [tile(64, 64), swizzle(32)]) + # + # First the in_transforms of user_op0 have to be propagated up to + # subview_op0. Then they have to be propagated down and resolved. Finally + # all subview ops need to have the same transforms. + + source_shape = (64, 64) + elt_ty = ir.BF16Type.get() + source_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=mgpu.utils.smem()) + + slice1_shape = (2, 64) + slice2_shape = (4, 64) + slice3_shape = (8, 64) + + slice0_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice1_ref_ty = ir.MemRefType.get(slice1_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice2_ref_ty = ir.MemRefType.get(slice2_shape, elt_ty, memory_space=mgpu.utils.smem()) + slice3_ref_ty = ir.MemRefType.get(slice3_shape, elt_ty, memory_space=mgpu.utils.smem()) + + want_tt = mgpu.dialect.TileTransformAttr.get((2 if even_offsets else 1, 64)) + + with ir.InsertionPoint(self.module.body): + [source_ref] = undefs(source_ref_ty) + subview_op0 = memref.SubViewOp( + slice0_ref_ty, + source_ref, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=source_shape, + static_strides=[1, 1], + ) + + transforms_0 = ir.ArrayAttr.get([want_tt]) + mgpu.dialect.WithTransformsOp(subview_op0.result, transforms_0) + + subview_op1 = memref.SubViewOp( + slice1_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=slice1_shape, + static_strides=[1, 1], + ) + + subview_op2 = memref.SubViewOp( + slice2_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[16 if even_offsets else 15, 0], + static_sizes=slice2_shape, + static_strides=[1, 1], + ) + + # The following ops are just to test the dynamic offsets support. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + c64 = c(64) + c32 = c(32) + c16 = c(16) + subi = arith.subi(c64, c32) + maxsi = arith.maxsi(c16, subi) + addi = arith.addi(maxsi, subi) + andi = arith.andi(addi, maxsi) + idx = arith.index_cast(ir.IndexType.get(), andi) + subview_op3 = memref.SubViewOp( + slice3_ref_ty, + subview_op0, + [idx], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=slice3_shape, + static_strides=[1, 1], + ) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + + want = ir.ArrayAttr.get([ + want_tt, + mgpu.dialect.SwizzleTransformAttr.get(128), + ]) + + self.assertSequenceEqual(inference_utils.in_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op3), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op3), [want]) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_handles_dynamic_offsets( + self, annotate_input + ): + shape = (32, 32, 32) + elt_ty = ir.BF16Type.get() + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem()) + out_ref_ty = ir.MemRefType.get((16, 16, 32), elt_ty, memory_space=mgpu.utils.smem()) + + with ir.InsertionPoint(self.module.body): + [in_ref] = undefs(in_ref_ty) + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((1, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + in_ref = mgpu.dialect.with_transforms(in_ref, transforms) + + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [c(16), c(4)], + [], + [], + static_offsets=[ + ir.ShapedType.get_dynamic_size(), + ir.ShapedType.get_dynamic_size(), + 0, + ], + static_sizes=[16, 16, 32], + static_strides=[1, 1, 1], + ) + + if not annotate_input: + mgpu.dialect.with_transforms(subview_op.result, transforms) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + + self.assertSequenceEqual( + inference_utils.in_transforms(subview_op), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(subview_op), [transforms] + ) + + def test_custom_primitive_op_retains_transforms(self): + with ir.InsertionPoint(self.module.body): + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((64, 64)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + op = mgpu.dialect.custom_primitive( + result=[], + operands_=[], + in_layouts=[], + in_transforms=[transforms], + out_layouts=[], + ) + + mgpu.infer_layout(self.module, enable_smem_inference=True) + self.assertSequenceEqual(inference_utils.in_transforms(op), [transforms]) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader())