diff --git a/Project.toml b/Project.toml index 5187a82..a2e4518 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" version = "0.1.2" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 91438b0..893dc46 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -13,4 +13,6 @@ include("chain.jl") include("compact.jl") +include("new_recur.jl") + end # module Fluxperimental diff --git a/src/new_recur.jl b/src/new_recur.jl new file mode 100644 index 0000000..824644f --- /dev/null +++ b/src/new_recur.jl @@ -0,0 +1,140 @@ +import Flux: ChainRulesCore +import Compat: stack + +##### Helper scan funtion which can likely be put into NNLib. ##### +""" + scan_full + +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence. +""" +function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) + # Recurrence operation used in the fold. Takes the state of the + # fold and the next input, returns the new state. + function recurrence_op((carry, outputs), input) + carry, out = func(carry, input) + return carry, vcat(outputs, [out]) + end + # Fold left to right. + return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) +end + +function scan_full(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. + xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_full(func, init_carry, xs) +end + +# Chain Rule for Base.mapfoldl_impl +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init, + x::Union{AbstractArray, Tuple}; +) where {G} + hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this + accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b + c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) + end + y = first(last(hobbits)) + axe = axes(x) + project = ChainRulesCore.ProjectTo(x) + function unfoldl(dy) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + ds, da, db = back(dc) + end + dop = sum(first, trio) + dx = map(last, Iterators.reverse(trio)) + d_init = trio[end][2] + return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) + end + return y, unfoldl +end + + +""" + scan_partial + +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence. +""" +function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) + x_init, x_rest = Iterators.peel(xs) + (carry, y) = func(init_carry, x_init) + for x in x_rest + (carry, y) = func(carry, x) + end + carry, y +end + +function scan_partial(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. + xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_partial(func, init_carry, xs) +end + + +""" + NewRecur +New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns: + +```julia +l = NewRNN(1,2) +xs # Some input array Input x BatchSize x Time +init_carry # the initial carry of the cell. +l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry. +l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE. +``` +""" +struct NewRecur{RET_SEQUENCE, T} + cell::T + # state::S + function NewRecur(cell; return_sequence::Bool=false) + new{return_sequence, typeof(cell)}(cell) + end + function NewRecur{true}(cell) + new{true, typeof(cell)}(cell) + end + function NewRecur{false}(cell) + new{false, typeof(cell)}(cell) + end +end + +Flux.@functor NewRecur +Flux.trainable(a::NewRecur) = (; cell = a.cell) +Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") +NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) + +(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur") +(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur") + +function (l::NewRecur)(xs::AbstractArray) + results = l(l.cell.state0, xs) + results[2] # Only return the output here. +end + +function (l::NewRecur{false})(init_carry, xs) + results = scan_partial(l.cell, init_carry, xs) + results[1], results[2] +end + +function (l::NewRecur{true})(init_carry, xs) + results = scan_full(l.cell, init_carry, xs) + results[1], stack(results[2], dims=3) +end diff --git a/test/new_recur.jl b/test/new_recur.jl new file mode 100644 index 0000000..cb5cf2a --- /dev/null +++ b/test/new_recur.jl @@ -0,0 +1,188 @@ +@testset "NewRecur RNN" begin + @testset "Forward Pass" begin + # tanh is needed for forward check to determine ordering of inputs. + cell = Flux.RNNCell(1, 1, tanh) + layer = Fluxperimental.NewRecur(cell; return_sequence=true) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + # Lets make sure th output is correct + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1, 2) + @test layer(x)[1, 1, 2] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed. + @test layer(cell.state0, x)[2][1,1,2] ≈ out[1,1] + + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) + end + + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = nm_layer(x_block) + sum(out[1, 1, 2]) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end + + @testset "gradients-explicit" begin + + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out[1, 1, 2]) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + end +end + +@testset "New Recur RNN Partial Sequence" begin + @testset "Forward Pass" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Fluxperimental.NewRecur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1) + @test layer(x)[1, 1] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 + @test layer(cell.state0, x)[2][1,1] ≈ out[1,1] + + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) + end + + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = (nm_layer)(x_block) + sum(out) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end + + @testset "gradients-explicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 55315cc..5291a8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,6 @@ using Flux, Fluxperimental include("compact.jl") + include("new_recur.jl") + end