basic recurrence
This commit is contained in:
parent
f6771b98cd
commit
9642ae8cd6
@ -25,5 +25,6 @@ using .Compiler: @net
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/recurrent.jl")
|
||||
|
||||
end # module
|
||||
|
64
src/layers/recurrent.jl
Normal file
64
src/layers/recurrent.jl
Normal file
@ -0,0 +1,64 @@
|
||||
# Stateful recurrence
|
||||
|
||||
mutable struct Recur{T}
|
||||
cell::T
|
||||
state
|
||||
end
|
||||
|
||||
Recur(m) = Recur(m, hidden(m))
|
||||
|
||||
function (m::Recur)(xs...)
|
||||
h, y = m.cell(m.state, xs...)
|
||||
m.state = h
|
||||
return y
|
||||
end
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{D,V}
|
||||
d::D
|
||||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, init = initn) =
|
||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d([x; h])
|
||||
return h, h
|
||||
end
|
||||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
end
|
||||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{M}
|
||||
Wxf::M; Wyf::M; bf::M
|
||||
Wxi::M; Wyi::M; bi::M
|
||||
Wxo::M; Wyo::M; bo::M
|
||||
Wxc::M; Wyc::M; bc::M
|
||||
hidden::M; cell::M
|
||||
end
|
||||
|
||||
LSTMCell(in, out; init = initn) =
|
||||
LSTMCell(track.(vcat([[init(out, in), init(out, out), init(out, 1)] for _ = 1:4]...))...,
|
||||
track(zeros(out, 1)), track(zeros(out, 1)))
|
||||
|
||||
function (m::LSTMCell)(h_, x)
|
||||
h, c = h_
|
||||
# Gates
|
||||
forget = σ.( m.Wxf * x .+ m.Wyf * h .+ m.bf )
|
||||
input = σ.( m.Wxi * x .+ m.Wyi * h .+ m.bi )
|
||||
output = σ.( m.Wxo * x .+ m.Wyo * h .+ m.bo )
|
||||
# State update and output
|
||||
c′ = tanh.( m.Wxc * x .+ m.Wyc * h .+ m.bc )
|
||||
c = forget .* c .+ input .* c′
|
||||
h = output .* tanh.(c)
|
||||
return (h, c), h
|
||||
end
|
||||
|
||||
hidden(m::LSTMCell) = (m.hidden, m.cell)
|
@ -19,6 +19,10 @@ back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
|
||||
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
||||
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
|
Loading…
Reference in New Issue
Block a user