Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New softmax forward kernel in dev/cuda #576

Closed
wants to merge 3 commits into from

Conversation

KarhouTam
Copy link
Contributor

@KarhouTam KarhouTam commented Jun 10, 2024

Did tiny modifications to made codes cleaner and performance better.

OS: Ubuntu 22.04
Device: NVIDIA GeForce RTX 3070 Laptop GPU

Performance

softmax_forward6 (online softmax with cooperative groups)

block_size   32 | time 14.3109 ms | per token 1.75 µs
block_size   64 | time 13.4489 ms | per token 1.64 µs
block_size  128 | time 16.3121 ms | per token 1.99 µs
block_size  256 | time 17.4961 ms | per token 2.14 µs
block_size  512 | time 22.0295 ms | per token 2.69 µs
block_size 1024 | time 22.6308 ms | per token 2.76 µs

softmax_forward8 (online softmax without cgs)

block_size   32 | time 14.2730 ms | per token 1.74 µs
block_size   64 | time 13.2274 ms | per token 1.61 µs
block_size  128 | time 14.4429 ms | per token 1.76 µs
block_size  256 | time 17.3742 ms | per token 2.12 µs
block_size  512 | time 18.2752 ms | per token 2.23 µs
block_size 1024 | time 20.2799 ms | per token 2.48 µs

@ngc92
Copy link
Contributor

ngc92 commented Jun 10, 2024

shouldn't the performance comparison be kernel8 before and after these changes? The current numbers don't really say anything about this PR

personally, I'd also say that __shfl_xor_sync doesn't really improve readability of the code. If you see that for the first time, unless you're way smarter than me, you probably need to sketch out what is actually happening there.

@KarhouTam
Copy link
Contributor Author

KarhouTam commented Jun 10, 2024

I agree that __shfl_xor_sync() is way complicated than __shfl_down_sync().

In my view, if just for reduction, __shfl_down_sync() and __shfl_xor_sync() can properly do the job, and I found that using __shfl_donw_sync() for reduction needs val = __shfl_sync(0xFFFFFFFF, val, 0) for synchronization that __shfl_xor_sync() doesn't. I'm still finding the specific reason behind this.

Here is the comparison. These changes don't improve the kernel that much, but they are not harmful, as least I think so.

Comparison Before and After Changes

Using __shfl_down_sync() with __shfl_sync()

block_size   32 | time 15.2935 ms | per token 1.87 µs
block_size   64 | time 15.2867 ms | per token 1.87 µs
block_size  128 | time 17.3360 ms | per token 2.12 µs
block_size  256 | time 19.3772 ms | per token 2.37 µs
block_size  512 | time 21.4683 ms | per token 2.62 µs
block_size 1024 | time 25.6833 ms | per token 3.14 µs

Using __shfl_xor_sync() without __shfl_sync()

block_size   32 | time 14.1287 ms | per token 1.72 µs
block_size   64 | time 13.2909 ms | per token 1.62 µs
block_size  128 | time 13.8408 ms | per token 1.69 µs
block_size  256 | time 16.6469 ms | per token 2.03 µs
block_size  512 | time 18.0218 ms | per token 2.20 µs
block_size 1024 | time 19.5242 ms | per token 2.38 µs

Thanks for your code review.
I'm just a cuda newbie and trying to participate some open source project to improve myself.
Respect you and your contribution 🙌.

@ngc92
Copy link
Contributor

ngc92 commented Jun 10, 2024

xor shuffle is the way to go if you need the result in all threads. What it implements is a butterfly-shuffle type operation, that ensures that in the end, every thread actually computes the reduction over all values. Its actually not that complicated once you visualize what is going on, e.g., this picture was the first that came up on google. Then, of course, you need to convince yourself that the xor operation actually does create this pattern.
(In contrast, shfl down would have all the arrows pointing to the left, so that only the leftmost element would actually get the full reduction; that's why you need one more shuffle to broadcast the results)

Because it is so useful but at the same time non-trivial, we actually do have a helper function in common.h that implements exactly this operation: warpReduceSum.

The changes actually give more improvement than I would have guessed :)

@KarhouTam
Copy link
Contributor Author

KarhouTam commented Jun 10, 2024

Thank you for your comprehensive explanation. It really helps me understand these warp reduce functions.

I do know llm.c provides warp reduce API in common.h. Considering the reduction in online softmax requires additional comparisons, I haven't used them, which are suitable for more general cases.

BTW, I see that in most warp reduce API implementations, the offset is mostly chagned by offset /=2. Consider the warp size in most scenarios is fixed 32, will be more efficient that using offset >>= 1 instead? Because in common sense, bitwise operation is more computationally lightweight than division. Or compiler actually do optimization for us and there is no difference?

Thank you for your reviewing and comment again.

Switch warp reduce func from `__shfl_down_sync()` to `__shfl_xor_sync()` and remove the `__shfl_sync()` codes.

Add comment at the beginning of file.
@KarhouTam KarhouTam force-pushed the improve-softmax-forward8 branch from d936f4c to ddfe7d3 Compare June 22, 2024 00:05
@KarhouTam KarhouTam changed the title Improve the performance of softmax_forward8 in dev/cuda New softmax forward kernel in dev/cuda Jun 22, 2024
@KarhouTam
Copy link
Contributor Author

I implement a new softmax forward kernel named softmax_forward_online_kernel9

Kernel 9 using x128 to decently acclerate memory access.

Benchmark

softmax_forward_online_kernel9

block_size   32 | time 14.5774 ms | per token 1.78 µs
block_size   64 | time 7.6058 ms | per token 0.93 µs
block_size  128 | time 5.8938 ms | per token 0.72 µs
block_size  256 | time 5.2009 ms | per token 0.63 µs
block_size  512 | time 4.3161 ms | per token 0.53 µs
block_size 1024 | time 5.4758 ms | per token 0.67 µs

Compares to the others:

softmax_forward_online_kernel8

Kernel 8 is online softmax forward without cooperative groups

block_size   32 | time 14.6022 ms | per token 1.78 µs
block_size   64 | time 14.3459 ms | per token 1.75 µs
block_size  128 | time 15.1699 ms | per token 1.85 µs
block_size  256 | time 19.2555 ms | per token 2.35 µs
block_size  512 | time 19.0878 ms | per token 2.33 µs
block_size 1024 | time 22.3713 ms | per token 2.73 µs

softmax_forward_online_kernel7

Kernel 7 is normal softmax optimized for very large C.

block_size   32 | time 21.3318 ms | per token 2.60 µs
block_size   64 | time 29.1571 ms | per token 3.56 µs
block_size  128 | time 38.7531 ms | per token 4.73 µs
block_size  256 | time 44.2141 ms | per token 5.40 µs
block_size  512 | time 48.5004 ms | per token 5.92 µs
block_size 1024 | time 50.4821 ms | per token 6.16 µs

softmax_forward_online_kernel6

Kernel 6 is online softmax forward with cooperative groups

block_size   32 | time 14.7439 ms | per token 1.80 µs
block_size   64 | time 14.3677 ms | per token 1.75 µs
block_size  128 | time 17.9260 ms | per token 2.19 µs
block_size  256 | time 20.7341 ms | per token 2.53 µs
block_size  512 | time 21.6989 ms | per token 2.65 µs
block_size 1024 | time 22.9575 ms | per token 2.80 µs

@KarhouTam KarhouTam closed this Jun 24, 2024
@KarhouTam KarhouTam deleted the improve-softmax-forward8 branch June 24, 2024 09:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants