Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: handle zero_tangent from cyclic data structures v2 Via premapping #655

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,38 @@ end

@generated function zero_tangent(primal)
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero.
zfield_exprs = map(fieldnames(primal)) do fname
fval = :(
if isdefined(primal, $(QuoteNode(fname)))
zero_tangent(getfield(primal, $(QuoteNode(fname))))
else
# This is going to be potentially bad, but that's what they get for not giving us a primal
# This will never me mutated inplace, rather it will alway be replaced with an actual value first
ZeroTangent()
end
)
Expr(:kw, fname, fval)

# easy case exit early, can't hold references, can't be a reference.
if isbitstype(primal)
zfield_exprs = map(fieldnames(primal)) do fname
fval = :(zero_tangent(getfield(primal, $(QuoteNode(fname)))))
Expr(:kw, fname, fval)
end
return :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...))))
end
return if has_mutable_tangent(primal)
any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent
fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype))

# hard case need to be prepared for references to this, or that are contained within this
quote
counts = $count_references(primal)
any_mask = $(Expr(:tuple, Expr(:parameters, map(fieldnames(primal), fieldtypes(primal)) do fname, ftype
# If it is is unassigned, or if it doesn't have a concrete type, or we have multiple reference to it
# then let it take any value for its tangent
fdef = :(
!isdefined(primal, $(QuoteNode(fname))) ||
!isconcretetype($ftype) ||
get(counts, $(QuoteNode(fname)), 0) > 1
)
Expr(:kw, fname, fdef)
end
end...)))

# Construct tangents

# Go back and fill in tangents that were not ready
end

## TODO rewrite below
has_mutable_tangent(primal)
any_mask =
:($MutableTangent{$primal}(
$(Expr(:tuple, Expr(:parameters, any_mask...))),
$(Expr(:tuple, Expr(:parameters, zfield_exprs...))),
Expand Down Expand Up @@ -171,6 +185,36 @@ function zero_tangent(x::Array{P,N}) where {P,N}
return y
end

###############################################
count_references(x) = count_references(IdDict{Any, Int}(), x)
function count_references!(counts::IdDict{Any, Int}, x)
isbits(x) && return counts # can't be a refernece and can't hold a reference
counts[x] = get(counts, x, 0) + 1 # Increment *before* recursing
if counts[x] == 1 # Only recurse the first time
for ii in fieldcount(typeof(x))
field = getfield(x, ii)
count_references!(counts, field)
end
end
return counts
end

function count_references!(counts::IdDict{Any, Int}, x::Array)
counts[x] = get(counts, x, 0) + 1 # increment before recursing
isbitstype(eltype(x)) && return counts # no need to look inside, it can't hold references
if counts[x] == 1 # only recurse the first time
for ele in x
count_references!(counts, ele)
end
end
return counts
end

count_references!(counts::IdDict{Any, Int}, ::DataType) = counts

###############################################


# Sad heauristic methods we need because of unassigned values
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
Expand Down
18 changes: 9 additions & 9 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,24 +303,24 @@ end
lk = Link(1.5)
lk.next = lk

@test_broken d = zero_tangent(lk)
@test_broken d.data == 0.0
@test_broken d.next === d
d = zero_tangent(lk)
@test d.data == 0.0
@test d.next === d

struct CarryingArray
x::Vector
end
ca = CarryingArray(Any[1.5])
push!(ca.x, ca)
@test_broken d_ca = zero_tangent(ca)
@test_broken d_ca[1] == 0.0
@test_broken d_ca[2] === _ca
@test d_ca = zero_tangent(ca)
@test d_ca[1] == 0.0
@test d_ca[2] === _ca

# Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing
xs = Any[1.5]
push!(xs, xs)
@test_broken d_xs = zero_tangent(xs)
@test_broken d_xs[1] == 0.0
@test_broken d_xs[2] == d_xs
@test d_xs = zero_tangent(xs)
@test d_xs[1] == 0.0
@test d_xs[2] == d_xs
end
end
Loading