Skip to content

Commit 3263218

Browse files
authored
Merge pull request #1344 from lzjpaul/25-10-19-dev
Create the dataset for the peft example
2 parents fa805d1 + 1001e01 commit 3263218

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
try:
21+
import pickle
22+
except ImportError:
23+
import cPickle as pickle
24+
25+
import numpy as np
26+
import os
27+
import sys
28+
29+
30+
def load_dataset(filepath):
31+
with open(filepath, 'rb') as fd:
32+
try:
33+
cifar10 = pickle.load(fd, encoding='latin1')
34+
except TypeError:
35+
cifar10 = pickle.load(fd)
36+
image = cifar10['data'].astype(dtype=np.uint8)
37+
image = image.reshape((-1, 3, 32, 32))
38+
label = np.asarray(cifar10['labels'], dtype=np.uint8)
39+
label = label.reshape(label.size, 1)
40+
return image, label
41+
42+
43+
#def load_train_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py', num_batches=5):
44+
def load_train_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py', num_batches=5):
45+
labels = []
46+
batchsize = 10000
47+
images = np.empty((num_batches * batchsize, 3, 32, 32), dtype=np.uint8)
48+
for did in range(1, num_batches + 1):
49+
fname_train_data = dir_path + "/data_batch_{}".format(did)
50+
image, label = load_dataset(check_dataset_exist(fname_train_data))
51+
images[(did - 1) * batchsize:did * batchsize] = image
52+
labels.extend(label)
53+
images = np.array(images, dtype=np.float32)
54+
labels = np.array(labels, dtype=np.int32)
55+
return images, labels
56+
57+
58+
#def load_test_data(dir_path='/scratch1/07801/nusbin20/gordon-bell/cifar-10-batches-py'):
59+
def load_test_data(dir_path='/scratch/snx3000/lyongbin/singa_my/cifar10_log/cifar-10-batches-py'):
60+
images, labels = load_dataset(check_dataset_exist(dir_path + "/test_batch"))
61+
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
62+
63+
64+
def check_dataset_exist(dirpath):
65+
if not os.path.exists(dirpath):
66+
print(
67+
'Please download the cifar10 dataset using python data/download_cifar10.py'
68+
)
69+
sys.exit(0)
70+
return dirpath
71+
72+
73+
def normalize(train_x, val_x):
74+
mean = [0.4914, 0.4822, 0.4465]
75+
std = [0.2023, 0.1994, 0.2010]
76+
train_x /= 255
77+
val_x /= 255
78+
for ch in range(0, 2):
79+
train_x[:, ch, :, :] -= mean[ch]
80+
train_x[:, ch, :, :] /= std[ch]
81+
val_x[:, ch, :, :] -= mean[ch]
82+
val_x[:, ch, :, :] /= std[ch]
83+
return train_x, val_x
84+
85+
def load(): # Need to pass in the path for loading training data
86+
train_x, train_y = load_train_data()
87+
val_x, val_y = load_test_data()
88+
train_x, val_x = normalize(train_x, val_x)
89+
train_y = train_y.flatten()
90+
val_y = val_y.flatten()
91+
return train_x, train_y, val_x, val_y

0 commit comments

Comments
 (0)