|  | 
|  | 1 | +# | 
|  | 2 | +# Licensed to the Apache Software Foundation (ASF) under one | 
|  | 3 | +# or more contributor license agreements.  See the NOTICE file | 
|  | 4 | +# distributed with this work for additional information | 
|  | 5 | +# regarding copyright ownership.  The ASF licenses this file | 
|  | 6 | +# to you under the Apache License, Version 2.0 (the | 
|  | 7 | +# "License"); you may not use this file except in compliance | 
|  | 8 | +# with the License.  You may obtain a copy of the License at | 
|  | 9 | +# | 
|  | 10 | +#   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 11 | +# | 
|  | 12 | +# Unless required by applicable law or agreed to in writing, | 
|  | 13 | +# software distributed under the License is distributed on an | 
|  | 14 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | 15 | +# KIND, either express or implied.  See the License for the | 
|  | 16 | +# specific language governing permissions and limitations | 
|  | 17 | +# under the License. | 
|  | 18 | +# | 
|  | 19 | + | 
|  | 20 | +try: | 
|  | 21 | +    import pickle | 
|  | 22 | +except ImportError: | 
|  | 23 | +    import cPickle as pickle | 
|  | 24 | + | 
|  | 25 | +import numpy as np | 
|  | 26 | +import os | 
|  | 27 | +import sys | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +def load_dataset(filepath): | 
|  | 31 | +    with open(filepath, 'rb') as fd: | 
|  | 32 | +        try: | 
|  | 33 | +            cifar10 = pickle.load(fd, encoding='latin1') | 
|  | 34 | +        except TypeError: | 
|  | 35 | +            cifar10 = pickle.load(fd) | 
|  | 36 | +    image = cifar10['data'].astype(dtype=np.uint8) | 
|  | 37 | +    image = image.reshape((-1, 3, 32, 32)) | 
|  | 38 | +    label = np.asarray(cifar10['labels'], dtype=np.uint8) | 
|  | 39 | +    label = label.reshape(label.size, 1) | 
|  | 40 | +    return image, label | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +#def load_train_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py', num_batches=5): | 
|  | 44 | +def load_train_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py', num_batches=5): | 
|  | 45 | +    labels = [] | 
|  | 46 | +    batchsize = 10000 | 
|  | 47 | +    images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8) | 
|  | 48 | +    for did in range(1, num_batches + 1): | 
|  | 49 | +        fname_train_data = dir_path + "/data_batch_{}".format(did) | 
|  | 50 | +        image, label = load_dataset(check_dataset_exist(fname_train_data)) | 
|  | 51 | +        images[(did - 1) * batchsize:did * batchsize] = image | 
|  | 52 | +        labels.extend(label) | 
|  | 53 | +    images = np.array(images, dtype=np.float32) | 
|  | 54 | +    labels = np.array(labels, dtype=np.int32) | 
|  | 55 | +    return images, labels | 
|  | 56 | + | 
|  | 57 | + | 
|  | 58 | +#def load_test_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py'): | 
|  | 59 | +def load_test_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py'): | 
|  | 60 | +    images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch")) | 
|  | 61 | +    return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32) | 
|  | 62 | + | 
|  | 63 | + | 
|  | 64 | +def check_dataset_exist(dirpath): | 
|  | 65 | +    if not os.path.exists(dirpath): | 
|  | 66 | +        print( | 
|  | 67 | +            'Please download the cifar10 dataset using python data/download_cifar10.py' | 
|  | 68 | +        ) | 
|  | 69 | +        sys.exit(0) | 
|  | 70 | +    return dirpath | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +def normalize(train_x, val_x): | 
|  | 74 | +    mean = [0.4914, 0.4822, 0.4465] | 
|  | 75 | +    std = [0.2023, 0.1994, 0.2010] | 
|  | 76 | +    train_x /= 255 | 
|  | 77 | +    val_x /= 255 | 
|  | 78 | +    for ch in range(0, 2): | 
|  | 79 | +        train_x[:, ch, :, :] -= mean[ch] | 
|  | 80 | +        train_x[:, ch, :, :] /= std[ch] | 
|  | 81 | +        val_x[:, ch, :, :] -= mean[ch] | 
|  | 82 | +        val_x[:, ch, :, :] /= std[ch] | 
|  | 83 | +    return train_x, val_x | 
|  | 84 | + | 
|  | 85 | +def load():  # Need to pass in the path for loading training data | 
|  | 86 | +    train_x, train_y = load_train_data() | 
|  | 87 | +    val_x, val_y = load_test_data() | 
|  | 88 | +    train_x, val_x = normalize(train_x, val_x) | 
|  | 89 | +    train_y = train_y.flatten() | 
|  | 90 | +    val_y = val_y.flatten() | 
|  | 91 | +    return train_x, train_y, val_x, val_y | 
0 commit comments