diff --git a/NEWS.md b/NEWS.md index 87333f8717..882c3f8224 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,10 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.14 +* New layer `RepeatVector` which works like + RepeatVector in keras + ## v0.14.13 * New macro `Flux.@layer` which should be used in place of `@functor`. This also adds `show` methods for pretty printing. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ef81c30872..539d4e71db 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -864,3 +864,33 @@ EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean) function Base.show(io::IO, m::EmbeddingBag) print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") end + + +""" + RepeatVector(n::Int) + +Repeat the input `n` times along the last dimension. + +# Examples +```jldoctest +julia> rv = RepeatVector(3) +julia> rv([1, 2, 3]) +3×3 Matrix{Int64}: + 1 1 1 + 2 2 2 + 3 3 3 +``` +""" +struct RepeatVector + n::Int +end + +@layer RepeatVector + +function (rv::RepeatVector)(x::AbstractArray{T}) where {T} + expanded = reshape(x, (size(x)..., 1)) + repeated = repeat(expanded, outer = (1, rv.n, 1)) + return repeated +end + +Base.show(io::IO, rv::RepeatVector) = print(io, "RepeatVector($(rv.n))") \ No newline at end of file