do this properly

This commit is contained in:
Mike J Innes 2016-10-26 15:49:35 +01:00
parent d5d7242c53
commit 1a726033f4
1 changed files with 7 additions and 5 deletions

View File

@ -85,17 +85,19 @@ unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
@net type Recurrent
Wxh; Whh; Why
bh; by
hidden
function (x)
hidden = σ( Wxh*x + Whh*hidden )
y = Why*hidden
hidden = σ( x * Wxh + hidden * Whh + bh )
y = hidden * Why + by
end
end
Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
initn(hidden), initn(out), zeros(hidden))
# syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent()
# r = Recurrent(10, 30, 20)
# unrollgraph(r,5) |> cse |> syntax |> prettify |> display