diff --git a/examples/pytorch/FastCells/fastcell_example.py b/examples/pytorch/FastCells/fastcell_example.py index 9d55dd9d7..f4a5ec8d1 100644 --- a/examples/pytorch/FastCells/fastcell_example.py +++ b/examples/pytorch/FastCells/fastcell_example.py @@ -47,6 +47,14 @@ def main(): assert dataDimension % inputDims == 0, "Infeasible per step input, " + \ "Timesteps have to be integer" + timeSteps = int(dataDimension / inputDims) + Xtrain = Xtrain.reshape((-1, timeSteps, inputDims)) + Xtest = Xtest.reshape((-1, timeSteps, inputDims)) + + if not batch_first: + Xtrain = np.swapaxes(Xtrain, 0, 1) + Xtest = np.swapaxes(Xtest, 0, 1) + currDir = helpermethods.createTimeStampDir(dataDir, cell) helpermethods.dumpCommand(sys.argv, currDir) diff --git a/pytorch/edgeml_pytorch/trainer/fastTrainer.py b/pytorch/edgeml_pytorch/trainer/fastTrainer.py index 3f0ebd338..8a6c44d59 100644 --- a/pytorch/edgeml_pytorch/trainer/fastTrainer.py +++ b/pytorch/edgeml_pytorch/trainer/fastTrainer.py @@ -9,6 +9,14 @@ from edgeml_pytorch.graph.rnn import * import numpy as np +class SimpleFC(nn.Module): + def __init__(self, input_size, num_classes, name="SimpleFC"): + super(SimpleFC, self).__init__() + self.FC = nn.Parameter(torch.randn([input_size, num_classes])) + self.FCbias = nn.Parameter(torch.randn([num_classes])) + + def forward(self, input): + return torch.matmul(input, self.FC) + self.FCbias class FastTrainer: @@ -49,24 +57,19 @@ def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0, self.assertInit() self.numMatrices = self.FastObj.num_weight_matrices self.totalMatrices = self.numMatrices[0] + self.numMatrices[1] - - self.optimizer = self.optimizer() - self.RNN = BaseRNN(self.FastObj, batch_first=self.batch_first).to(self.device) - self.FC = nn.Parameter(torch.randn( - [self.FastObj.output_size, self.numClasses])).to(self.device) - self.FCbias = nn.Parameter(torch.randn( - [self.numClasses])).to(self.device) + self.simpleFC = SimpleFC(self.FastObj.output_size, self.numClasses).to(self.device) self.FastParams = self.FastObj.getVars() + self.optimizer = self.optimizer() def classifier(self, feats): ''' Can be raplaced by any classifier TODO: Make this a separate class if needed ''' - return torch.matmul(feats, self.FC) + self.FCbias + return self.simpleFC(feats) def computeLogits(self, input): ''' @@ -74,19 +77,23 @@ def computeLogits(self, input): ''' if self.FastObj.cellType == "LSTMLR": feats, _ = self.RNN(input) - logits = self.classifier(feats[-1, :]) else: feats = self.RNN(input) - logits = self.classifier(feats[-1, :]) - return logits, feats[:, -1] + if self.batch_first: + logits = self.classifier(feats[:, -1]) + return logits, feats[:, -1] + else: + logits = self.classifier(feats[-1, :]) + return logits, feats[-1, :] def optimizer(self): ''' Optimizer for FastObj Params ''' + paramList = list(self.FastObj.parameters()) + list(self.simpleFC.parameters()) optimizer = torch.optim.Adam( - self.FastObj.parameters(), lr=self.learningRate) + paramList, lr=self.learningRate) return optimizer @@ -168,12 +175,12 @@ def getModelSize(self): hasSparse = hasSparse or sparseFlag # Replace this with classifier class call - nnz, size, sparseFlag = utils.estimateNNZ(self.FC, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FC, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag - nnz, size, sparseFlag = utils.estimateNNZ(self.FCbias, 1.0) + nnz, size, sparseFlag = utils.estimateNNZ(self.simpleFC.FCbias, 1.0) totalnnZ += nnz totalSize += size hasSparse = hasSparse or sparseFlag @@ -341,8 +348,8 @@ def saveParams(self, currDir): np.save(os.path.join(currDir, "Bo.npy"), self.FastParams[self.totalMatrices + 3].data.cpu()) - np.save(os.path.join(currDir, "FC.npy"), self.FC.data.cpu()) - np.save(os.path.join(currDir, "FCbias.npy"), self.FCbias.data.cpu()) + np.save(os.path.join(currDir, "FC.npy"), self.simpleFC.FC.data.cpu()) + np.save(os.path.join(currDir, "FCbias.npy"), self.simpleFC.FCbias.data.cpu()) def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, decayStep, decayRate, dataDir, currDir): @@ -351,7 +358,13 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, ''' fileName = str(self.FastObj.cellType) + 'Results_pytorch.txt' resultFile = open(os.path.join(dataDir, fileName), 'a+') - numIters = int(np.ceil(float(Xtrain.shape[0]) / float(batchSize))) + if self.batch_first: + self.timeSteps = Xtrain.shape[1] + self.numPoints = Xtrain.shape[0] + else: + self.timeSteps = Xtrain.shape[0] + self.numPoints = Xtrain.shape[1] + numIters = int(np.ceil(float(self.numPoints) / float(batchSize))) totalBatches = numIters * totalEpochs counter = 0 @@ -362,11 +375,6 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, ihtDone = 1 maxTestAcc = -10000 header = '*' * 20 - self.timeSteps = int(Xtest.shape[1] / self.inputDims) - Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims)) - Xtest = np.swapaxes(Xtest, 0, 1) - Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims)) - Xtrain = np.swapaxes(Xtrain, 0, 1) for i in range(0, totalEpochs): print("\nEpoch Number: " + str(i), file=self.outFile) @@ -376,7 +384,7 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, for param_group in self.optimizer.param_groups: param_group['lr'] = self.learningRate - shuffled = list(range(Xtrain.shape[1])) + shuffled = list(range(self.numPoints)) np.random.shuffle(shuffled) trainAcc = 0.0 trainLoss = 0.0 @@ -389,7 +397,10 @@ def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest, (header, msg, header), file=self.outFile) k = shuffled[j * batchSize:(j + 1) * batchSize] - batchX = Xtrain[:, k, :] + if self.batch_first: + batchX = Xtrain[k, :, :] + else: + batchX = Xtrain[:, k, :] batchY = Ytrain[k] self.optimizer.zero_grad()