From c95e9376a5a7df7bca4aa5b22130c9fde365db06 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 5 Sep 2017 19:25:34 -0400 Subject: [PATCH] constructors --- src/Flux.jl | 2 +- src/layers/recurrent.jl | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 206a58fa..1b4cbbc7 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ module Flux using Juno using Lazy: @forward -export Chain, Dense, +export Chain, Dense, RNN, LSTM, SGD using NNlib diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index c067a302..1ac7f9f5 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -16,6 +16,8 @@ function (m::Recur)(xs...) return y end +Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") + # Vanilla RNN struct RNNCell{D,V} @@ -23,7 +25,7 @@ struct RNNCell{D,V} h::V end -RNNCell(in::Integer, out::Integer, init = initn) = +RNNCell(in::Integer, out::Integer; init = initn) = RNNCell(Dense(in+out, out, init = initn), track(initn(out))) function (m::RNNCell)(h, x) @@ -37,6 +39,8 @@ function Base.show(io::IO, m::RNNCell) print(io, "RNNCell(", m.d, ")") end +RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) + # LSTM struct LSTMCell{D1,D2,V} @@ -66,3 +70,10 @@ function (m::LSTMCell)(h_, x) end hidden(m::LSTMCell) = (m.h, m.c) + +Base.show(io::IO, m::LSTMCell) = + print(io, "LSTMCell(", + size(m.forget.W, 2) - size(m.forget.W, 1), ", ", + size(m.forget.W, 1), ')') + +LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))