Skip to content

Commit 2bc92ce

Browse files
authored
fix: add Prep supertype to the public interface (#875)
* fix: add Prep supertype to the public interface * Add test * Fix prefix * Fix * Import subtypes
1 parent 7a87b5f commit 2bc92ce

File tree

9 files changed

+58
-24
lines changed

9 files changed

+58
-24
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ Apart from the conditions above, this repository follows the [ColPrac](https://g
1212
Its code is formatted using [Runic.jl](https://github.com/fredrikekre/Runic.jl).
1313
As part of continuous integration, a set of formal tests is run using [pre-commit](https://pre-commit.com/).
1414
We invite you to install pre-commit so that these checks are performed locally before you open or update a pull request.
15-
You can refer to the [dev guide](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/dev_guide/) for details on the package structure and the testing pipeline.
15+
You can refer to the relevant page of the development documentation for details on the package structure and the testing pipeline.

DifferentiationInterface/docs/make.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ makedocs(;
3535
],
3636
"FAQ" => ["faq/limitations.md", "faq/differentiability.md"],
3737
"api.md",
38-
"dev_guide.md",
38+
"Development" => [
39+
"dev/internals.md",
40+
"dev/contributing.md",
41+
],
3942
],
4043
plugins = [links],
4144
)

DifferentiationInterface/docs/src/api.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,8 @@ DifferentiationInterface.AutoForwardFromPrimitive
139139
DifferentiationInterface.AutoReverseFromPrimitive
140140
```
141141

142-
## Internals
142+
### Preparation type
143143

144-
The following is not part of the public API.
145-
146-
```@autodocs
147-
Modules = [DifferentiationInterface]
148-
Public = false
149-
Filter = t -> !(Symbol(t) in [:outer, :inner])
144+
```@docs
145+
DifferentiationInterface.Prep
150146
```

DifferentiationInterface/docs/src/dev_guide.md renamed to DifferentiationInterface/docs/src/dev/contributing.md

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Dev guide
1+
# Contributing
22

33
This page is important reading if you want to contribute to DifferentiationInterface.jl.
44
It is not part of the public API and the content below may become outdated, in which case you should refer to the source code as the ground truth.
@@ -7,26 +7,27 @@ It is not part of the public API and the content below may become outdated, in w
77

88
The package is structured around 8 [operators](@ref Operators):
99

10-
- [`derivative`](@ref)
11-
- [`second_derivative`](@ref)
12-
- [`gradient`](@ref)
13-
- [`jacobian`](@ref)
14-
- [`hessian`](@ref)
15-
- [`pushforward`](@ref)
16-
- [`pullback`](@ref)
17-
- [`hvp`](@ref)
10+
- [`derivative`](@ref)
11+
- [`second_derivative`](@ref)
12+
- [`gradient`](@ref)
13+
- [`jacobian`](@ref)
14+
- [`hessian`](@ref)
15+
- [`pushforward`](@ref)
16+
- [`pullback`](@ref)
17+
- [`hvp`](@ref)
1818

1919
Most operators have 4 variants, which look like this in the first order: `operator`, `operator!`, `value_and_operator`, `value_and_operator!`.
2020

2121
## New operator
2222

2323
To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above).
24-
For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).
24+
For some operators, you will also need to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).
2525

2626
The method `prepare_operator_nokwarg` must output a `prep` object of the correct type.
27-
For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
28-
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
27+
For instance, `prepare_gradient_nokwarg(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
28+
Assuming you don't need any preparation for said operator, you can use the trivial preparation types that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
2929
Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there.
30+
Take inspiration from existing operators on how to enforce the signature `SIG`.
3031

3132
## New backend
3233

@@ -36,18 +37,18 @@ Your AD package needs to be registered first.
3637
### Core code
3738

3839
In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines.
39-
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.
40+
You also have to define [`ADTypes.mode`](@extref), [`DifferentiationInterface.check_available`](@ref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.
4041

4142
!!! info
42-
43+
4344
In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4445
However, putting it in the DifferentiationInterface.jl PR is a good first step for debugging.
4546

4647
In a [package extension](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) named `DifferentiationInterfaceSuperDiffExt`, you need to implement at least [`pushforward`](@ref) or [`pullback`](@ref) (and their variants).
4748
The exact requirements depend on the differentiation mode you chose:
4849

4950
| backend mode | pushforward necessary | pullback necessary |
50-
|:------------------------------------------------- |:--------------------- |:------------------ |
51+
| :------------------------------------------------ | :-------------------- | :----------------- |
5152
| [`ADTypes.ForwardMode`](@extref ADTypes) | yes | no |
5253
| [`ADTypes.ReverseMode`](@extref ADTypes) | no | yes |
5354
| [`ADTypes.ForwardOrReverseMode`](@extref ADTypes) | yes | yes |
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Internals
2+
3+
The following names are not part of the public API.
4+
5+
```@autodocs
6+
Modules = [DifferentiationInterface]
7+
Public = false
8+
Filter = t -> !(Symbol(t) in [:outer, :inner, :Prep, :AutoForwardFromPrimitive, :AutoReverseFromPrimitive])
9+
```

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ export AutoSparse
129129

130130
@public inner, outer
131131
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
132+
@public Prep
132133

133134
include("init.jl")
134135

DifferentiationInterface/src/utils/prep.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
Prep
3+
4+
Abstract supertype for all preparation results (outputs of `prepare_operator` functions).
5+
6+
!!! warning
7+
The public API does not make any guarantees about the type parameters or field layout of `Prep`, the only guarantee is that this type exists.
8+
"""
19
abstract type Prep{SIG} end
210

311
"""
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using DifferentiationInterface: Prep
2+
using InteractiveUtils: subtypes
3+
using Test
4+
5+
@test subtypes(Prep) == [
6+
DifferentiationInterface.DerivativePrep,
7+
DifferentiationInterface.GradientPrep,
8+
DifferentiationInterface.HVPPrep,
9+
DifferentiationInterface.HessianPrep,
10+
DifferentiationInterface.JacobianPrep,
11+
DifferentiationInterface.PullbackPrep,
12+
DifferentiationInterface.PushforwardPrep,
13+
DifferentiationInterface.SecondDerivativePrep,
14+
]

DifferentiationInterface/test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
66
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
77
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
8+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
89
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
910
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -23,6 +24,7 @@ ComponentArrays = "0.15.27"
2324
DataFrames = "1.7.0"
2425
Dates = "1"
2526
ExplicitImports = "1.10.1"
27+
InteractiveUtils = "1"
2628
JET = "0.9,0.10"
2729
JLArrays = "0.2.0"
2830
Pkg = "1"

0 commit comments

Comments
 (0)