Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion transformer_engine/plugin/core/backends/vendor/cuda/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def try_load_lib(name, search_patterns):
try_load_lib("nvrtc", [f"libnvrtc{ext}*"])
try_load_lib("curand", [f"libcurand{ext}*"])

te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
te_path_override = os.environ.get("TE_LIB_PATH")
if te_path_override:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, only the te_path in vendor/cuda was updated.
Don't the te_path entries under vendor/ for other hardware vendors (such as hygon) also need to be updated?

te_path = Path(te_path_override)
else:
te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
for search_dir in [te_path, te_path / "transformer_engine"]:
if search_dir.exists():
matches = list(search_dir.glob(f"libtransformer_engine{ext}*"))
Expand Down
Loading