diff --git a/bigram.py b/bigram.py index faa3a14a..4f6d69bd 100644 --- a/bigram.py +++ b/bigram.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn from torch.nn import functional as F +import os +import subprocess # hyperparameters batch_size = 32 # how many independent sequences will we process in parallel? @@ -14,7 +16,10 @@ torch.manual_seed(1337) -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +if not os.path.exists('input.txt'): + output = subprocess.check_output(["wget", "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]) + print(output.decode()) + with open('input.txt', 'r', encoding='utf-8') as f: text = f.read() diff --git a/gpt.py b/gpt.py index 81f50c42..bb270448 100644 --- a/gpt.py +++ b/gpt.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn from torch.nn import functional as F +import os +import subprocess # hyperparameters batch_size = 64 # how many independent sequences will we process in parallel? @@ -18,7 +20,10 @@ torch.manual_seed(1337) -# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +if not os.path.exists('input.txt'): + output = subprocess.check_output(["wget", "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]) + print(output.decode()) + with open('input.txt', 'r', encoding='utf-8') as f: text = f.read()