Commit 5fdfbec
[PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture (#3065)
* [PyTorch] Propagate skip_fp8_weight_update in GroupedLinear during FP8 CUDA graph capture
GroupedLinear.forward hardcoded None for skip_fp8_weight_update, so the
FP8 graph-capture skip tensor was never forwarded during CUDA graph
replay. Mirror Linear.forward: when fp8_graph_capturing() is true, read
quantization_state.skip_fp8_weight_update_tensor, force is_first_microbatch
to False, and thread the tensor into the forward call (the slot
_GroupedLinear.forward already unpacks).
Fixes #3051
Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
* [PyTorch] Add CUDA graph FP8 weight-caching test for GroupedLinear
Exercises skip_fp8_weight_update propagation in GroupedLinear during FP8
CUDA graph capture. With fp8_weight_caching enabled, graphed and eager
runs only match when is_first_microbatch is threaded into the weight-
update skip tensor for every microbatch, which the prior None hardcode
prevented.
Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
---------
Signed-off-by: LeSingh1 <sshaurya914@gmail.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>1 parent 4bf946d commit 5fdfbec
2 files changed
Lines changed: 97 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
216 | 217 | | |
217 | 218 | | |
218 | 219 | | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
219 | 252 | | |
220 | 253 | | |
221 | 254 | | |
| |||
315 | 348 | | |
316 | 349 | | |
317 | 350 | | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
318 | 360 | | |
319 | 361 | | |
320 | 362 | | |
| |||
501 | 543 | | |
502 | 544 | | |
503 | 545 | | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
504 | 592 | | |
505 | 593 | | |
506 | 594 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1684 | 1684 | | |
1685 | 1685 | | |
1686 | 1686 | | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
1687 | 1696 | | |
1688 | 1697 | | |
1689 | 1698 | | |
| |||
0 commit comments