-
Notifications
You must be signed in to change notification settings - Fork 100
Generic flatten (2d and 3d) #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,6 @@ license = "MIT" | |
| author = "Milan Curcic" | ||
| maintainer = "[email protected]" | ||
| copyright = "Copyright 2018-2025, neural-fortran contributors" | ||
|
|
||
| [preprocess] | ||
| [preprocess.cpp] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,20 @@ module nf_flatten_layer | |
| integer, allocatable :: input_shape(:) | ||
| integer :: output_size | ||
|
|
||
| real, allocatable :: gradient(:,:,:) | ||
| real, allocatable :: gradient_2d(:,:) | ||
| real, allocatable :: gradient_3d(:,:,:) | ||
| real, allocatable :: output(:) | ||
|
|
||
| contains | ||
|
|
||
| procedure :: backward | ||
| procedure :: forward | ||
| procedure :: backward_2d | ||
| procedure :: backward_3d | ||
| generic :: backward => backward_2d, backward_3d | ||
|
|
||
| procedure :: forward_2d | ||
| procedure :: forward_3d | ||
| generic :: forward => forward_2d, forward_3d | ||
|
||
|
|
||
| procedure :: init | ||
|
|
||
| end type flatten_layer | ||
|
|
@@ -39,26 +46,47 @@ end function flatten_layer_cons | |
|
|
||
| interface | ||
|
|
||
| pure module subroutine backward(self, input, gradient) | ||
| !! Apply the backward pass to the flatten layer. | ||
| pure module subroutine backward_2d(self, input, gradient) | ||
| !! Apply the backward pass to the flatten layer for 2D input. | ||
| !! This is a reshape operation from 1-d gradient to 2-d input. | ||
| class(flatten_layer), intent(in out) :: self | ||
| !! Flatten layer instance | ||
| real, intent(in) :: input(:,:) | ||
| !! Input from the previous layer | ||
| real, intent(in) :: gradient(:) | ||
| !! Gradient from the next layer | ||
| end subroutine backward_2d | ||
|
|
||
| pure module subroutine backward_3d(self, input, gradient) | ||
| !! Apply the backward pass to the flatten layer for 3D input. | ||
| !! This is a reshape operation from 1-d gradient to 3-d input. | ||
| class(flatten_layer), intent(in out) :: self | ||
| !! Flatten layer instance | ||
| real, intent(in) :: input(:,:,:) | ||
| !! Input from the previous layer | ||
| real, intent(in) :: gradient(:) | ||
| !! Gradient from the next layer | ||
| end subroutine backward | ||
| end subroutine backward_3d | ||
|
|
||
| pure module subroutine forward_2d(self, input) | ||
| !! Propagate forward the layer for 2D input. | ||
| !! Calling this subroutine updates the values of a few data components | ||
| !! of `flatten_layer` that are needed for the backward pass. | ||
| class(flatten_layer), intent(in out) :: self | ||
| !! Dense layer instance | ||
| real, intent(in) :: input(:,:) | ||
| !! Input from the previous layer | ||
| end subroutine forward_2d | ||
|
|
||
| pure module subroutine forward(self, input) | ||
| !! Propagate forward the layer. | ||
| pure module subroutine forward_3d(self, input) | ||
| !! Propagate forward the layer for 3D input. | ||
| !! Calling this subroutine updates the values of a few data components | ||
| !! of `flatten_layer` that are needed for the backward pass. | ||
| class(flatten_layer), intent(in out) :: self | ||
| !! Dense layer instance | ||
| real, intent(in) :: input(:,:,:) | ||
| !! Input from the previous layer | ||
| end subroutine forward | ||
| end subroutine forward_3d | ||
|
|
||
| module subroutine init(self, input_shape) | ||
| !! Initialize the layer data structures. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I thought about that but decided not to make the code even less SOLID
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But here we have a choice between SOLID and less boilerplate, I think I agree that the second one is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and, most importantly for me, this approach allows for a unified API (only one
flatten()for the user).