constructors

This commit is contained in:
Mike J Innes 2017-09-05 19:25:34 -04:00
parent b023da1b7d
commit c95e9376a5
2 changed files with 13 additions and 2 deletions

View File

@ -7,7 +7,7 @@ module Flux
using Juno using Juno
using Lazy: @forward using Lazy: @forward
export Chain, Dense, export Chain, Dense, RNN, LSTM,
SGD SGD
using NNlib using NNlib

View File

@ -16,6 +16,8 @@ function (m::Recur)(xs...)
return y return y
end end
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
# Vanilla RNN # Vanilla RNN
struct RNNCell{D,V} struct RNNCell{D,V}
@ -23,7 +25,7 @@ struct RNNCell{D,V}
h::V h::V
end end
RNNCell(in::Integer, out::Integer, init = initn) = RNNCell(in::Integer, out::Integer; init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out))) RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
function (m::RNNCell)(h, x) function (m::RNNCell)(h, x)
@ -37,6 +39,8 @@ function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")") print(io, "RNNCell(", m.d, ")")
end end
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM # LSTM
struct LSTMCell{D1,D2,V} struct LSTMCell{D1,D2,V}
@ -66,3 +70,10 @@ function (m::LSTMCell)(h_, x)
end end
hidden(m::LSTMCell) = (m.h, m.c) hidden(m::LSTMCell) = (m.h, m.c)
Base.show(io::IO, m::LSTMCell) =
print(io, "LSTMCell(",
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
size(m.forget.W, 1), ')')
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))