Skip to content

A tiny deep learning training framework implemented from scratch in C++ that follows PyTorch's API.

License

Notifications You must be signed in to change notification settings

keith2018/TinyTorch

Repository files navigation

TinyTorch

Tiny deep learning training framework implemented from scratch in C++ that follows PyTorch's API.

CMake Linux CMake MacOS CMake Windows

Components

  • Module
    • Linear
    • Conv2D
    • MaxPool2D
    • Dropout
    • Softmax
    • LogSoftmax
    • Relu
    • Sequential
  • Loss
    • MSELoss
    • NLLLoss
  • Optimizer
    • SGD
    • Adagrad
    • RMSprop
    • AdaDelta
    • Adam
    • AdamW
  • Data
    • Dataset
    • DataLoader
    • Transform

Automatic differentiation

MNIST training demo:

#include "Torch.h"

using namespace TinyTorch;

class Net : public nn::Module {
 public:
  Net() { registerModules({conv1, conv2, dropout1, dropout2, fc1, fc2}); }

  Tensor forward(Tensor &x) override {
    x = conv1(x);
    x = Function::relu(x);
    x = conv2(x);
    x = Function::relu(x);
    x = Function::maxPool2d(x, 2);
    x = dropout1(x);
    x = Tensor::flatten(x, 1);
    x = fc1(x);
    x = Function::relu(x);
    x = dropout2(x);
    x = fc2(x);
    x = Function::logSoftmax(x, 1);
    return x;
  }

 private:
  nn::Conv2D conv1{1, 32, 3, 1};
  nn::Conv2D conv2{32, 64, 3, 1};
  nn::Dropout dropout1{0.25};
  nn::Dropout dropout2{0.5};
  nn::Linear fc1{9216, 128};
  nn::Linear fc2{128, 10};
};

void train(nn::Module &model, data::DataLoader &dataLoader,
           optim::Optimizer &optimizer, int32_t epoch) {
  model.train();
  for (auto [batchIdx, batch] : dataLoader) {
    auto &data = batch[0];
    auto &target = batch[1];
    optimizer.zeroGrad();
    auto output = model(data);
    auto loss = Function::nllloss(output, target);
    loss.backward();
    optimizer.step();

    auto currDataCnt = batchIdx * dataLoader.batchSize() + data.shape()[0];
    auto totalDataCnt = dataLoader.dataset().size();
    LOGD("Train Epoch: %d [%d/%d %.2f%%], loss: %.6f", epoch, currDataCnt,
         totalDataCnt, 100.f * currDataCnt / (float)totalDataCnt, loss.item());
  }
}

void test(nn::Module &model, data::DataLoader &dataLoader) {
  model.eval();
  auto total = 0;
  auto correct = 0;
  withNoGrad {
    for (auto [batchIdx, batch] : dataLoader) {
      auto &data = batch[0];
      auto &target = batch[1];
      auto output = model(data);
      total += target.shape()[0];
      auto pred = output.data().argmax(1, true);
      correct += (int32_t)(pred == target.data().view(pred.shape())).sum();

      auto currDataCnt = batchIdx * dataLoader.batchSize() + data.shape()[0];
      auto totalDataCnt = dataLoader.dataset().size();
      LOGD("Test [%d/%d %.2f%%], Accuracy: [%d/%d (%.2f%%)]", currDataCnt,
           totalDataCnt, 100.f * currDataCnt / (float)totalDataCnt, correct,
           total, 100. * correct / (float)total);
    }
  }
}

void demo_mnist() {
  LOGD("demo_mnist ...");
  Timer timer;
  timer.start();

  // config
  auto lr = 1.f;
  auto epochs = 2;
  auto batchSize = 64;

  auto transform = std::make_shared<data::transforms::Compose>(
      data::transforms::Normalize(0.1307f, 0.3081f));

  auto dataDir = "./data/";
  auto trainDataset = std::make_shared<data::DatasetMNIST>(
      dataDir, data::DatasetMNIST::TRAIN, transform);
  auto testDataset = std::make_shared<data::DatasetMNIST>(
      dataDir, data::DatasetMNIST::TEST, transform);

  LOGD("train size: %d", trainDataset->size());
  LOGD("test size: %d", testDataset->size());

  auto trainDataloader = data::DataLoader(trainDataset, batchSize, true);
  auto testDataloader = data::DataLoader(testDataset, batchSize, true);

  auto model = Net();
  auto optimizer = optim::Adadelta(model.parameters(), lr);
  auto scheduler = optim::lr_scheduler::StepLR(optimizer, 1, 0.7f);

  for (auto epoch = 0; epoch < epochs; epoch++) {
    train(model, trainDataloader, optimizer, epoch);
    test(model, testDataloader);
    scheduler.step();

    std::ostringstream saveName;
    saveName << "mnist_cnn_epoch_" << epoch << ".model";
    save(model, saveName.str().c_str());
  }

  timer.stop();
  LOGD("Time cost: %lld ms", timer.elapseMillis());
}

Build

mkdir build
cmake -B ./build -DCMAKE_BUILD_TYPE=Release
cmake --build ./build --config Release

Demo

cd demo/bin
./TinyTorch_demo

Test

cd build
ctest

Dependencies

License

This code is licensed under the MIT License (see LICENSE).

About

A tiny deep learning training framework implemented from scratch in C++ that follows PyTorch's API.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published