Skip to content

Commit 31d325e

Browse files
authored
Merge pull request #68 from prabhuomkar/blitz
[feature] Popular Tutorials, Deep Learning With PyTorch: A 60 Minute Blitz
2 parents c0d8c48 + e5e700e commit 31d325e

File tree

18 files changed

+726
-24
lines changed

18 files changed

+726
-24
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ cmake-build-*/
6969
# extern
7070
extern/*
7171
!extern/CMakeLists.txt
72+
.vscode/

CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ add_subdirectory("tutorials/advanced/variational_autoencoder")
6464
add_subdirectory("tutorials/advanced/neural_style_transfer")
6565
add_subdirectory("tutorials/advanced/image_captioning")
6666

67+
# Popular
68+
add_subdirectory("tutorials/popular/blitz/tensors")
69+
add_subdirectory("tutorials/popular/blitz/autograd")
70+
add_subdirectory("tutorials/popular/blitz/neural_networks")
71+
add_subdirectory("tutorials/popular/blitz/training_a_classifier")
72+
6773
if(MSVC)
6874
include(copy_torch_dlls)
6975
copy_torch_dlls(${EXECUTABLE_NAME})

README.md

+27-24
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,33 @@
2525
This repository provides tutorial code in C++ for deep learning researchers to learn PyTorch.
2626
**Python Tutorial**: [https://github.com/yunjey/pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial)
2727

28+
## Table of Contents
29+
30+
### 1. Basics
31+
* [PyTorch Basics](tutorials/basics/pytorch_basics/main.cpp)
32+
* [Linear Regression](tutorials/basics/linear_regression/main.cpp)
33+
* [Logistic Regression](tutorials/basics/logistic_regression/main.cpp)
34+
* [Feedforward Neural Network](tutorials/basics/feedforward_neural_network/src/main.cpp)
35+
36+
### 2. Intermediate
37+
* [Convolutional Neural Network](tutorials/intermediate/convolutional_neural_network/src/main.cpp)
38+
* [Deep Residual Network](tutorials/intermediate/deep_residual_network/src/main.cpp)
39+
* [Recurrent Neural Network](tutorials/intermediate/recurrent_neural_network/src/main.cpp)
40+
* [Bidirectional Recurrent Neural Network](tutorials/intermediate/bidirectional_recurrent_neural_network/src/main.cpp)
41+
* [Language Model (RNN-LM)](tutorials/intermediate/language_model/src/main.cpp)
42+
43+
### 3. Advanced
44+
* [Generative Adversarial Networks](tutorials/advanced/generative_adversarial_network/main.cpp)
45+
* [Variational Auto-Encoder](tutorials/advanced/variational_autoencoder/src/main.cpp)
46+
* [Neural Style Transfer](tutorials/advanced/neural_style_transfer/src/main.cpp)
47+
* [Image Captioning (CNN-AttentionRNN)](tutorials/advanced/image_captioning/src/main.cpp)
48+
49+
### 4. Interactive Tutorials
50+
* [Tensor Slicing](notebooks/tensor_slicing.ipynb)
51+
52+
### 5. Other Popular Tutorials
53+
* [Deep Learning with PyTorch: A 60 Minute Blitz](tutorials/popular/blitz)
54+
2855
# Getting Started
2956

3057
## Requirements
@@ -157,29 +184,5 @@ You can build and run the tutorials (on CPU) in a Docker container using the pro
157184
```
158185
This will - if necessary - build all tutorials and then start the provided tutorial in a container.
159186

160-
## Table of Contents
161-
162-
### 1. Basics
163-
* [PyTorch Basics](tutorials/basics/pytorch_basics/main.cpp)
164-
* [Linear Regression](tutorials/basics/linear_regression/main.cpp)
165-
* [Logistic Regression](tutorials/basics/logistic_regression/main.cpp)
166-
* [Feedforward Neural Network](tutorials/basics/feedforward_neural_network/src/main.cpp)
167-
168-
### 2. Intermediate
169-
* [Convolutional Neural Network](tutorials/intermediate/convolutional_neural_network/src/main.cpp)
170-
* [Deep Residual Network](tutorials/intermediate/deep_residual_network/src/main.cpp)
171-
* [Recurrent Neural Network](tutorials/intermediate/recurrent_neural_network/src/main.cpp)
172-
* [Bidirectional Recurrent Neural Network](tutorials/intermediate/bidirectional_recurrent_neural_network/src/main.cpp)
173-
* [Language Model (RNN-LM)](tutorials/intermediate/language_model/src/main.cpp)
174-
175-
### 3. Advanced
176-
* [Generative Adversarial Networks](tutorials/advanced/generative_adversarial_network/main.cpp)
177-
* [Variational Auto-Encoder](tutorials/advanced/variational_autoencoder/src/main.cpp)
178-
* [Neural Style Transfer](tutorials/advanced/neural_style_transfer/src/main.cpp)
179-
* [Image Captioning (CNN-AttentionRNN)](tutorials/advanced/image_captioning/src/main.cpp)
180-
181-
### 4. Interactive Tutorials
182-
* [Tensor Slicing](notebooks/tensor_slicing.ipynb)
183-
184187
## License
185188
This repository is licensed under MIT as given in [LICENSE](LICENSE).

tutorials/popular/blitz/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Deep Learning with PyTorch: A 60 Minute Blitz
2+
3+
1. *[tensors](tutorials/popular/blitz/tensors)*: What is PyTorch?
4+
https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html
5+
6+
2. *[autograd](tutorials/popular/blitz/autograd)*: Autograd: Automatic Differentiation
7+
https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
8+
9+
3. *[neural_networks](tutorials/popular/blitz/neural_networks)*: Neural Networks
10+
https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
2+
3+
project(autograd VERSION 1.0.0 LANGUAGES CXX)
4+
5+
if(NOT Torch_FOUND)
6+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake")
7+
find_package(Torch REQUIRED PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch")
8+
endif()
9+
10+
set(EXECUTABLE_NAME autograd)
11+
12+
add_executable(${EXECUTABLE_NAME})
13+
target_sources(${EXECUTABLE_NAME} PRIVATE main.cpp)
14+
15+
target_link_libraries(${EXECUTABLE_NAME} ${TORCH_LIBRARIES})
16+
17+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
18+
CXX_STANDARD 14
19+
CXX_STANDARD_REQUIRED YES
20+
)
21+
22+
if(MSVC)
23+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR})
24+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_CURRENT_BINARY_DIR})
25+
include(copy_torch_dlls)
26+
copy_torch_dlls(${EXECUTABLE_NAME})
27+
endif(MSVC)
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
6+
int main() {
7+
std::cout << "Deep Learning with PyTorch: A 60 Minute Blitz\n\n";
8+
std::cout << "Autograd: Automatic Differentiation\n\n";
9+
10+
std::cout << "Tensor\n\n";
11+
12+
// Create a tensor and set requires_grad=True to track computation with it:
13+
auto x = torch::ones({2, 2}, torch::TensorOptions().requires_grad(true));
14+
std::cout << "x:\n" << x << '\n';
15+
16+
// Do a tensor operation:
17+
auto y = x + 2;
18+
std::cout << "y:\n" << y << '\n';
19+
20+
// y was created as a result of an operation, so it has a grad_fn:
21+
std::cout << "y.grad_fn:\n" << y.grad_fn() << '\n';
22+
23+
// Do more operations on y:
24+
auto z = y * y * 3;
25+
auto out = z.mean();
26+
std::cout << "z:\n" << z << "out:\n" << out << '\n';
27+
28+
// .requires_grad_(...) changes an existing Tensor’s requires_grad flag in-place:
29+
auto a = torch::randn({2, 2});
30+
a = ((a * 3) / (a - 1));
31+
std::cout << a.requires_grad() << '\n';
32+
a.requires_grad_(true);
33+
std::cout << a.requires_grad() << '\n';
34+
auto b = (a * a).sum();
35+
std::cout << b.grad_fn() << '\n';
36+
37+
std::cout << "Gradients\n\n";
38+
39+
// Let’s backprop now:
40+
out.backward();
41+
42+
// Print gradients d(out)/dx:
43+
std::cout << "x.grad:\n" << x.grad() << '\n';
44+
45+
// Example of vector-Jacobian product:
46+
x = torch::randn(3, torch::TensorOptions().requires_grad(true));
47+
y = x * 2;
48+
while (y.data().norm().item<int>() < 1000) {
49+
y = y * 2;
50+
}
51+
std::cout << "y:\n" << y << '\n';
52+
53+
// Simply pass the vector to backward as argument:
54+
auto v = torch::tensor({0.1, 1.0, 0.0001}, torch::TensorOptions(torch::kFloat));
55+
y.backward(v);
56+
std::cout << "x.grad:\n" << x.grad() << '\n';
57+
58+
// Stop autograd from tracking history on Tensors with .requires_grad=True:
59+
std::cout << "x.requires_grad\n" << x.requires_grad() << '\n';
60+
std::cout << "(x ** 2).requires_grad\n" << (x * x).requires_grad() << '\n';
61+
torch::NoGradGuard no_grad;
62+
std::cout << "(x ** 2).requires_grad\n" << (x * x).requires_grad() << '\n';
63+
64+
// Or by using .detach() to get a new Tensor with the same content but that does not require gradients:
65+
std::cout << "x.requires_grad:\n" << x.requires_grad() << '\n';
66+
y = x.detach();
67+
std::cout << "y.requires_grad:\n" << y.requires_grad() << '\n';
68+
std::cout << "x.eq(y).all():\n" << x.eq(y).all() << '\n';
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
2+
3+
project(neural-networks VERSION 1.0.0 LANGUAGES CXX)
4+
5+
if(NOT Torch_FOUND)
6+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake")
7+
find_package(Torch REQUIRED PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch")
8+
endif()
9+
10+
set(EXECUTABLE_NAME neural-networks)
11+
12+
add_executable(${EXECUTABLE_NAME})
13+
target_sources(${EXECUTABLE_NAME} PRIVATE src/main.cpp
14+
src/nnet.cpp
15+
include/nnet.h
16+
)
17+
18+
target_include_directories(${EXECUTABLE_NAME} PRIVATE include)
19+
20+
target_link_libraries(${EXECUTABLE_NAME} ${TORCH_LIBRARIES})
21+
22+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
23+
CXX_STANDARD 14
24+
CXX_STANDARD_REQUIRED YES
25+
)
26+
27+
if(MSVC)
28+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR})
29+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_CURRENT_BINARY_DIR})
30+
include(copy_torch_dlls)
31+
copy_torch_dlls(${EXECUTABLE_NAME})
32+
endif(MSVC)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#pragma once
3+
4+
#include <torch/torch.h>
5+
6+
class NetImpl : public torch::nn::Module {
7+
public:
8+
NetImpl();
9+
torch::Tensor forward(torch::Tensor x);
10+
11+
torch::nn::Conv2d conv1;
12+
torch::nn::Conv2d conv2;
13+
torch::nn::Linear fc1;
14+
torch::nn::Linear fc2;
15+
torch::nn::Linear fc3;
16+
17+
private:
18+
int num_flat_features(torch::Tensor x);
19+
};
20+
21+
TORCH_MODULE(Net);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
#include <iomanip>
5+
#include "nnet.h"
6+
7+
int main() {
8+
std::cout << "Deep Learning with PyTorch: A 60 Minute Blitz\n\n";
9+
std::cout << "Neural Networks\n\n";
10+
11+
std::cout << "Define the network\n\n";
12+
Net net = Net();
13+
net->to(torch::kCPU);
14+
std::cout << net << "\n\n";
15+
16+
// The learnable parameters of a model are returned by net.parameters():
17+
auto params = net->parameters();
18+
std::cout << params.size() << '\n';
19+
std::cout << params.at(0).sizes() << "\n\n"; // conv1's .weight
20+
21+
// Let’s try a random 32x32 input:
22+
auto input = torch::randn({1, 1, 32, 32});
23+
auto out = net->forward(input);
24+
std::cout << out << "\n\n";
25+
26+
// Zero the gradient buffers of all parameters and backprops with random gradients:
27+
net->zero_grad();
28+
out.backward(torch::randn({1, 10}));
29+
30+
std::cout << "Loss Function\n\n";
31+
32+
auto output = net->forward(input);
33+
auto target = torch::randn(10); // a dummy target, for example
34+
target = target.view({1, -1}); // make it the same shape as output
35+
torch::nn::MSELoss criterion;
36+
auto loss = criterion(output, target);
37+
std::cout << loss << "\n\n";
38+
39+
// For illustration, let us follow a few steps backward:
40+
std::cout << "loss.grad_fn:\n" << loss.grad_fn() << '\n'; // MSELoss
41+
42+
std::cout << "Backprop\n\n";
43+
44+
// Now we shall call loss.backward(), and have a look at conv1’s bias gradients before and after the backward:
45+
net->zero_grad(); // zeroes the gradient buffers of all parameters
46+
std::cout << "conv1.bias.grad before backward:\n" << net->conv1->bias.grad() << '\n';
47+
loss.backward();
48+
std::cout << "conv1.bias.grad after backward:\n" << net->conv1->bias.grad() << "\n\n";
49+
50+
std::cout << "Update the weights\n\n";
51+
52+
// create your optimizer
53+
auto learning_rate = 0.01;
54+
auto optimizer = torch::optim::SGD(net->parameters(), torch::optim::SGDOptions(learning_rate));
55+
// in your training loop:
56+
optimizer.zero_grad(); // zero the gradient buffers
57+
output = net->forward(input);
58+
loss = criterion(output, target);
59+
loss.backward();
60+
optimizer.step(); // Does the update
61+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2020-present pytorch-cpp Authors
2+
#include "nnet.h"
3+
#include <torch/torch.h>
4+
5+
NetImpl::NetImpl() :
6+
conv1(torch::nn::Conv2dOptions(1, 6, 3)),
7+
conv2(torch::nn::Conv2dOptions(6, 16, 3)),
8+
fc1(torch::nn::LinearOptions(16 * 6 * 6, 120)),
9+
fc2(torch::nn::LinearOptions(120, 84)),
10+
fc3(torch::nn::LinearOptions(84, 10)) {
11+
register_module("conv1", conv1);
12+
register_module("conv2", conv2);
13+
register_module("fc1", fc1);
14+
register_module("fc2", fc2);
15+
register_module("fc3", fc3);
16+
}
17+
18+
int NetImpl::num_flat_features(torch::Tensor x) {
19+
auto sz = x.sizes().slice(1); // all dimensions except the batch dimension
20+
int num_features = 1;
21+
for (auto s : sz) {
22+
num_features *= s;
23+
}
24+
return num_features;
25+
}
26+
27+
torch::Tensor NetImpl::forward(torch::Tensor x) {
28+
// Max pooling over a (2, 2) window
29+
auto out = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions({2, 2}))->forward(torch::relu(conv1->forward(x)));
30+
// If the size is a square you can only specify a single number
31+
out = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2))->forward(torch::relu(conv2->forward(out)));
32+
out = out.view({-1, num_flat_features(out)});
33+
out = torch::relu(fc1->forward(out));
34+
out = torch::relu(fc2->forward(out));
35+
return fc3->forward(out);
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
2+
3+
project(tensors VERSION 1.0.0 LANGUAGES CXX)
4+
5+
if(NOT Torch_FOUND)
6+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../cmake")
7+
find_package(Torch REQUIRED PATHS "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch")
8+
endif()
9+
10+
set(EXECUTABLE_NAME tensors)
11+
12+
add_executable(${EXECUTABLE_NAME})
13+
target_sources(${EXECUTABLE_NAME} PRIVATE main.cpp)
14+
15+
target_link_libraries(${EXECUTABLE_NAME} ${TORCH_LIBRARIES})
16+
17+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES
18+
CXX_STANDARD 14
19+
CXX_STANDARD_REQUIRED YES
20+
)
21+
22+
if(MSVC)
23+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_CURRENT_BINARY_DIR})
24+
set_target_properties(${EXECUTABLE_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_CURRENT_BINARY_DIR})
25+
include(copy_torch_dlls)
26+
copy_torch_dlls(${EXECUTABLE_NAME})
27+
endif(MSVC)

0 commit comments

Comments
 (0)