diff --git a/bigram.py b/bigram.py index faa3a14a..6649b8ec 100644 --- a/bigram.py +++ b/bigram.py @@ -8,7 +8,9 @@ max_iters = 3000 eval_interval = 300 learning_rate = 1e-2 -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = ('cuda' if torch.cuda.is_available() + else 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + else 'cpu') eval_iters = 200 # ------------ diff --git a/gpt.py b/gpt.py index e4fc68d6..2e92a613 100644 --- a/gpt.py +++ b/gpt.py @@ -8,7 +8,9 @@ max_iters = 5000 eval_interval = 500 learning_rate = 3e-4 -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = ('cuda' if torch.cuda.is_available() + else 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() + else 'cpu') eval_iters = 200 n_embd = 384 n_head = 6