-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optimisers #320
Draft
jatkinson1000
wants to merge
4
commits into
main
Choose a base branch
from
optim
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add optimisers #320
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
dcdcf6b
Add optimisers example. README, requirements, and python version of t…
jatkinson1000 3e78599
Add a Fortran equivalent of the optimisers example WIP.
jatkinson1000 ef302d9
Add boilerplate ftorch_optim.F90 module and C++ ready for optimisers.
jatkinson1000 6a28918
Add optim module to fortitude linting in static analysis workflow.
jatkinson1000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
cmake_minimum_required(VERSION 3.15...3.31) | ||
# policy CMP0076 - target_sources source files are relative to file where | ||
# target_sources is run | ||
cmake_policy(SET CMP0076 NEW) | ||
|
||
set(PROJECT_NAME OptimisersExample) | ||
|
||
project(${PROJECT_NAME} LANGUAGES Fortran) | ||
|
||
# Build in Debug mode if not specified | ||
if(NOT CMAKE_BUILD_TYPE) | ||
set(CMAKE_BUILD_TYPE | ||
Debug | ||
CACHE STRING "" FORCE) | ||
endif() | ||
|
||
find_package(FTorch) | ||
message(STATUS "Building with Fortran PyTorch coupling") | ||
|
||
# Fortran example | ||
add_executable(optimisers optimisers.f90) | ||
target_link_libraries(optimisers PRIVATE FTorch::ftorch) | ||
|
||
# Integration testing | ||
if(CMAKE_BUILD_TESTS) | ||
include(CTest) | ||
|
||
# 1. Check the Python Optimisers script runs successfully | ||
add_test(NAME pyoptim COMMAND ${Python_EXECUTABLE} | ||
${PROJECT_SOURCE_DIR}/optimisers.py) | ||
|
||
# 2. Check the Fortran Optimisers script runs successfully | ||
add_test( | ||
NAME foptim | ||
COMMAND optimisers | ||
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) | ||
set_tests_properties(foptim PROPERTIES PASS_REGULAR_EXPRESSION | ||
"Optimisers example ran successfully") | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Example n - Optimisers | ||
|
||
**This example is currently under development.** Eventually, it will demonstrate | ||
the use of optimisers in FTorch by leveraging PyTorch's optim module. | ||
|
||
By exposing optimisers in Fortran, FTorch will be able to compute optimisation | ||
steps to update models as part of a training process. | ||
|
||
## Description | ||
|
||
A Python demo is copied from the PyTorch documentation as `optimisers.py`, which | ||
shows how to use an optimiser in PyTorch. | ||
|
||
The demo will be replicated in Fortran as `optimisers.f90`, to show how to do the | ||
same thing using FTorch. | ||
|
||
## Dependencies | ||
|
||
To run this example requires: | ||
|
||
- CMake | ||
- Fortran compiler | ||
- FTorch (installed as described in main package) | ||
- Python 3 | ||
|
||
## Running | ||
|
||
To run this example install FTorch as described in the main documentation. | ||
Then from this directory create a virtual environment and install the necessary | ||
Python modules: | ||
``` | ||
python3 -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Run the Python version of the demo with | ||
``` | ||
python3 optimisers.py | ||
``` | ||
This trains a tensor to scale, elementwise, a vector of ones to the vector `[1, 2, 3, 4]`. | ||
It uses the torch SGD optimiser to adjust the values of the scaling tensor at each step, | ||
outputting values of interest to screen in the form: | ||
```console | ||
======================== | ||
Epoch: 0 | ||
Output: | ||
tensor([1., 1., 1., 1.], grad_fn=<MulBackward0>) | ||
loss: | ||
3.5 | ||
tensor gradient: | ||
tensor([ 0.0000, -0.5000, -1.0000, -1.5000]) | ||
tensor: | ||
tensor([1.0000, 1.5000, 2.0000, 2.5000], requires_grad=True) | ||
... | ||
``` | ||
|
||
To run the Fortran version of the demo we need to compile with (for example) | ||
``` | ||
mkdir build | ||
cd build | ||
cmake .. -DCMAKE_PREFIX_PATH=<path/to/your/installation/of/library/> -DCMAKE_BUILD_TYPE=Release | ||
cmake --build . | ||
``` | ||
|
||
(Note that the Fortran compiler can be chosen explicitly with the `-DCMAKE_Fortran_COMPILER` flag, | ||
and should match the compiler that was used to locally build FTorch.) | ||
|
||
To run the compiled code, simply use | ||
``` | ||
./optimisers | ||
``` | ||
Currently, the example constructs Torch Tensors and iterates over a training loop, | ||
computing a loss with each iteration. | ||
It does not yet implement an optimiser or step to update the scaling tensor. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
program example | ||
|
||
! Import precision info from iso | ||
use, intrinsic :: iso_fortran_env, only : sp => real32 | ||
|
||
! Import c_int64_t | ||
use, intrinsic :: iso_c_binding, only: c_int64_t | ||
|
||
! Import our library for interfacing with PyTorch's Autograd module | ||
use ftorch, only: assignment(=), operator(-), operator(*), operator(/), operator(**), & | ||
torch_kCPU, torch_kFloat32, & | ||
torch_tensor, torch_tensor_from_array, & | ||
torch_tensor_ones, torch_tensor_empty, & | ||
torch_tensor_print, torch_delete | ||
|
||
implicit none | ||
|
||
! Set working precision for reals | ||
integer, parameter :: wp = sp | ||
|
||
! Set up Fortran data structures | ||
integer, parameter :: ndims = 1 | ||
integer, parameter :: n=4 | ||
real(wp), dimension(n), target :: input_data, output_data, target_data | ||
integer :: tensor_layout(ndims) = [1] | ||
|
||
! Set up Torch data structures | ||
integer(c_int64_t), dimension(1), parameter :: tensor_shape = [4] | ||
type(torch_tensor) :: input_vec, output_vec, target_vec, scaling_tensor, loss, torch_4p0 | ||
|
||
! Set up training parameters | ||
integer :: i | ||
integer, parameter :: n_train = 15 | ||
integer, parameter :: n_print = 1 | ||
|
||
! Initialise Torch Tensors from input/target arrays as in Python example | ||
input_data = [1.0_wp, 1.0_wp, 1.0_wp, 1.0_wp] | ||
target_data = [1.0_wp, 2.0_wp, 3.0_wp, 4.0_wp] | ||
call torch_tensor_from_array(input_vec, input_data, tensor_layout, torch_kCPU) | ||
call torch_tensor_from_array(target_vec, target_data, tensor_layout, torch_kCPU) | ||
|
||
! Initialise Scaling tensor as ones as in Python example | ||
call torch_tensor_ones(scaling_tensor, ndims, tensor_shape, & | ||
torch_kFloat32, torch_kCPU, requires_grad=.true.) | ||
|
||
! Initialise scaling factor of 4.0 for use in tensor operations | ||
call torch_tensor_from_array(torch_4p0, [4.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.) | ||
|
||
! Initialise an optimiser and apply it to scaling_tensor | ||
! TODO | ||
|
||
! Conduct training loop | ||
do i = 1, n_train+1 | ||
! Zero any previously stored gradients ready for a new iteration | ||
! TODO: implement equivalent to optimizer.zero_grad() | ||
|
||
! Forward pass: multiply the input of ones by the tensor (elementwise) | ||
call torch_tensor_from_array(output_vec, output_data, tensor_layout, torch_kCPU) | ||
output_vec = input_vec * scaling_tensor | ||
|
||
! Create an empty loss tensor and populate with mean square error (MSE) between target and input | ||
! Then perform backward step on loss to propogate gradients using autograd | ||
! | ||
! We could use the following lines to do this by explicitly specifying a | ||
! gradient of ones to start the process: | ||
call torch_tensor_empty(loss, ndims, tensor_shape, & | ||
torch_kFloat32, torch_kCPU) | ||
loss = ((output_vec - target_vec) ** 2) / torch_4p0 | ||
! TODO: add in backpropogation functionality for loss.backward(gradient=torch.ones(4)) | ||
! | ||
! However, we can avoid explicitly passing an initial gradient and instead do this | ||
! implicitly by aggregating the loss vector into a scalar value: | ||
! TODO: Requires addition of `.mean()` to the FTorch tensor API | ||
! loss = ((output - target_vec) ** 2).mean() | ||
! loss.backward() | ||
|
||
! Step the optimiser to update the values in `tensor` | ||
! TODO Add step functionality to optimisers for optimizer.step() | ||
|
||
if (modulo(i,n_print) == 0) then | ||
write(*,*) "================================================" | ||
write(*,*) "Epoch: ", i | ||
write(*,*) | ||
write(*,*) "Output:", output_data | ||
write(*,*) | ||
write(*,*) "loss:" | ||
call torch_tensor_print(loss) | ||
write(*,*) | ||
write(*,*) "tensor gradient: TODO: scaling_tensor.grad" | ||
write(*,*) | ||
write(*,*) "scaling_tensor:" | ||
call torch_tensor_print(scaling_tensor) | ||
write(*,*) | ||
end if | ||
|
||
! Clean up created tensors | ||
call torch_delete(output_vec) | ||
call torch_delete(loss) | ||
|
||
end do | ||
|
||
write(*,*) "Training complete." | ||
|
||
write (*,*) "Optimisers example ran successfully" | ||
|
||
end program example |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Optimisers demo.""" | ||
|
||
import torch | ||
|
||
# We define: | ||
# - the input as as a vector of ones, | ||
# - the target as a vector where each element is the index value, | ||
# - a tensor to transform from input to target by elementwise multiplication | ||
# initialised as a vector of ones | ||
# This is a contrived example, but provides a simple demo of optimiser functionality | ||
input_vec = torch.ones(4) | ||
target_vec = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
scaling_tensor = torch.ones(4, requires_grad=True) | ||
|
||
# Set the optimiser as torch's stochastic gradient descent (SGD) | ||
# The parameters to tune will be the values of `tensor`, and we also set a learning rate | ||
# Since this is a simple elemetwise example we can get away with a large learning rate | ||
optimizer = torch.optim.SGD([scaling_tensor], lr=1.0) | ||
|
||
# Training loop | ||
# Run n_iter times printing every n_print steps | ||
n_iter = 15 | ||
n_print = 1 | ||
for epoch in range(n_iter + 1): | ||
# Zero any previously stored gradients ready for a new iteration | ||
optimizer.zero_grad() | ||
|
||
# Forward pass: multiply the input of ones by the tensor (elementwise) | ||
output = input_vec * scaling_tensor | ||
|
||
# Create a loss tensor as computed mean square error (MSE) between target and input | ||
# Then perform backward step on loss to propogate gradients using autograd | ||
# | ||
# We could use the following 2 lines to do this by explicitly specifying a | ||
# gradient of ones to start the process: | ||
# loss = ((output - target) ** 2) / 4.0 | ||
# loss.backward(gradient=torch.ones(4)) | ||
# | ||
# However, we can avoid explicitly passing an initial gradient and instead do this | ||
# implicitly by aggregating the loss vector into a scalar value: | ||
loss = ((output - target_vec) ** 2).mean() | ||
loss.backward() | ||
|
||
# Step the optimiser to update the values in `tensor` | ||
optimizer.step() | ||
|
||
if (epoch) % n_print == 0: | ||
print(f"========================") | ||
print(f"Epoch: {epoch}") | ||
print(f"\tOutput:\n\t\t{output}") | ||
print(f"\tloss:\n\t\t{loss}") | ||
print(f"\ttensor gradient:\n\t\t{scaling_tensor.grad}") | ||
print(f"\tscaling_tensor:\n\t\t{scaling_tensor}") | ||
|
||
print("Training complete.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
!| Optimisers module for FTorch. | ||
! | ||
! * License | ||
! FTorch is released under an MIT license. | ||
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) | ||
! file for details. | ||
|
||
module ftorch_optim | ||
|
||
use, intrinsic :: iso_c_binding, only: c_associated, c_null_ptr, c_ptr | ||
use, intrinsic :: iso_fortran_env, only: int32 | ||
|
||
use ftorch, only: ftorch_int | ||
|
||
implicit none | ||
|
||
public | ||
|
||
! ============================================================================ | ||
! --- | ||
! ============================================================================ | ||
|
||
|
||
end module ftorch_optim |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point it might make sense to separate out modules for
tensor
,model
,optim
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, since each optimiser (Adam, SGD, etc.) seems to be its own function I thought I'd do this here, for the Fortran at least.
I had been hoping that there was a general optimiser function specified by an enum, but alas no.
Others feel OK for now, but definitely something I was thinking about, especially as module functions might grow soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. This was more in anticipation of when the torch model/module API starts growing.