From 89c4a6df312a37e2371fa88761c5e14aac6feda0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 29 Oct 2016 00:13:32 +0100 Subject: [PATCH] this is no longer test code --- src/compiler/loops.jl | 19 ------------------- src/layers/recurrent.jl | 13 +++++++++++++ 2 files changed, 13 insertions(+), 19 deletions(-) create mode 100644 src/layers/recurrent.jl diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 46a20152..6d3a7fba 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -83,22 +83,3 @@ end graph(u::Unrolled) = u.graph unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n) - -@net type Recurrent - Wxh; Whh; Why - bh; by - hidden - function (x) - hidden = σ( x * Wxh + hidden * Whh + bh ) - y = hidden * Why + by - end -end - -Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) = - Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)), - initn(hidden), initn(out), zeros(Float32, hidden)) - -# syntax′(x) = syntax(Flow.dl(x), bindconst = true) - -# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10)) -# unrollgraph(r,5)[1] |> syntax′ |> prettify |> clipboard diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl new file mode 100644 index 00000000..8882025b --- /dev/null +++ b/src/layers/recurrent.jl @@ -0,0 +1,13 @@ +@net type Recurrent + Wxh; Whh; Why + bh; by + hidden + function (x) + hidden = σ( x * Wxh + hidden * Whh + bh ) + y = hidden * Why + by + end +end + +Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) = + Recurrent(init((in, hidden)), init((hidden, hidden)), init((hidden, out)), + init(hidden), init(out), zeros(Float32, hidden))