more interesting recurrent model
This commit is contained in:
parent
652c26728e
commit
2a58b23085
|
@ -71,21 +71,19 @@ function unroll(model, n)
|
|||
@> group(state, group(outputs...)) detuple
|
||||
end
|
||||
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
# r = Recurrent(10,10)
|
||||
# unroll(r,5) |> cse |> syntax′ |> prettify |> display
|
||||
|
||||
@net type Recurrent
|
||||
Wx; Wh; B
|
||||
Wxh; Whh; Why
|
||||
hidden
|
||||
|
||||
function (x)
|
||||
hidden = σ( Wx*x + Wh*hidden + B )
|
||||
hidden = σ( Wxh*x + Whh*hidden )
|
||||
y = Why*hidden
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in::Integer, out::Integer; init = initn) =
|
||||
Recurrent(init(out, in), init(out, out), init(out), zeros(out))
|
||||
Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
|
||||
|
||||
Base.show(io::IO, r::Recurrent) =
|
||||
print(io, "Flux.Recurrent(...)")
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
|
||||
# r = Recurrent()
|
||||
# unroll(r,10) |> cse |> syntax′ |> prettify |> display
|
||||
|
|
Loading…
Reference in New Issue