|
94 | 94 | coeffs::C
|
95 | 95 | TaylorTangent(coeffs) = $(Expr(:new, :(TaylorTangent{typeof(coeffs)}), :coeffs))
|
96 | 96 | end
|
| 97 | +Base.:(==)(a::TaylorTangent, b::TaylorTangent) = a.coeffs == b.coeffs |
| 98 | +Base.hash(tt::TaylorTangent, h::UInt64) = hash(tt.coeffs, h) |
97 | 99 |
|
98 | 100 | """
|
99 | 101 | struct TaylorTangent{C}
|
@@ -159,6 +161,9 @@ TangentBundle
|
159 | 161 | TangentBundle{N}(primal::B, tangent::P) where {N, B, P<:AbstractTangentSpace} =
|
160 | 162 | _TangentBundle(Val{N}(), primal, tangent)
|
161 | 163 |
|
| 164 | +Base.hash(tb::TangentBundle, h::UInt64) = hash(tb.primal, h) |
| 165 | +Base.:(==)(a::TangentBundle, b::TangentBundle) = (a.primal == b.primal) && (a.tangent == b.tangent) |
| 166 | + |
162 | 167 | const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
|
163 | 168 |
|
164 | 169 | check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
|
@@ -201,20 +206,25 @@ end
|
201 | 206 |
|
202 | 207 | const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
|
203 | 208 |
|
| 209 | + |
| 210 | +function TaylorBundle{N, B, P}(primal::B, coeffs::P) where {N, B, P} |
| 211 | + check_taylor_invariants(coeffs, primal, N) |
| 212 | + _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) |
| 213 | +end |
204 | 214 | function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
|
205 | 215 | check_taylor_invariants(coeffs, primal, N)
|
206 | 216 | _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
|
207 | 217 | end
|
| 218 | +function TaylorBundle{N}(primal, coeffs) where {N} |
| 219 | + check_taylor_invariants(coeffs, primal, N) |
| 220 | + _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) |
| 221 | +end |
208 | 222 |
|
209 | 223 | function check_taylor_invariants(coeffs, primal, N)
|
210 | 224 | @assert length(coeffs) == N
|
211 |
| - |
212 | 225 | end
|
213 | 226 | @ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
|
214 | 227 |
|
215 |
| -function TaylorBundle{N}(primal, coeffs) where {N} |
216 |
| - _TangentBundle(Val{N}(), primal, TaylorTangent(coeffs)) |
217 |
| -end |
218 | 228 |
|
219 | 229 | function Base.show(io::IO, x::TaylorBundle{1})
|
220 | 230 | print(io, x.primal)
|
@@ -350,7 +360,7 @@ function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle)
|
350 | 360 | unbundle(atb), Δ->throw(Δ)
|
351 | 361 | end
|
352 | 362 |
|
353 |
| -function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...) |
| 363 | +function StructArrays.createinstance(T::Type{<:UniformBundle}, args...) |
354 | 364 | T(args[1], args[2])
|
355 | 365 | end
|
356 | 366 |
|
|
0 commit comments