diff --git a/src/Flux.jl b/src/Flux.jl index 1103fdbb..7397ec19 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -25,5 +25,6 @@ using .Compiler: @net include("layers/stateless.jl") include("layers/basic.jl") +include("layers/recurrent.jl") end # module diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl new file mode 100644 index 00000000..db2d3fc5 --- /dev/null +++ b/src/layers/recurrent.jl @@ -0,0 +1,64 @@ +# Stateful recurrence + +mutable struct Recur{T} + cell::T + state +end + +Recur(m) = Recur(m, hidden(m)) + +function (m::Recur)(xs...) + h, y = m.cell(m.state, xs...) + m.state = h + return y +end + +# Vanilla RNN + +struct RNNCell{D,V} + d::D + h::V +end + +RNNCell(in::Integer, out::Integer, init = initn) = + RNNCell(Dense(in+out, out, init = initn), track(initn(out))) + +function (m::RNNCell)(h, x) + h = m.d([x; h]) + return h, h +end + +hidden(m::RNNCell) = m.h + +function Base.show(io::IO, m::RNNCell) + print(io, "RNNCell(", m.d, ")") +end + +# LSTM + +struct LSTMCell{M} + Wxf::M; Wyf::M; bf::M + Wxi::M; Wyi::M; bi::M + Wxo::M; Wyo::M; bo::M + Wxc::M; Wyc::M; bc::M + hidden::M; cell::M +end + +LSTMCell(in, out; init = initn) = + LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))..., + track(zeros(out, 1)), track(zeros(out, 1))) + +function (m::LSTMCell)(h_, x) + h, c = h_ + # Gates + forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf ) + input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi ) + output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo ) + # State update and output + c′ = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc ) + c = forget .* c .+ input .* c′ + h = output .* tanh.(c) + return (h, c), h +end + +hidden(m::LSTMCell) = (m.hidden, m.cell) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 37181bab..796b34e9 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -19,6 +19,10 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ) Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs)) Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) +Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))