do this properly
This commit is contained in:
parent
d5d7242c53
commit
1a726033f4
|
@ -85,17 +85,19 @@ unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
|||
|
||||
@net type Recurrent
|
||||
Wxh; Whh; Why
|
||||
bh; by
|
||||
hidden
|
||||
|
||||
function (x)
|
||||
hidden = σ( Wxh*x + Whh*hidden )
|
||||
y = Why*hidden
|
||||
hidden = σ( x * Wxh + hidden * Whh + bh )
|
||||
y = hidden * Why + by
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
|
||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
||||
initn(hidden), initn(out), zeros(hidden))
|
||||
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
|
||||
# r = Recurrent()
|
||||
# r = Recurrent(10, 30, 20)
|
||||
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
||||
|
|
Loading…
Reference in New Issue