Solve pytorch-triton and triton package contention#2540
Conversation
| num_ctas, # arg2: num_ctas (int) | ||
| compiled.metadata.shared, # arg3: shared_mem_bytes (int) | ||
| compiled.asm["ptx"], # arg4: ptx (str) | ||
| "", # arg5: ttir (str) - empty |
There was a problem hiding this comment.
This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+
Greptile SummaryThis PR resolves the package contention between Key changes:
Impact: Users running mixed JAX+PyTorch environments can now properly configure triton package selection. The detection logic warns users if they accidentally install the broken placeholder Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant SetupPy as setup.py
participant BuildTools as build_tools/*.py
participant TritonUtils as triton_extensions/utils.py
participant Triton as triton package
User->>SetupPy: pip install (PyTorch)
SetupPy->>BuildTools: install_requirements()
BuildTools-->>SetupPy: ["pytorch-triton", ...]
Note over SetupPy: Requires PyTorch index
User->>SetupPy: pip install (JAX)
SetupPy->>BuildTools: test_requirements()
alt NVTE_USE_PYTORCH_TRITON=1
BuildTools-->>SetupPy: ["pytorch-triton"]
else Default
BuildTools-->>SetupPy: ["triton"]
end
User->>TritonUtils: import triton_extensions
TritonUtils->>Triton: import triton
TritonUtils->>TritonUtils: _detect_triton_package()
alt Placeholder package (0.0.1)
TritonUtils-->>User: ImportError with fix instructions
else pytorch-triton detected
alt NVTE_USE_PYTORCH_TRITON=1
TritonUtils-->>User: Silent (acknowledged)
else Not acknowledged
TritonUtils-->>User: UserWarning
end
else Standard triton
TritonUtils-->>User: Normal operation
end
|
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/jax/triton_extensions/utils.py, line 322 (link)syntax: Typo:
compile.nameshould becompiled.name. The variablecompileis not defined in this scope - onlycompiledexists from line 300. This will cause aNameErrorat runtime for JAX versions < 0.8.2.
4 files reviewed, 1 comment
There was a problem hiding this comment.
Additional Comments (1)
-
build_tools/pytorch.py, line 21 (link)style: Placeholder text
<version??>should be replaced with an actual version (e.g.,cu121orcu124) or made generic.
4 files reviewed, 1 comment
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | ||
| "1", | ||
| "true", | ||
| "yes", | ||
| ) |
There was a problem hiding this comment.
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | |
| "1", | |
| "true", | |
| "yes", | |
| ) | |
| use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) |
There was a problem hiding this comment.
It's better to be specific with envvars and follow convention from rest of the codebase.
| @@ -22,7 +32,7 @@ def install_requirements() -> List[str]: | |||
| "packaging", | |||
| "pydantic", | |||
| "nvdlfw-inspect", | |||
| "triton", | |||
| "pytorch-triton", | |||
There was a problem hiding this comment.
If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency
There was a problem hiding this comment.
pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.
triton is the default for jax, unless either in 2 scenarios happen:
- there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
- The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.
There was a problem hiding this comment.
after the break I now understand what you mean here
…for jax Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Additional Comments (1)
-
build_tools/jax.py, line 37 (link)logic: Inconsistent env var parsing:
utils.pyaccepts"1","true","yes"(via.lower()check), but this line only accepts"0"or"1"(viaint()cast). If user setsNVTE_USE_PYTORCH_TRITON=true, this will raiseValueError: invalid literal for int() with base 10: 'true'.
4 files reviewed, 1 comment
e2773e1 to
469ed68
Compare
There was a problem hiding this comment.
Additional Comments (1)
-
build_tools/pytorch.py, line 35 (link)style: Listing
pytorch-tritonhere won't work with standardpip installfrom PyPI since the real package is only available from PyTorch's package index. Users runningpip install .without specifying--index-urlwill get the placeholder package that raises "Should never be installed".Consider adding an installation note to the README or setup.py, or documenting that this package requires:
pip install --index-url https://download.pytorch.org/whl/cu121 pytorch-triton
4 files reviewed, 1 comment
Signed-off-by: tdophung <tdophung@nvidia.com>
zhongbozhu
left a comment
There was a problem hiding this comment.
Can we merge this fix soon?
|
cc @nWEIdia |
Description
pytorch-tritonandtritonpackages install to the same location at site-packages/triton, andtritondoes not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creatingpytorch-tritonto make it work and validated it with the release of torch). Howeverpytorch-tritonshould in theory (and experimented) still be compatible with how jax uses it*.Fixes # (issue)
Type of change
Changes
Checklist: