constructors
This commit is contained in:
parent
b023da1b7d
commit
c95e9376a5
@ -7,7 +7,7 @@ module Flux
|
||||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense,
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD
|
||||
|
||||
using NNlib
|
||||
|
@ -16,6 +16,8 @@ function (m::Recur)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
struct RNNCell{D,V}
|
||||
@ -23,7 +25,7 @@ struct RNNCell{D,V}
|
||||
h::V
|
||||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer, init = initn) =
|
||||
RNNCell(in::Integer, out::Integer; init = initn) =
|
||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
@ -37,6 +39,8 @@ function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
end
|
||||
|
||||
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
|
||||
struct LSTMCell{D1,D2,V}
|
||||
@ -66,3 +70,10 @@ function (m::LSTMCell)(h_, x)
|
||||
end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
Base.show(io::IO, m::LSTMCell) =
|
||||
print(io, "LSTMCell(",
|
||||
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||
size(m.forget.W, 1), ')')
|
||||
|
||||
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||
|
Loading…
Reference in New Issue
Block a user