-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New softmax forward kernel in dev/cuda
#576
Conversation
shouldn't the performance comparison be kernel8 before and after these changes? The current numbers don't really say anything about this PR personally, I'd also say that __shfl_xor_sync doesn't really improve readability of the code. If you see that for the first time, unless you're way smarter than me, you probably need to sketch out what is actually happening there. |
I agree that In my view, if just for reduction, Here is the comparison. These changes don't improve the kernel that much, but they are not harmful, as least I think so. Comparison Before and After ChangesUsing
|
xor shuffle is the way to go if you need the result in all threads. What it implements is a butterfly-shuffle type operation, that ensures that in the end, every thread actually computes the reduction over all values. Its actually not that complicated once you visualize what is going on, e.g., this picture was the first that came up on google. Then, of course, you need to convince yourself that the xor operation actually does create this pattern. Because it is so useful but at the same time non-trivial, we actually do have a helper function in The changes actually give more improvement than I would have guessed :) |
Thank you for your comprehensive explanation. It really helps me understand these warp reduce functions. I do know llm.c provides warp reduce API in common.h. Considering the reduction in online softmax requires additional comparisons, I haven't used them, which are suitable for more general cases. BTW, I see that in most warp reduce API implementations, the offset is mostly chagned by Thank you for your reviewing and comment again. |
Switch warp reduce func from `__shfl_down_sync()` to `__shfl_xor_sync()` and remove the `__shfl_sync()` codes. Add comment at the beginning of file.
d936f4c
to
ddfe7d3
Compare
softmax_forward8
in dev/cuda
dev/cuda
I implement a new softmax forward kernel named Kernel 9 using Benchmark
|
Did tiny modifications to made codes cleaner and performance better.
OS: Ubuntu 22.04
Device: NVIDIA GeForce RTX 3070 Laptop GPU
Performance
softmax_forward6
(online softmax with cooperative groups)softmax_forward8
(online softmax without cgs)