Skip to content

[PyTorch] Python DType enum#3039

Merged
vthumbe1503 merged 44 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype
Jun 4, 2026
Merged

[PyTorch] Python DType enum#3039
vthumbe1503 merged 44 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 22, 2026

Description

Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. CPP still sees the transformer_engine::DType(C++ def for tex.DType) and so pybind caster is defined to make sure Python DType is mapped CPP Dtype correctly.

Motivation

  • CPU overheads: tex.DType is a pybind enum, so every access/compare/convert in Python crosses into C-extension code.
  • torch.compile: tex.DType won't work with torch.compile — TorchDynamo doesn't understand pybind enums, so it graph-breaks (or fails to trace) when one flows through a compiled region.
  • Checkpointing: tex.DType lives in tensor/quantizer state and lands in checkpoints; pickling a pybind enum is fragile and awkward to allow-list vs. a stdlib python enum.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Canonical DType: Add a pure-Python DType(IntEnum) in constants.py as the single source of truth. Its members are defined from the C++ enum values, and an import-time assert verifies it stays in sync with the pybinded enum tex.DType
  • Migration: Repoint TE_DType maps and move pytorch modules, examples, benchmarks, and tests off raw tex.DType onto constants.DType.
  • Backward compatibility: Add DTypeSupported = Union[DType, tex.DType]; tex.DType is still accepted at Quantizer/QuantizedTensor constructor boundaries and stays allow-listed for loading old checkpoints.
  • Python → C++: Register a type_caster for transformer_engine::DType in Pytorch extensions so that all our pybind registered functions see them. This enables us to pass int/pythonDType values into tex functions from python and enables pybind to convert it to transformer_engine::DType accepted by tex C++ API interface
  • C++ → Python: Add cached MakePythonDType (csrc/common.*) and use it at quantizer/quantizedtensor construction.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits May 22, 2026 21:50
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title initial prototype TE_DType in python May 22, 2026
// when this runs, so the GIL is held and Python imports are legal.
static pybind11::object te_dtype_cls =
pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType");
return te_dtype_cls(static_cast<int>(dtype));
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find a way to bind C++ and python Dtype through pybind cast mechanism

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done for Python. -> C++

For C++ to Python. --> Cant avoid this.

Comment thread transformer_engine/pytorch/__init__.py Outdated
# pybind11 enum used as Quantizer.dtype
tex.DType,
# Python IntEnum used as Quantizer.dtype
TE_DType,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save/load backward compatibilty should be there

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vthumbe1503 and others added 14 commits May 31, 2026 19:45
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title TE_DType in python [PyTorch] Python DType enum Jun 1, 2026
vthumbe1503 and others added 2 commits June 1, 2026 08:27
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 1, 2026 08:32
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR replaces the pybind11 tex.DType enum with a canonical Python DType(IntEnum) throughout transformer_engine.pytorch, addressing CPU overhead from cross-extension calls, torch.compile incompatibility, and fragile pickling. The core change introduces DType in constants.py, a custom pybind type-caster (pybind_dtype_caster.h) for transparent Python→C++ conversion, and MakePythonDType for the C++→Python path.

  • DType introduction: A new IntEnum with custom __eq__/__ne__/__hash__ is added to constants.py, replacing tex.DType across 66 files while keeping the pybind enum accepted at constructor boundaries for backward compatibility.
  • C++ interop layer: pybind_dtype_caster.h enables any C++ function expecting transformer_engine::DType to transparently accept DType, int, or tex.DType; MakePythonDType in common.cpp caches and returns canonical Python DType members when constructing quantized tensors from C++.
  • Copyreg / checkpoint compatibility: Both DType and tex.DType are added to the pickling allowlist in __init__.py, preserving loading of old checkpoints.

Confidence Score: 4/5

The PR is a well-scoped refactoring that can be merged; the two open observations do not block correctness under normal use.

The 66-file migration is comprehensive and each constructor boundary normalizes values through DType.cast() or MakePythonDType. The one area deserving a second look is the tex.DType.__hash__ behavior after the runtime __eq__ override: cross-type dict/set lookups in TE_DType_To_Torch and similar maps rely on hash(tex.DType.kX) == hash(DType.kX), which holds in current pybind11 implementations but is not explicitly guaranteed or enforced by the new code.

transformer_engine/pytorch/csrc/extensions/pybind.cpp (tex.DType.hash not explicitly set after eq override) and transformer_engine/pytorch/csrc/common.cpp (GIL assumption in MakePythonDType magic-static initializer).

Important Files Changed

Filename Overview
transformer_engine/pytorch/constants.py Introduces DType(IntEnum) with custom eq/ne/hash and a cast() classmethod; updates TE_DType map to use new enum; derives TE_DType_To_Torch by dict-reversal. The import-time assert sync-check can be bypassed in optimized builds.
transformer_engine/pytorch/csrc/extensions/pybind_dtype_caster.h New custom type-caster lets C++ functions accept DType IntEnum, plain int, or tex.DType pybind enum transparently; cast() still returns tex.DType (intentional for backward compat); header guard uses plural CASTERS_H while filename is singular.
transformer_engine/pytorch/csrc/common.cpp Adds MakePythonDType with a per-value cache and explicit null-safety check via NVTE_CHECK; cache is initialized once using a C++11 magic static that imports transformer_engine.pytorch.DType lazily.
transformer_engine/pytorch/csrc/common.h Adds #include extensions/pybind_dtype_caster.h (singular, matching the actual filename) and declares MakePythonDType; include path is correct.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Registers eq/ne on tex.DType to compare by integer value; hash is not explicitly registered to match, leaving an implicit reliance on pybind11's default enum hash behavior for cross-type dict/set correctness.
transformer_engine/pytorch/csrc/quantizer.cpp All py::cast(this->dtype) / py::cast(dtype) calls replaced with MakePythonDType(...) for FP8 / MXFP8 / NVFP4 quantizers; comprehensive coverage across all create/convert tensor code paths.
transformer_engine/pytorch/init.py Imports DType early (after load_framework_extension) and adds it to the copyreg allowlist alongside tex.DType; ordering is correct and backward-compat is preserved.
transformer_engine/pytorch/tensor/float8_tensor.py Migrates Float8Quantizer and Float8CurrentScalingQuantizer dtype fields to DType; constructors accept Union[DType, tex.DType] and normalize with DType.cast().
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Storage layer now stores _fp8_dtype as DType; constructor normalizes with DType.cast(); TE_DType_To_Torch import updated accordingly.
transformer_engine/pytorch/csrc/extensions/cast.cpp All bulk-allocate Python tensor constructors updated to pass MakePythonDType(fp8_dtype) / MakePythonDType(fp4_dtype) instead of the raw C++ enum.

Sequence Diagram

sequenceDiagram
    participant PY as Python user code
    participant DType as constants.DType IntEnum
    participant Caster as pybind_dtype_caster.h
    participant CPP as C++ quantizer/cast
    participant MakePY as MakePythonDType

    Note over PY,MakePY: Python to C++ direction
    PY->>Caster: pass DType.kFloat8E4M3 to C++ fn
    Caster->>Caster: PyLong_Check true, IntEnum is int subclass
    Caster-->>CPP: static_cast to DType enum

    Note over PY,MakePY: Backward-compat path
    PY->>Caster: pass tex.DType.kFloat8E4M3 deprecated
    Caster->>Caster: type_caster_base fallback
    Caster-->>CPP: reads underlying C++ enum value

    Note over PY,MakePY: C++ to Python direction
    CPP->>MakePY: MakePythonDType kFloat8E4M3
    MakePY->>MakePY: check magic-static cache
    alt cache miss on first call
        MakePY->>PY: import transformer_engine.pytorch
        PY-->>MakePY: DType class
        MakePY->>MakePY: populate cache per value
    end
    MakePY-->>PY: DType.kFloat8E4M3 canonical Python enum

    Note over PY,MakePY: Cross-type equality check
    PY->>DType: "DType.kFloat8E4M3 == tex.DType.kFloat8E4M3"
    DType->>DType: isinstance check then int comparison
    DType-->>PY: True
Loading

Reviews (17): Last reviewed commit: "Merge branch 'main' into te_dtype" | Re-trigger Greptile

Comment on lines +44 to +49
# Fail fast at import time if a new enumerator is added
# on the C++ side without being mirrored above.
assert {m.name for m in DType} == set(tex.DType.__members__), (
"DType is out of sync with transformer_engine_torch.DType; "
"add the new pybind enumerator to DType in constants.py."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Import-time sync check can be silently skipped

Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.

Comment thread transformer_engine/common/util/dtype_pybind_conversion.h Outdated
Comment thread transformer_engine/pytorch/constants.py
@vthumbe1503 vthumbe1503 requested a review from ptrendx June 1, 2026 18:13
Comment thread transformer_engine/pytorch/csrc/common.cpp Outdated
vthumbe1503 and others added 5 commits June 2, 2026 19:21
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch amd64 arm64

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci arm64 amd64

Comment thread transformer_engine/pytorch/csrc/extensions/pybind_dtype_caster.h
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp
vthumbe1503 and others added 2 commits June 4, 2026 05:28
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

timmoon10
timmoon10 previously approved these changes Jun 4, 2026
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Aside from the functional considerations with checkpointing/pickling/torch.compile, this is quite a nice from a design perspective. It was uncomfortable that downstream dependencies needed to import tex to do advanced quantizer configuration.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is kind of a strange place for this header since it's not exposed externally in tex. That said, the csrc dir is generally kind of disorganized right now. Maybe reorganizing would be good for a future PR, something like:

  • common.h: Header for internal utilities
  • api.h: Header for functions exposed in tex
  • pybind: Dir with Pybind11 logic
  • api: Dir with functions exposed in tex
  • quantizers: Dir with quantizer impls
  • utils: Dir with misc utilities

This is far beyond the scope of this PR though.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/csrc/common.h Outdated
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 merged commit f458abe into NVIDIA:main Jun 4, 2026
12 of 15 checks passed
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants