[PyTorch] Python DType enum#3039
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| // when this runs, so the GIL is held and Python imports are legal. | ||
| static pybind11::object te_dtype_cls = | ||
| pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType"); | ||
| return te_dtype_cls(static_cast<int>(dtype)); |
There was a problem hiding this comment.
Find a way to bind C++ and python Dtype through pybind cast mechanism
There was a problem hiding this comment.
This is done for Python. -> C++
For C++ to Python. --> Cant avoid this.
| # pybind11 enum used as Quantizer.dtype | ||
| tex.DType, | ||
| # Python IntEnum used as Quantizer.dtype | ||
| TE_DType, |
There was a problem hiding this comment.
save/load backward compatibilty should be there
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Greptile SummaryThis PR replaces the pybind11
Confidence Score: 4/5The PR is a well-scoped refactoring that can be merged; the two open observations do not block correctness under normal use. The 66-file migration is comprehensive and each constructor boundary normalizes values through transformer_engine/pytorch/csrc/extensions/pybind.cpp (tex.DType.hash not explicitly set after eq override) and transformer_engine/pytorch/csrc/common.cpp (GIL assumption in MakePythonDType magic-static initializer). Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python user code
participant DType as constants.DType IntEnum
participant Caster as pybind_dtype_caster.h
participant CPP as C++ quantizer/cast
participant MakePY as MakePythonDType
Note over PY,MakePY: Python to C++ direction
PY->>Caster: pass DType.kFloat8E4M3 to C++ fn
Caster->>Caster: PyLong_Check true, IntEnum is int subclass
Caster-->>CPP: static_cast to DType enum
Note over PY,MakePY: Backward-compat path
PY->>Caster: pass tex.DType.kFloat8E4M3 deprecated
Caster->>Caster: type_caster_base fallback
Caster-->>CPP: reads underlying C++ enum value
Note over PY,MakePY: C++ to Python direction
CPP->>MakePY: MakePythonDType kFloat8E4M3
MakePY->>MakePY: check magic-static cache
alt cache miss on first call
MakePY->>PY: import transformer_engine.pytorch
PY-->>MakePY: DType class
MakePY->>MakePY: populate cache per value
end
MakePY-->>PY: DType.kFloat8E4M3 canonical Python enum
Note over PY,MakePY: Cross-type equality check
PY->>DType: "DType.kFloat8E4M3 == tex.DType.kFloat8E4M3"
DType->>DType: isinstance check then int comparison
DType-->>PY: True
Reviews (17): Last reviewed commit: "Merge branch 'main' into te_dtype" | Re-trigger Greptile |
| # Fail fast at import time if a new enumerator is added | ||
| # on the C++ side without being mirrored above. | ||
| assert {m.name for m in DType} == set(tex.DType.__members__), ( | ||
| "DType is out of sync with transformer_engine_torch.DType; " | ||
| "add the new pybind enumerator to DType in constants.py." | ||
| ) |
There was a problem hiding this comment.
Import-time sync check can be silently skipped
Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch amd64 arm64 |
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci arm64 amd64 |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
LGTM. Aside from the functional considerations with checkpointing/pickling/torch.compile, this is quite a nice from a design perspective. It was uncomfortable that downstream dependencies needed to import tex to do advanced quantizer configuration.
There was a problem hiding this comment.
Nit: This is kind of a strange place for this header since it's not exposed externally in tex. That said, the csrc dir is generally kind of disorganized right now. Maybe reorganizing would be good for a future PR, something like:
common.h: Header for internal utilitiesapi.h: Header for functions exposed intexpybind: Dir with Pybind11 logicapi: Dir with functions exposed intexquantizers: Dir with quantizer implsutils: Dir with misc utilities
This is far beyond the scope of this PR though.
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
Description
Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. CPP still sees the transformer_engine::DType(C++ def for tex.DType) and so pybind caster is defined to make sure Python DType is mapped CPP Dtype correctly.
Motivation
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: