Skip to content

Commit c9899cf

Browse files
committed
Use tf32 in pytorch
1 parent 6806019 commit c9899cf

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

algoperf/pytorch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
23+
torch.set_float32_matmul_precision('high')
2324
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
2425
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
2526
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

0 commit comments

Comments
 (0)