Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18526,12 +18526,13 @@ def _emitTdmIterateInit(self, mod, kernel, tc, dtype, du, mt, perIssueLoadRowDiv

tile_dim1 = self._tdmIterTileDim1(kernel, tc, du, dtype)
rows_per_il = mt // perIssueLoadRowDivisor
iter_count = rows_per_il // tile_dim1
lds_inc = (lbspp + pad_bytes) >> dss

if tile_dim1 == 0 or rows_per_il % tile_dim1 != 0:
raise RuntimeError(
f"TDM iterate {tc}: rows_per_issueLoad({rows_per_il}) not divisible by tile_dim1({tile_dim1}).")

iter_count = rows_per_il // tile_dim1
lds_inc = (lbspp + pad_bytes) >> dss
if not (0 < iter_count <= 256):
raise RuntimeError(
f"TDM iterate {tc}: iter_count({iter_count}) outside HW range 1~256 (field encodes n-1).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,12 +948,6 @@ def assignProblemIndependentDerivedParameters(state, printRejectionReason: bool,
if state["DirectToVgprMXSA"] or state["DirectToVgprMXSB"]:
reject(state, printRejectionReason, "UseSubtileImpl=1 PrefetchAcrossPersistent not supported with DirectToVgpr MX scale tensors")

# TODO: Support other LdsBlockSizePerPadMXSA/B for gfx1250.
if state["ISA"] == (12, 5, 0):
if ((state["LdsBlockSizePerPadMXSA"] > 0) or (state["LdsBlockSizePerPadMXSB"] > 0 )):
reject(state, "LdsBlockSizePerPadMXSA/LdsBlockSizePerPadMXSB support -1 and 0 for gfx1250")
return

state["Multicast"] = False
state["ClusterBarrier"] = False
if state["ClusterDim"] != [1, 1]:
Expand Down Expand Up @@ -3019,12 +3013,17 @@ def checkLdsBlockSizePerPadForTDM(ldsBlockSizePerPadA: int, ldsBlockSizePerPadB:
if state["TDMInst"]:
pads = {"A": ldsBlockSizePerPadA, "B": ldsBlockSizePerPadB, "MXSA": ldsBlockSizePerPadMXSA, "MXSB": ldsBlockSizePerPadMXSB}
for tc, val in pads.items():
if val == 0: continue
# A/B in iterate-mode bypass the pad_interval encoding; skip their
# check. MXSA/MXSB do not support iterate-mode, so their LBSPP
# must still satisfy the pad_interval constraints.
if tc in ("A", "B") and state.get("_TDMIterateMode%s" % tc, False):
if val == 0:
reject(state, printRejectionReason,
f"TDMIterateMode set for {tc} but LdsBlockSizePerPad{tc}=0; "
f"iterate-mode needs a non-zero pad block.")
return
continue
if val == 0: continue
dwords = val // 4
if dwords == 0 or (dwords & (dwords - 1)) != 0:
reject(state, printRejectionReason, f"LdsBlockSizePerPad{tc}={val}: val//4={dwords} must be a positive power of 2 for TDM hardware encoding")
Expand Down
Loading