-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeRules #103
Draft
sethaxen
wants to merge
4
commits into
JuliaMath:master
Choose a base branch
from
sethaxen:enzyme
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add EnzymeRules #103
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
module AbstractFFTsEnzymeCoreExt | ||
|
||
using AbstractFFTs | ||
using AbstractFFTs.LinearAlgebra | ||
using EnzymeCore | ||
using EnzymeCore.EnzymeRules | ||
|
||
###################### | ||
# Forward-mode rules # | ||
###################### | ||
|
||
const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}} | ||
|
||
# since FFTs are linear, implement all forward-model rules generically at a low-level | ||
|
||
function EnzymeRules.forward( | ||
func::Const{typeof(mul!)}, | ||
RT::Type{<:Const}, | ||
y::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, | ||
p::Const{<:AbstractFFTs.Plan{T}}, | ||
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, | ||
) where {T} | ||
val = func.val(y.val, p.val, x.val) | ||
if x isa Duplicated && y isa Duplicated | ||
dval = func.val(y.dval, p.val, x.dval) | ||
elseif x isa Duplicated && y isa Duplicated | ||
dval = map(y.dval, x.dval) do dy, dx | ||
return func.val(dy, p.val, dx) | ||
end | ||
end | ||
return nothing | ||
end | ||
|
||
function EnzymeRules.forward( | ||
func::Const{typeof(*)}, | ||
RT::Type{ | ||
<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed} | ||
}, | ||
p::Const{<:AbstractFFTs.Plan}, | ||
x::DuplicatedOrBatchDuplicated{<:StridedArray}, | ||
) | ||
RT <: Const && return func.val(p.val, x.val) | ||
if x isa Duplicated | ||
dval = func.val(p.val, x.dval) | ||
RT <: DuplicatedNoNeed && return dval | ||
val = func.val(p.val, x.val) | ||
RT <: Duplicated && return Duplicated(val, dval) | ||
else # x isa BatchDuplicated | ||
dval = map(x.dval) do dx | ||
return func.val(p.val, dx) | ||
end | ||
RT <: BatchDuplicatedNoNeed && return dval | ||
val = func.val(p.val, x.val) | ||
RT <: BatchDuplicated && return BatchDuplicated(val, dval) | ||
end | ||
end | ||
|
||
end # module |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish the type
T
can be restricted to a finite set, e.g. BLAS number types, otherwise, it may produce incorrect gradients for user defined extensions. Generally speaking, I feel "generic" AD is not a good practise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pushforward of a linear operator is always itself. And so far as I know, every definition of an FFT is a linear operator. So I can see no reasons why this rule should be problematic for forward-mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, I may want to extended FFT with tropical numbers, which is not a real number. It is linear, but does not have an inverse. Then your rule would give me incorrect gradients without throwing an error. I have seen too many incorrect gradients in previous AD frameworks such as Zygote when handling complex numbers.
I agree it is good to have a generic backward routine there, but please constraint the interfaces to concrete types when porting it to an AD engine. It should not be so difficult for users to extend the list of supported types in the future. Defining fft rules on BLAS types would be good enough to cover most using cases. For those non-BLAS types, honestly we can not make any assumption for them. Julia community needs an AD engine with provable correctness, I think it is also one of the goals of Enzyme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since Julia does not have a good trait system, I think it is in general impossible to restrict users to input what the functions are designed for. This is what I meant there lacks provable correctness.
It has been a big issue that none of the Julia libraries (except Enzyme) can provide reliable gradients. They claim too much on untested using cases, like complex numbers and tropical numbers. There has been a belief that "it is cool if the code works in cases that it is not expected to work". But no, untested rules are not reliable, they can break on any future change even it works now. Rules must be concrete and tested, they are easy to extend, but hard to debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By that argument, no AD rules should be defined here anyways, since downstream a user could define a custom Plan that doesn't do any kind of FFT at all. Then even with BLAS number types and strides arrays, any rule we write here would be wrong.
The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.
Taken to its logical conclusion, wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A big YES. I do not think many people need the backward rules for non-BLAS types. You may want to support e.g. double float that defined in
DoubleFloat.jl
. I would argue in these using cases, users can port the generic rule to the AD framework with little effort. The rule can be generic, but when porting it to the AD framework, it should be concrete.We have to decide between support more data types and ensure the correctness. I really wish there can be a trait system that user can tell the compiler "this element type is a field", then users can use the rule with more confidence. Facts obvious to you, like "fft should work on field rather than other rings" may not be obvious to others.
To differentiate a long code, I will let the code fly and see where it falls. I will add new rules to the AD engine to keep it flying. It is not a problem for me if a rule does not exist. So when using a new element type, like complex number, symbolic type, finite field algebra or the Tropical number type as mentioned above, I will probably not check whether the property of each function is as documented.
A warning will be thrown when overloading an existing function. Also, pirating is not difficult to avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case, if we have ChainRules I think we should have the corresponding EnzymeRules.
If users make the questionable choice of overriding
fft
to compute an unrelated function, then it is up to them to override the EnzymeRules/ChainRules as well.