Skip to content

Add TMA tile::gather4 / tile::scatter4 support#2129

Open
ighoshsubho wants to merge 4 commits intotile-ai:mainfrom
ighoshsubho:feat/tma-gather4-scatter4
Open

Add TMA tile::gather4 / tile::scatter4 support#2129
ighoshsubho wants to merge 4 commits intotile-ai:mainfrom
ighoshsubho:feat/tma-gather4-scatter4

Conversation

@ighoshsubho
Copy link
Copy Markdown

@ighoshsubho ighoshsubho commented Apr 30, 2026

Adds first-class support for the Hopper+ TMA tile::gather4 (load) and tile::scatter4 (store) PTX modes. Previously TileLang's TMA path only covered rectangular tiled and im2col descriptors; sparse-row access patterns (MoE token gather, KV-cache compaction, embedding tile lookup, sparse attention) had to fall back to per-thread cp.async. This PR exposes the gather/scatter modes end-to-end: device PTX wrappers, TIR builtins, CUDA codegen, and Python user API.

Major changes

  • Implements TMA tile::gather4 (load) and tile::scatter4 (store) PTX modes for Hopper+ (verified on B200 / sm_100).
  • New TIR builtins tl.tma_load_gather4 / tl.tma_store_scatter4, CUDA codegen, device-side templates in copy_sm90.h.
  • New user-facing helpers T.tma_gather4 / T.tma_scatter4 (and T.tma_gather4_bytes for expect_tx).
  • Reuses the existing create_tma_descriptor + LowerHopperIntrin hoisting pipeline; descriptor is a 2D tiled CUtensorMap with box dim along the gathered axis = 1.

Usage

if T.shuffle_elect(num_threads):
    T.mbarrier_expect_tx(mbar, T.tma_gather4_bytes(K_box, "float16"))
    T.tma_gather4(Src, smem, col, [r0, r1, r2, r3], barrier=mbar)
    T.barrier_arrive(mbar)
T.mbarrier_wait_parity(mbar, 0)

Test plan

  • testing/python/language/test_tilelang_language_tma_gather_scatter.py — round-trips 4 rows of a (64, 64) fp16 tensor through gather4 → smem → scatter4, asserts equality.
  • Emitted CUDA contains tl::tma_load_gather4 / tl::tma_store_scatter4.
  • Runtime correctness verified on B200.
  • Swizzled shared layouts (32B/64B/128B) — descriptor flag is plumbed but not exercised yet; flagged as follow-up.

Known follow-ups (not in this PR)

  • User must add T.reads / T.writes annotations because the descriptor build references raw buffer.data at parse time. Cleaner fix: a LowerBulkGather4 lowering pass (parallel to LowerBulkCopy) that builds the descriptor at lowering time. Happy to follow up.
  • Manual leader-thread guard required; auto-wrap could be folded into the helper.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added TMA gather4/scatter4 operations for efficient 2D tensor memory access on SM100+ devices
    • New public API functions enable gather/scatter operations with flexible eviction policy configuration
    • Improved layout constraint handling for bank-swizzle operations
  • Tests

    • Added comprehensive test suite for new gather/scatter functionality

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds comprehensive TMA gather4/scatter4 support across the TileLang stack, including intrinsic registrations, copy operator infrastructure, CUDA codegen pathways, Python front-end API, and test coverage for SM100+ hardware.

Changes

TMA Gather4/Scatter4 Feature

Layer / File(s) Summary
Intrinsic Registration
src/op/builtin.cc, src/op/builtin.h
Registers two new public builtin ops: tl.tma_load_gather4 and tl.tma_store_scatter4, marked as opaque calls with unspecified input count (-1).
Copy Operator Infrastructure
src/op/copy.h
Extends CopyInst enum with kBulkLoadGather4 and kBulkStoreScatter4. Adds accessors GetIsGather4(), GetIsScatter4(), GetGather4Rows(), GetGather4Col() on CopyNode. Declares new LowerBulkCopyGather4 lowering method.
Copy Operator Lowering
src/op/copy.cc
Implements precedence routing in GetCopyInst for gather4/scatter4 markers. Extends Lower dispatch to call new LowerBulkCopyGather4 path. New method handles descriptor creation, barrier management, and PTX emission for 2D gather4/scatter4 operations.
CUDA Codegen
src/target/codegen_cuda.cc
Extends codegen to support tl::tma_load_gather4 and tl::tma_store_scatter4, enforcing argument arity, extracting eviction policy, and emitting intrinsic calls with cache hint specialization.
CUDA Templates
src/tl_templates/cuda/copy_sm100.h
Implements templated tma_load_gather4 and tma_store_scatter4 functions (SM100+, CUDA >= 12.8) with CUtensorMap descriptor handling, barrier indexing, and CP.async bulk 2D gather/scatter instructions.
Python Frontend
tilelang/language/copy_op.py
Adds high-level functions tma_gather4, tma_gather4_bytes, tma_scatter4 with dtype validation, TMA descriptor construction, eviction policy mapping, and barrier integration.
Public API Exports
tilelang/language/__init__.py
Re-exports tma_gather4, tma_gather4_bytes, tma_scatter4 from copy_op module.
Layout Constraints
src/layout/gemm_layouts.cc
Relaxes stride checks in MakeQuarterBankSwizzleLayout2D, MakeHalfBankSwizzleLayout2D, MakeFullBankSwizzleLayout2D from stride % 8 == 0 to stride == 4 || stride % 8 == 0 with explanatory comments.
Tests
testing/python/language/test_tilelang_language_tma_gather_scatter.py
Adds SM100-gated test module with program builders (gather_scatter_program, gather_scatter_swizzled_program), execution helpers, and test functions validating functional correctness against expected tensor results and emitted CUDA symbols.

Sequence Diagram

sequenceDiagram
    participant PyAPI as Python API<br/>(tilelang.language)
    participant CopyLow as Copy Operator<br/>(src/op/copy.cc)
    participant Builtin as TMA Builtin<br/>(src/op/builtin.cc)
    participant Codegen as CUDA Codegen<br/>(src/target/codegen_cuda.cc)
    participant Templates as CUDA Templates<br/>(copy_sm100.h)
    participant Hardware as SM100 Hardware

    PyAPI->>PyAPI: tma_gather4(src, dst, col, rows,<br/>barrier, eviction_policy)
    PyAPI->>CopyLow: Annotation: is_gather4=true,<br/>gather4_rows, gather4_col
    CopyLow->>CopyLow: GetCopyInst() → kBulkLoadGather4
    CopyLow->>CopyLow: LowerBulkCopyGather4() creates<br/>TMADesc, barrier, PTX call
    CopyLow->>Builtin: Emit call to tl.tma_load_gather4
    Builtin->>Codegen: Route to CUDA codegen
    Codegen->>Codegen: Extract eviction_policy,<br/>map to CacheHintSm90
    Codegen->>Templates: Generate template<br/>instantiation
    Templates->>Hardware: CP.async.bulk.tensor.2d.shared::cta<br/>.global.tile::gather4 with<br/>CUtensorMap descriptor
    Hardware-->>Templates: Load 4 rows into shared memory
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

  • tile-ai/tilelang#746: Touches overlapping TMA/bulk-copy infrastructure (src/op/builtin., src/op/copy., CUDA codegen) for related TMA operations.
  • tile-ai/tilelang#1667: Modifies CUDA TMA codepaths and CUtensorMap/TMA intrinsic utilities shared with this PR's gather4/scatter4 implementation.
  • tile-ai/tilelang#933: Extends same Copy operator APIs (CopyInst enum and CopyNode lowering methods) by adding tmem load/store, parallel in scope to gather4/scatter4 lowering.

Suggested reviewers

  • Rachmanino
  • LeiWang1999

Poem

🐰 Four rows scattered to the winds,
Four rows gathered back again,
SM100 hardware bends,
With TMA's mystical pen.
A rabbit hops through gather-lanes,
Where shared and global memories reign! 🚀

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.73% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title accurately and concisely summarizes the primary change: adding TMA tile::gather4/scatter4 support, which aligns with all file modifications across builtins, codegen, templates, Python APIs, and tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/tl_templates/cuda/copy_sm90.h`:
- Around line 417-439: The tma_load_gather4 function is using the wrong
namespace and lacks a CUDA 12.8+ guard: wrap the Blackwell-only gather4
implementation with the same version guard used by other TMA loads (i.e., only
compile when CUDA toolkit >= 12.8 / Blackwell) and change the inline-asm
namespace from "shared::cluster.global" to "shared::cta.global" in the asm
string; update the guard to surround the function or the asm block for
tma_load_gather4 so older toolkits skip this code and compilation errors are
prevented.

In `@testing/python/language/test_tilelang_language_tma_gather_scatter.py`:
- Around line 43-75: Gate the test to run only when CUDA is available and the
device compute capability is sm_90+ by adding a runtime skip or pytest marker:
import torch and pytest, then either decorate test_gather_scatter_basic with
`@pytest.mark.skipif`(not (torch.cuda.is_available() and
torch.cuda.get_device_properties(0).major >= 9), reason="requires CUDA sm_90+"),
or add an early check at the top of run_gather_scatter that calls pytest.skip
with the same condition; reference the test function test_gather_scatter_basic
and the helper run_gather_scatter so the compile/kernel invocation and CUDA
tensor allocations are only executed on sm_90+ devices.

In `@tilelang/language/copy_op.py`:
- Around line 297-359: The tma_gather4 path currently only checks rank and rows
but must validate dtype and layout invariants before calling
create_tma_descriptor/_tma_load_gather4: ensure src.dtype == dst.dtype, ensure
dst.shape[0] == 4 (leading dimension is 4 for the shared tile), and ensure the
innermost global stride is 1 (check the computed g_stride_rev[0] or
src.strides[1] when strides exist); if any check fails raise a ValueError with a
clear message. Add the same set of validations to the analogous gather4/scatter4
helper (the block around lines 389-449) referencing create_tma_descriptor,
_tma_load_gather4/_tma_store_scatter4 and the variables src, dst, src.strides,
dst.shape so the descriptor is only built for packed 4xK_box tiles of matching
element type and packed global layout.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cee9ba02-9641-473e-8859-3a49f37bee6b

📥 Commits

Reviewing files that changed from the base of the PR and between cabc702 and 7c9aad7.

📒 Files selected for processing (8)
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/copy_sm90.h
  • testing/python/language/test_tilelang_language_tma_gather_scatter.py
  • tilelang/language/__init__.py
  • tilelang/language/builtin.py
  • tilelang/language/copy_op.py

Comment thread src/tl_templates/cuda/copy_sm90.h Outdated
Comment thread tilelang/language/copy_op.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/language/copy_op.py`:
- Around line 256-265: The _swizzle_str_to_int helper currently accepts
"32B"/"64B"/"128B" but we must reject non-none swizzle modes until shared-layout
support exists; update _swizzle_str_to_int to only accept None/"none"/0 and
return _CU_TENSOR_MAP_SWIZZLE_NONE, and change the other branches to raise
ValueError for any other input (including "32B","64B","128B",32,64,128), and
apply the same restrictive behavior to the other similar helper(s) referenced in
the file (the equivalent functions around the later blocks noted at lines
~342-356 and ~439-453) so the public API cannot encode swizzled modes
prematurely.
- Around line 297-300: The current checks in tma_gather4 (and similarly in
tma_scatter4 at the other site) only validate that src/dst are tir.Buffer
instances but do not assert their storage scopes; update the checks to also
verify the buffer storage scope strings so tma_gather4 enforces
src.storage_scope == "global" and dst.storage_scope == "shared" (and
tma_scatter4 enforces the opposite: src == "shared", dst == "global"). Locate
the checks around the tma_gather4 and tma_scatter4 helpers (the existing
isinstance checks for src and dst) and add explicit scope checks against the
buffer's storage scope property (e.g., buffer.scope or buffer.storage_scope API
used in this codebase), raising TypeError with a clear message if the scope is
wrong.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f3a11354-f33c-47fe-b638-0dab0b208773

📥 Commits

Reviewing files that changed from the base of the PR and between 7c9aad7 and a05f621.

📒 Files selected for processing (3)
  • src/tl_templates/cuda/copy_sm90.h
  • testing/python/language/test_tilelang_language_tma_gather_scatter.py
  • tilelang/language/copy_op.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/copy_sm90.h

Comment thread tilelang/language/copy_op.py Outdated
Comment thread tilelang/language/copy_op.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tilelang/language/copy_op.py`:
- Around line 296-297: The docstring for tma_gather4 conflicts with the runtime:
_swizzle_str_to_int no longer accepts "32B"/"64B"/"128B", so update
tma_gather4's swizzle parameter docs to reflect the actual contract (accepts
None or integer values 0, 32, 64, 128) and remove the quoted string variants;
mention that int 0/32/64/128 are the valid inputs and that other formats will be
rejected by _swizzle_str_to_int.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ab68ca95-a876-4a8b-a374-b3d7948649d9

📥 Commits

Reviewing files that changed from the base of the PR and between a05f621 and 7efc451.

📒 Files selected for processing (1)
  • tilelang/language/copy_op.py

Comment thread tilelang/language/copy_op.py Outdated
@ighoshsubho ighoshsubho force-pushed the feat/tma-gather4-scatter4 branch from 7efc451 to 9aace3e Compare April 30, 2026 07:30
@Rachmanino
Copy link
Copy Markdown
Collaborator

Awesome contribution! This is absolutely useful feature for kernel developmen on Blackwell. Left some minor suggestions:)

@Rachmanino Rachmanino self-requested a review May 6, 2026 05:50
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to move these templates to copy_sm100.h, as they are introduced on sm100

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, yeah that cp.async.bulk.tensor.*.tile::gather4 / scatter4 was only introduced in sm100, l'll move both out of copy_sm90.h and into copy_sm100.h in the next push.

Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there's no need to def these 2 fns in builtin.py? They are kinda duplicated with those in copy_op.py. Those in copy_op.py can be simply lowered to tir.call_intrin(xxx)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the tma_load_gather4 and tma_load_scatter4 are thin shims around tir.call_intrin, I'll drop them and inline the call_intrin directly in copy_op.py's tma_gather4 / tma_scatter4 so there's a single Python entry point per op.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure but is there any way the compiler can figure out the cuTensorMapDataType and swizzle mode for us, just like previously did in plain TMA copy?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the right direction. The reason this PR has explicit dtype/swizzle plumbing is to land a working baseline; tile::gather4 adds constraints that don't apply to plain TMA like rank 2 descriptor, leading box-dim must be 4, so I went with manual encoding while sorting that out. Modular's Mojo wrapper exposes the same swizzle_mode knob, so it's not a hardware constraint.

The right followup ig is to extend the same path the regular bulk copy already uses CopyNode::LowerBulkCopy, and infers desc.data_type via to_CUtensorMapDataType(global_tensor->dtype and picks desc.swizzle from the shared tile SwizzleMode.

Happy to do it in this PR if you'd prefer it bundled, or as a follow up, let me know. cc @Rachmanino

@Rachmanino
Copy link
Copy Markdown
Collaborator

Also feel free to quote me whenever you wanna run the CI routine

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 434-437: DetectSwizzleMode currently enforces info.stride % 8 == 0
before returning quarter/half/full swizzle modes, so newly-allowed stride==4
layouts (used by tile::gather4/scatter4) get classified as
CU_TENSOR_MAP_SWIZZLE_NONE; update DetectSwizzleMode to treat info.stride==4 as
a valid candidate for the 4-row swizzle pattern (and similarly handle any checks
that previously required stride%8==0), by adding a branch that accepts stride==4
and applies the same swizzle-detection logic used for 8-multiple strides
(returning quarter/half/full as appropriate based on the computed pattern),
ensuring references to stride/info.stride and the CU_TENSOR_MAP_SWIZZLE_NONE
return are adjusted accordingly.

In `@src/op/copy.cc`:
- Around line 1912-2012: LowerBulkCopyGather4 currently skips the generic TMA
legality checks and can build a TMADesc with misaligned box/strides; before
calling create_tma_descriptor() in LowerBulkCopyGather4, reapply the same
validations used by the regular TMA lowering: verify the shared box byte-size is
128-bit aligned (i.e., K_box * dtype.bytes() % 16 == 0), ensure each global
stride (after TMABytesFromElements conversion) is 16-byte aligned and within the
allowed range (< 2^40 bytes), and fail fast (ICHECK or return an error) if any
check fails; place these checks right after computing
desc.global_stride/desc.smem_box and before Call(DataType::Handle(),
create_tma_descriptor(), ...), referencing TMADesc, desc.global_stride,
desc.smem_box, create_tma_descriptor(), and LowerBulkCopyGather4 so the
specialized gather4/scatter4 path enforces the same constraints as the generic
TMA lowering.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 44a9d366-6e18-4d1c-b63f-465abe03b9a6

📥 Commits

Reviewing files that changed from the base of the PR and between 9aace3e and d8992e5.

📒 Files selected for processing (6)
  • src/layout/gemm_layouts.cc
  • src/op/copy.cc
  • src/op/copy.h
  • src/tl_templates/cuda/copy_sm100.h
  • testing/python/language/test_tilelang_language_tma_gather_scatter.py
  • tilelang/language/copy_op.py

Comment thread src/layout/gemm_layouts.cc
Comment thread src/op/copy.cc Outdated
@ighoshsubho ighoshsubho force-pushed the feat/tma-gather4-scatter4 branch from d8992e5 to c0ef91a Compare May 6, 2026 10:01
@ighoshsubho
Copy link
Copy Markdown
Author

@Rachmanino re-pushed: descriptor build moved into LowerBulkGather4 (CUDA backend), dtype + swizzle now compiler-inferred (DetectSwizzleMode extended for stride==4, plus generic TMA legality checks per coderabbit). Templates moved to copy_sm100.h, builtin.py wrappers dropped. 4 gather4 tests + 21 gemm regressions green on B200 feel free to kick CI.

@ighoshsubho ighoshsubho changed the title Add TMA tile::gather4 / tile::scatter4 support (sm_90+) Add TMA tile::gather4 / tile::scatter4 support May 6, 2026
@ighoshsubho ighoshsubho force-pushed the feat/tma-gather4-scatter4 branch from c0ef91a to d4842b6 Compare May 6, 2026 11:56
@ighoshsubho
Copy link
Copy Markdown
Author

all merge conflicts resolved

@ighoshsubho
Copy link
Copy Markdown
Author

some rocm tests seem to be failing, I don't think its related to this PR but @Rachmanino let me know if its something from my side, will fix it

@Rachmanino
Copy link
Copy Markdown
Collaborator

yep, these rocm failures are irrelevent. I have no further concern, cc @LeiWang1999

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants