constructors
This commit is contained in:
parent
b023da1b7d
commit
c95e9376a5
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno
|
using Juno
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense,
|
export Chain, Dense, RNN, LSTM,
|
||||||
SGD
|
SGD
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
|
@ -16,6 +16,8 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
|
|
||||||
struct RNNCell{D,V}
|
struct RNNCell{D,V}
|
||||||
@ -23,7 +25,7 @@ struct RNNCell{D,V}
|
|||||||
h::V
|
h::V
|
||||||
end
|
end
|
||||||
|
|
||||||
RNNCell(in::Integer, out::Integer, init = initn) =
|
RNNCell(in::Integer, out::Integer; init = initn) =
|
||||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
@ -37,6 +39,8 @@ function Base.show(io::IO, m::RNNCell)
|
|||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", m.d, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||||
|
|
||||||
# LSTM
|
# LSTM
|
||||||
|
|
||||||
struct LSTMCell{D1,D2,V}
|
struct LSTMCell{D1,D2,V}
|
||||||
@ -66,3 +70,10 @@ function (m::LSTMCell)(h_, x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
|
Base.show(io::IO, m::LSTMCell) =
|
||||||
|
print(io, "LSTMCell(",
|
||||||
|
size(m.forget.W, 2) - size(m.forget.W, 1), ", ",
|
||||||
|
size(m.forget.W, 1), ')')
|
||||||
|
|
||||||
|
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))
|
||||||
|
Loading…
Reference in New Issue
Block a user