Skip to content

Commit e2af5f6

Browse files
committed
Update PyTorch setup for tf32 and change default WMT workload compute dtype to FLOAT32
1 parent a4f9948 commit e2af5f6

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

algoperf/pytorch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
23+
torch.set_float32_matmul_precision('high')
2324
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
2425
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
2526
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

algoperf/workloads/wmt/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class BaseWmtWorkload(spec.Workload):
2222
"""A WMT workload."""
2323

2424
_vocab_size: int = 32000
25-
_compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16
25+
_compute_dtype: spec.DTYPE = spec.DTYPE.FLOAT32
2626
_param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32
2727

2828
def __init__(self) -> None:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ jax_core_deps = [
9696
"chex==0.1.86",
9797
"ml_dtypes==0.5.1",
9898
"protobuf==4.25.5",
99+
"jmp",
99100
]
100101
jax_cpu = [
101102
"jax==0.7.0",

0 commit comments

Comments
 (0)