-
Notifications
You must be signed in to change notification settings - Fork 2k
Fix broken transpose match in OptimizeDotOperands #6819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
SwizzleShmemConvert currently only handles dot(convert(trans(src)) #dot_operand). Unfortunately, AccelerateMatmul has the tendency to rewrite dot(convert(trans(convert(src))) #dot_operand) into dot(trans(convert(src)) #dot_operand). This breaks ldmatrix lowering in TritonGPUToLLVM and tanks performance with transposed operands. We can fix this by recognizing the case where the trans directly feeds into dot in addition to the case where there's a convert in the middle. Fixes triton-lang#6569.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this solving a performance of functional problem? With the new pattern match we may force move the input into shared memory even if the data are already in register so I don't think it is always profitable
This solves a performance problem with TN matrix multiplies (you can see the issue linked in the commit for more details). The idea is that the data is moved to shared memory preemptively because the transpose is going to spill to shared memory anyway, and with this form later stages in the pipeline can recognize it using the optimized ldmatrix form that knows how to transpose data "natively". This avoids an extremely expensive load/store pattern that results from trying to do the transpose manually in shared memory. |
but that's not always true. Transpose doesn't mean the values will go through shared |
I guess? I do see that there's a |
that's why the current pattern kicks in when there is a convert as it would go through shared. I don't fully understand why this is not the case for you |
The problem is that the optimizer is too myopic. If you have a tt.load, tt.trans, tt.dot sequence, at least one shared memory conversion will be generated to ensure that both anchor ops (load, dot) get the layouts they want. You can load in a blocked layout, perform a transpose, and then do another conversion before the dot to get into mma form. Or you can load in a blocked layout and do a conversion to get into "mma layout but transposed" form in one step. ldmatrix will also go through shared memory but it can actually do the entire sequence in a "hardware accelerated" way, rather than element-by-element as the default conversions will be lowered. OptimizeDotOperands figures out that the conversion right before a dot will hit shared memory and rewrites it to use ldmatrix (that's the first sequence I mentioned above). But AccelerateMatmul converts code from the first form into the second one, which breaks recognizing the idiom, because OptimizeDotOperands doesn't know the "mma but transposed" conversion is going to hit shared memory. It doesn't "look up" far enough; it just looks for "trans, convert, dot", rather than "convert, trans, dot". I guess this PR doesn't either but it would be easy to change that. |
How urgent is this for you? In the next month we'll be landing fast transposes for all datatypes via arbitrary swizzling and we'll remove this pass among others. Would that work for you? |
Also, the issue you are seeing wrt. the interaction between accelerate matmul and this pass is that accelerate matmul does what it does expecting the pipeliner to kick in, but I guess that's not the case in your kernel as there is no loop. Also, I think this pass is now never kicking in as the pattern it matches would not happen (if our remove_convert_layout pass works as expected) as we can now propagate layouts past transposes. The fix here would be to just look at the convert_layout (not the ops around) and check whether we want to transpose the matrix (by matching the output layout with the transpose of the input). But again, all this will be moot in a month so perhaps it's better to just wait? |
That's a nuanced question. We have several kernels that aren't hitting the performance numbers we want, so we've been patching the compiler to have them generate the code we "expect" it to. Unfortunately these are a important workload for us, so I'd say the changes are pretty critical. In fact they are so critical that we are writing our kernels in conjunction with the patches, which we apply to our (private) fork of Triton. We don't actually want to have a fork at all, but our changes are pretty sloppy. Most of them recognize specific patterns that don't really make sense to be upstreamed, as they are unlikely to be useful to anyone else since they're not very general. However, some of them seem worthwhile to send back, like this one. We already have this patch, and it's not super difficult to carry it ourselves, so I wouldn't say that it is "urgent" that this gets merged. If you land changes in the future that obviates the need for this patch we'd be glad to drop it. One thing which we don't have is the context for how we should actually be making our changes. As you can probably guess from the PR, we've noticed something that doesn't work and picked an option that fixes it for us. It's hard to tell what "should" be happening–that is, which passes are supposed to optimize what. All we know is that Triton ends up generating code that isn't great. We're happy to file issues or send in fixes when things don't work as they should, but we're still trying to figure out when that's happening. FWIW, for this particular issue, our "simple" example doesn't have a loop but we pulled this out of a kernel that did. I can go back and review it with the information you've provided but off the top of my head I think it was failing to pipeline the loads at all, or doing it incorrectly (which I was planning on investigating next). When writing this PR I had assumed that the operand was "just supposed to" end up in shared memory (by looking at the TritonGPUToLLVM pass) and my aim was to have it happen. I didn't realize that this was meant to happen because the load had been pipelined. I'll go back and take another look at what was going on here. I'll let you know if I figure anything out, or perhaps we'll just have something higher priority come up and we can just wait for you to fix it for us first ;) |
SwizzleShmemConvert currently only handles
dot(convert(trans(src)) #dot_operand). Unfortunately, AccelerateMatmul has the tendency to rewrite
dot(convert(trans(convert(src))) #dot_operand) into dot(trans(convert(src)) #dot_operand). This breaks ldmatrix lowering in TritonGPUToLLVM and tanks performance with transposed operands. We can fix this by recognizing the case where the trans directly feeds into dot in addition to the case where there's a convert in the middle.
Fixes #6569.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsI am not sure what the best way to test this is, please advise
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)