this is no longer test code

This commit is contained in:
Mike J Innes 2016-10-29 00:13:32 +01:00
parent 65e22210f6
commit 89c4a6df31
2 changed files with 13 additions and 19 deletions

View File

@ -83,22 +83,3 @@ end
graph(u::Unrolled) = u.graph graph(u::Unrolled) = u.graph
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n) unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
@net type Recurrent
Wxh; Whh; Why
bh; by
hidden
function (x)
hidden = σ( x * Wxh + hidden * Whh + bh )
y = hidden * Why + by
end
end
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
initn(hidden), initn(out), zeros(Float32, hidden))
# syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
# unrollgraph(r,5)[1] |> syntax |> prettify |> clipboard

13
src/layers/recurrent.jl Normal file
View File

@ -0,0 +1,13 @@
@net type Recurrent
Wxh; Whh; Why
bh; by
hidden
function (x)
hidden = σ( x * Wxh + hidden * Whh + bh )
y = hidden * Why + by
end
end
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
Recurrent(init((in, hidden)), init((hidden, hidden)), init((hidden, out)),
init(hidden), init(out), zeros(Float32, hidden))