Skip to content

v0.12.0

Latest
Compare
Choose a tag to compare
@drisspg drisspg released this 17 Jul 17:56
· 29 commits to main since this release

Highlights

We are excited to announce the 0.12.0 release of torchao! This release adds support for QAT + Axolotl Integration and prototype MXFP/NVFP support on Blackwell GPUs!

QAT + Axolotl Integration

TorchAO’s QAT support has been integrated into Axolotl’s fine-tuning recipes! Check out the docs here or run it yourself using the following command:

axolotl train examples/llama-3/3b-qat-fsdp2.yaml
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml

Initial results for Llama3.2-3B by @SalmanMohammadi (axolotl-ai-cloud/axolotl#2590):

Model/Metric hellaswag acc hellaswag acc_norm wikitext bits_per_byte wikitext byte_perplexity wikitext word_perplexity
bfloat16 0.5552 0.7315 0.6410 1.5594 10.7591
bfloat16 PTQ 0.5393 0.7157 0.6613 1.5815 11.6033
qat ptq 0.5423 0.7180 0.6567 1.5764 11.4043
Recovered (qat ptq) 18.87% 14.56% 22.66% 23.08% 23.57%

[Prototype | API not finalized] MXFP and NVFP support on Blackwell GPUs

TorchAO now includes prototype support for NVFP4 (NVIDIA's 4-bit floating-point format) and Microscaling (MX) formats on NVIDIA's latest Blackwell GPU architecture. These formats enable efficient inference, achieving up to 61% end-to-end performance improvement in vLLM on Qwen3 models and near 2x speedups for diffusion workloads.

To use:

from torchao.quantization import quantize_ 
from torchao.prototype.mx_formats import (
    MXFPInferenceConfig,
    NVFP4InferenceConfig,
)
# Quantize model with MXFP8 
model = quantize_(model, MXFPInferenceConfig(block_size=32))
# Quantize model to NVFP4 (without double scaling)
model = quantize_(model, NVFP4InferenceConfig())

Note: This is a prototype feature with APIs subject to change. Requires NVIDIA Blackwell GPUs (B200, 5090) with CUDA 12.8+.

BC Breaking

  • Remove preserve_zero and zero_point_domain from choose_qparams_affine (#2149)
  • Rename qparams for tinygemm (#2344)
  • Convert quant_primitives methods private (#2350)
  • Delete Galore (#2397)
  • Remove more Galore bits (#2417)
  • Remove sparsity/prototype/blocksparse (#2205)

Deprecations

  • Clean up prototype folder (#2232)
  • Make float8 training's force_recompute_fp8_weight_in_bwd flag do nothing (#2356)

New Features

  • Enabling MOE Quantization using linear decomposition (#2043)
  • [PT2E][X86] Migrate fusion passes in Inductor to torchao (#2140)
  • 2:4 activation sparsity packing kernels (#2012)
  • Add subclass based method for inference w/ MXFP8 (#2132)
  • Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors (#1763)
  • Arm_inductor_quantizer for Pt2e quantization (#2139)
  • Add mx_fp4 path (#2201)
  • Add support for KleidiAI int4 kernels on aarch64 Linux (#2169)
  • Add support for fbgemm int4 mm kernel (#2255)
  • Enable fp16+int4 mixed precission path for int4 xpu path with int zero point (#2240)
  • Enable range learning for QAT (#2033)
  • Patch the _is_conv_node function (#2257)
  • Add support for fbgemm fp8 kernels (#2276)
  • Add Float8ActInt4WeightQATQuantizer (#2289)
  • [float8] add _auto_filter_for_recipe to float8 (#2410)
  • NVfp4 (#2408)
  • [float8] Prevent quantize_affine_float8/dequantize_affine_float8 decomposed on inductor (#2379)
  • [CPU] Enable DA8W4 on CPU (#2128)
  • Add exportable coreml codebook quantization op (#2443)
  • Add support for Int4GroupwisePreshuffleTensor for fbgemm (#2421)

Improvement

  • Add serialization support for AOPerModuleConfig (#2186)
  • Set eps in end-to-end QAT flow (#2180)
  • Enable {conv3d, conv_transpose3d} + bn fusion in pt2e (#2212)
  • Update GemLite to support vLLM V1 (#2199)
  • [sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity (#2242)
  • Patch the _is_conv_node function (#2223)
  • Relax int4wo device mismatch error (#2254)
  • Rename AOPerModuleConfig to ModuleFqnToConfig (#2243)
  • [reland2][ROCm] preshuffled weight mm (#2207)
  • GPTQ updates (#2235)
  • Fix QAT range learning, ensure scales get gradients (#2280)
  • Fix slicing and get_plain() in GemLite (#2288)
  • Add slicing support for fbgemm fp8 and int4 (#2308)
  • Add support for bmm and to for fbgemm Tensor (#2337)
  • Add dynamic quantization support to gemlite layout (#2327)
  • Test PARQ with torchao activation quantization (#2370)
  • Update index.rst (#2395)
  • Add inplace quantizer examples (#2345)
  • Build mxfp4 kernel for sm120a (#2285)
  • Enable to_mxfp8 cast for DTensor (#2420)
  • Enable tensor parallelism for MXLinear (#2434)
  • Graduate debug handle in torchao (#2452)
  • Switch alignemtn to 8 for cutlass 4 upgrade (#2491)
  • Mxfp8 training: add TP sharding strategy for dim1 kernel (#2436)

Bug Fixes

  • [optim] Fix low-bit optim when used with FSDP2+CPUOffload (#2195)
  • Fix Per Row scaling for inference (#2253)
  • Fix benchmark_low_bit_adam.py reference (#2287)
  • [optim] Fix bug when default dtype is BF16 (#2286)
  • [sparse] marlin fixes (#2305)
  • Fix ROCM test failures (#2362)
  • [float8] Add fnuz fp8 dtypes to Float8Layout (#2351)
  • Fixing ruff format for trunk (#2369)
  • Fixing trunk - autoquant test failure (#2363)
  • Remove torchao dependency from torchao build script (#2383)
  • Fix torchao quantized model in fbcode (#2396)
  • Gemlite generate.py fix (#2372)
  • Fixes issue #156414: Fixes bug in implementation of _combine_histogram (Follow up) (#2418)
  • TorchAO new observers (#2508)
  • Fix tutorials (#2516)

Performance

  • Add a triton kernel for swizziling (#2168)

Documentation

  • Add blockwise fp8 gemm benchmarks to README (#2203)
  • [float] document e2e training -> inference flow (#2190)
  • Update Readme (#1526)
  • Mark QAT range learning as prototype for now (#2272)
  • Update float8 training readme to include time measurement (#2291)
  • [BE/docs] Add float8 training api ref to docsite (#2313)
  • Enable doc build to run on PRs (#2315)
  • [BE] [docs] Add float8 pretraining tutorial to docsite (#2304)
  • [BE/docs] Add fp8 rowwise perf table to float8 training readme (#2312)
  • Update Quantization docs to show newer AOConfigs (#2317)
  • Update QAT docs, highlight axolotl integration (#2266)
  • Add static quant tutorial (#2047)
  • Update README.md to include seamless v2 (#2355)
  • Add Tutorial on E2E integration into VLLM and minimal Subclass (#2346)
  • [docs] Replace deprecated configs with Config objects (#2375)
  • Revamp README (#2374)
  • Add pt2e tutorials to torchao doc page (#2384)
  • Add part 2 of end-to-end tutorial: fine-tuning (#2394)
  • Call out axolotl + QAT integration on README (#2442)
  • Float8 readme: remove duplication (#2447)
  • Float8 readme: add key features section (#2448)
  • Update README.md to include Flux-Fast (#2457)
  • Inference tutorial - Part 3 of e2e series (#2343)
  • Update QAT README and API docstrings (#2465)
  • Fix typo : whic -> which (#2495)
  • Fix links for torchao tutorials (#2503)
  • Fix docstrings for quantization API docs (#2471)
  • Tutorial for benchmarking (#2499)

Developers

New Contributors

Full Changelog: v0.11.0...v0.12.0-rc2