tweaks
This commit is contained in:
parent
ce01ecb058
commit
467e829b64
|
@ -52,33 +52,24 @@ function break!(g::IVertex)
|
|||
end
|
||||
|
||||
# r = Recurrent(10, 10)
|
||||
# r = Chain(Dense(10,10), Recurrent(10,10))
|
||||
# r = Chain(Recurrent(10,10), Dense(10,10))
|
||||
# r = Chain(Recurrent(10,10),Recurrent(10,10))
|
||||
|
||||
#
|
||||
# atomise(r) |> syntax |> prettify |> display
|
||||
#
|
||||
# break!(atomise(r)) |> syntax |> prettify |> display
|
||||
|
||||
# @model type Recurrent
|
||||
# Wx; Wh; B
|
||||
# hidden
|
||||
#
|
||||
# function (x)
|
||||
# hidden = σ( Wx*x + Wh*hidden + B )
|
||||
# end
|
||||
# end
|
||||
#
|
||||
# Recurrent(in::Integer, out::Integer; init = initn) =
|
||||
# Recurrent(init(out, in), init(out, out), init(out), zeros(out))
|
||||
|
||||
@model type Recurrent
|
||||
model
|
||||
Wx; Wh; B
|
||||
hidden
|
||||
|
||||
function (x)
|
||||
hidden = σ(model(vcat(x, hidden)))
|
||||
hidden = σ( Wx*x + Wh*hidden + B )
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in::Integer, out::Integer; init = initn) =
|
||||
Recurrent(Dense(in + out, out, init = init), zeros(out))
|
||||
Recurrent(init(out, in), init(out, out), init(out), zeros(out))
|
||||
|
||||
Base.show(io::IO, r::Recurrent) =
|
||||
print(io, "Flux.Recurrent(...)")
|
||||
|
|
Loading…
Reference in New Issue