constructors

This commit is contained in:
Mike J Innes 2017-09-05 19:25:34 -04:00
parent b023da1b7d
commit c95e9376a5
2 changed files with 13 additions and 2 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno
using Lazy: @forward
export Chain, Dense,
export Chain, Dense, RNN, LSTM,
SGD
using NNlib

View File

@ -16,6 +16,8 @@ function (m::Recur)(xs...)
return y
end
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
# Vanilla RNN
struct RNNCell{D,V}
@ -23,7 +25,7 @@ struct RNNCell{D,V}
h::V
end
RNNCell(in::Integer, out::Integer, init = initn) =
RNNCell(in::Integer, out::Integer; init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x)
@ -37,6 +39,8 @@ function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
struct LSTMCell{D1,D2,V}
@ -66,3 +70,10 @@ function (m::LSTMCell)(h_, x)
end
hidden(m::LSTMCell) = (m.h, m.c)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))