basic recurrence

This commit is contained in:
Mike J Innes 2017-09-03 02:12:44 -04:00
parent f6771b98cd
commit 9642ae8cd6
3 changed files with 69 additions and 0 deletions

View File

@ -25,5 +25,6 @@ using .Compiler: @net
include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/recurrent.jl")
end # module

64
src/layers/recurrent.jl Normal file
View File

@ -0,0 +1,64 @@
# Stateful recurrence
mutable struct Recur{T}
cell::T
state
end
Recur(m) = Recur(m, hidden(m))
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
# Vanilla RNN
struct RNNCell{D,V}
d::D
h::V
end
RNNCell(in::Integer, out::Integer, init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x)
h = m.d([x; h])
return h, h
end
hidden(m::RNNCell) = m.h
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
# LSTM
struct LSTMCell{M}
Wxf::M; Wyf::M; bf::M
Wxi::M; Wyi::M; bi::M
Wxo::M; Wyo::M; bo::M
Wxc::M; Wyc::M; bc::M
hidden::M; cell::M
end
LSTMCell(in, out; init = initn) =
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
track(zeros(out, 1)), track(zeros(out, 1)))
function (m::LSTMCell)(h_, x)
h, c = h_
# Gates
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
# State update and output
c = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
c = forget .* c .+ input .* c
h = output .* tanh.(c)
return (h, c), h
end
hidden(m::LSTMCell) = (m.hidden, m.cell)

View File

@ -19,6 +19,10 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))