do this properly

This commit is contained in:
Mike J Innes 2016-10-26 15:49:35 +01:00
parent d5d7242c53
commit 1a726033f4

View File

@ -85,17 +85,19 @@ unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
@net type Recurrent @net type Recurrent
Wxh; Whh; Why Wxh; Whh; Why
bh; by
hidden hidden
function (x) function (x)
hidden = σ( Wxh*x + Whh*hidden ) hidden = σ( x * Wxh + hidden * Whh + bh )
y = Why*hidden y = hidden * Why + by
end end
end end
Recurrent() = Recurrent((rand(i,i) for i = 1:4)...) Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
initn(hidden), initn(out), zeros(hidden))
# syntax(x) = syntax(Flow.dl(x), bindconst = true) # syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent() # r = Recurrent(10, 30, 20)
# unrollgraph(r,5) |> cse |> syntax |> prettify |> display # unrollgraph(r,5) |> cse |> syntax |> prettify |> display