this is no longer test code
This commit is contained in:
parent
65e22210f6
commit
89c4a6df31
@ -83,22 +83,3 @@ end
|
|||||||
graph(u::Unrolled) = u.graph
|
graph(u::Unrolled) = u.graph
|
||||||
|
|
||||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
||||||
|
|
||||||
@net type Recurrent
|
|
||||||
Wxh; Whh; Why
|
|
||||||
bh; by
|
|
||||||
hidden
|
|
||||||
function (x)
|
|
||||||
hidden = σ( x * Wxh + hidden * Whh + bh )
|
|
||||||
y = hidden * Why + by
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
|
||||||
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
|
||||||
initn(hidden), initn(out), zeros(Float32, hidden))
|
|
||||||
|
|
||||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
|
||||||
|
|
||||||
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
|
|
||||||
# unrollgraph(r,5)[1] |> syntax′ |> prettify |> clipboard
|
|
||||||
|
13
src/layers/recurrent.jl
Normal file
13
src/layers/recurrent.jl
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
@net type Recurrent
|
||||||
|
Wxh; Whh; Why
|
||||||
|
bh; by
|
||||||
|
hidden
|
||||||
|
function (x)
|
||||||
|
hidden = σ( x * Wxh + hidden * Whh + bh )
|
||||||
|
y = hidden * Why + by
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||||
|
Recurrent(init((in, hidden)), init((hidden, hidden)), init((hidden, out)),
|
||||||
|
init(hidden), init(out), zeros(Float32, hidden))
|
Loading…
Reference in New Issue
Block a user