do this properly
This commit is contained in:
parent
d5d7242c53
commit
1a726033f4
@ -85,17 +85,19 @@ unroll(model, n) = Unrolled(model, unrollgraph(model, n), n)
|
|||||||
|
|
||||||
@net type Recurrent
|
@net type Recurrent
|
||||||
Wxh; Whh; Why
|
Wxh; Whh; Why
|
||||||
|
bh; by
|
||||||
hidden
|
hidden
|
||||||
|
|
||||||
function (x)
|
function (x)
|
||||||
hidden = σ( Wxh*x + Whh*hidden )
|
hidden = σ( x * Wxh + hidden * Whh + bh )
|
||||||
y = Why*hidden
|
y = hidden * Why + by
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Recurrent() = Recurrent((rand(i,i) for i = 1:4)...)
|
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||||
|
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
||||||
|
initn(hidden), initn(out), zeros(hidden))
|
||||||
|
|
||||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||||
|
|
||||||
# r = Recurrent()
|
# r = Recurrent(10, 30, 20)
|
||||||
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
# unrollgraph(r,5) |> cse |> syntax′ |> prettify |> display
|
||||||
|
Loading…
Reference in New Issue
Block a user