Skip to content

Fix broken transpose match in OptimizeDotOperands #6819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

saagarjha
Copy link
Contributor

SwizzleShmemConvert currently only handles
dot(convert(trans(src)) #dot_operand). Unfortunately, AccelerateMatmul has the tendency to rewrite
dot(convert(trans(convert(src))) #dot_operand) into dot(trans(convert(src)) #dot_operand). This breaks ldmatrix lowering in TritonGPUToLLVM and tanks performance with transposed operands. We can fix this by recognizing the case where the trans directly feeds into dot in addition to the case where there's a convert in the middle.

Fixes #6569.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because I am not sure what the best way to test this is, please advise.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

SwizzleShmemConvert currently only handles
dot(convert(trans(src)) #dot_operand). Unfortunately, AccelerateMatmul
has the tendency to rewrite
dot(convert(trans(convert(src))) #dot_operand) into
dot(trans(convert(src)) #dot_operand). This breaks ldmatrix lowering in
TritonGPUToLLVM and tanks performance with transposed operands. We can
fix this by recognizing the case where the trans directly feeds into dot
in addition to the case where there's a convert in the middle.

Fixes triton-lang#6569.
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this solving a performance of functional problem? With the new pattern match we may force move the input into shared memory even if the data are already in register so I don't think it is always profitable

@saagarjha
Copy link
Contributor Author

This solves a performance problem with TN matrix multiplies (you can see the issue linked in the commit for more details). The idea is that the data is moved to shared memory preemptively because the transpose is going to spill to shared memory anyway, and with this form later stages in the pipeline can recognize it using the optimized ldmatrix form that knows how to transpose data "natively". This avoids an extremely expensive load/store pattern that results from trying to do the transpose manually in shared memory.

@ThomasRaoux
Copy link
Collaborator

because the transpose is going to spill to shared memory anyway

but that's not always true. Transpose doesn't mean the values will go through shared

@saagarjha
Copy link
Contributor Author

I guess? I do see that there's a cvtNeedsSharedMemory function, so could we just look at the two layouts and ask it if shared memory is necessary?

@ThomasRaoux
Copy link
Collaborator

I guess? I do see that there's a cvtNeedsSharedMemory function, so could we just look at the two layouts and ask it if shared memory is necessary?

that's why the current pattern kicks in when there is a convert as it would go through shared. I don't fully understand why this is not the case for you

@saagarjha
Copy link
Contributor Author

saagarjha commented May 15, 2025

The problem is that the optimizer is too myopic. If you have a tt.load, tt.trans, tt.dot sequence, at least one shared memory conversion will be generated to ensure that both anchor ops (load, dot) get the layouts they want. You can load in a blocked layout, perform a transpose, and then do another conversion before the dot to get into mma form. Or you can load in a blocked layout and do a conversion to get into "mma layout but transposed" form in one step. ldmatrix will also go through shared memory but it can actually do the entire sequence in a "hardware accelerated" way, rather than element-by-element as the default conversions will be lowered.

OptimizeDotOperands figures out that the conversion right before a dot will hit shared memory and rewrites it to use ldmatrix (that's the first sequence I mentioned above). But AccelerateMatmul converts code from the first form into the second one, which breaks recognizing the idiom, because OptimizeDotOperands doesn't know the "mma but transposed" conversion is going to hit shared memory. It doesn't "look up" far enough; it just looks for "trans, convert, dot", rather than "convert, trans, dot". I guess this PR doesn't either but it would be easy to change that.

@lezcano
Copy link
Contributor

lezcano commented May 17, 2025

How urgent is this for you? In the next month we'll be landing fast transposes for all datatypes via arbitrary swizzling and we'll remove this pass among others. Would that work for you?

@lezcano
Copy link
Contributor

lezcano commented May 17, 2025

Also, the issue you are seeing wrt. the interaction between accelerate matmul and this pass is that accelerate matmul does what it does expecting the pipeliner to kick in, but I guess that's not the case in your kernel as there is no loop.

Also, I think this pass is now never kicking in as the pattern it matches would not happen (if our remove_convert_layout pass works as expected) as we can now propagate layouts past transposes.

The fix here would be to just look at the convert_layout (not the ops around) and check whether we want to transpose the matrix (by matching the output layout with the transpose of the input). But again, all this will be moot in a month so perhaps it's better to just wait?

@saagarjha
Copy link
Contributor Author

That's a nuanced question. We have several kernels that aren't hitting the performance numbers we want, so we've been patching the compiler to have them generate the code we "expect" it to. Unfortunately these are a important workload for us, so I'd say the changes are pretty critical. In fact they are so critical that we are writing our kernels in conjunction with the patches, which we apply to our (private) fork of Triton.

We don't actually want to have a fork at all, but our changes are pretty sloppy. Most of them recognize specific patterns that don't really make sense to be upstreamed, as they are unlikely to be useful to anyone else since they're not very general. However, some of them seem worthwhile to send back, like this one. We already have this patch, and it's not super difficult to carry it ourselves, so I wouldn't say that it is "urgent" that this gets merged. If you land changes in the future that obviates the need for this patch we'd be glad to drop it.

One thing which we don't have is the context for how we should actually be making our changes. As you can probably guess from the PR, we've noticed something that doesn't work and picked an option that fixes it for us. It's hard to tell what "should" be happening–that is, which passes are supposed to optimize what. All we know is that Triton ends up generating code that isn't great. We're happy to file issues or send in fixes when things don't work as they should, but we're still trying to figure out when that's happening.

FWIW, for this particular issue, our "simple" example doesn't have a loop but we pulled this out of a kernel that did. I can go back and review it with the information you've provided but off the top of my head I think it was failing to pipeline the loads at all, or doing it incorrectly (which I was planning on investigating next). When writing this PR I had assumed that the operand was "just supposed to" end up in shared memory (by looking at the TritonGPUToLLVM pass) and my aim was to have it happen. I didn't realize that this was meant to happen because the load had been pipelined. I'll go back and take another look at what was going on here. I'll let you know if I figure anything out, or perhaps we'll just have something higher priority come up and we can just wait for you to fix it for us first ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tl.dot on transposed matrix tries to rearrange matrix in shared memory
3 participants