this is no longer test code
This commit is contained in:
parent
65e22210f6
commit
89c4a6df31
@ -83,22 +83,3 @@ end
|
||||
graph(u::Unrolled) = u.graph
|
||||
|
||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
||||
|
||||
@net type Recurrent
|
||||
Wxh; Whh; Why
|
||||
bh; by
|
||||
hidden
|
||||
function (x)
|
||||
hidden = σ( x * Wxh + hidden * Whh + bh )
|
||||
y = hidden * Why + by
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||
Recurrent(initn((in, hidden)), initn((hidden, hidden)), initn((hidden, out)),
|
||||
initn(hidden), initn(out), zeros(Float32, hidden))
|
||||
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
|
||||
# r = Chain(Recurrent(10, 30, 20), Recurrent(20, 40, 10))
|
||||
# unrollgraph(r,5)[1] |> syntax′ |> prettify |> clipboard
|
||||
|
13
src/layers/recurrent.jl
Normal file
13
src/layers/recurrent.jl
Normal file
@ -0,0 +1,13 @@
|
||||
@net type Recurrent
|
||||
Wxh; Whh; Why
|
||||
bh; by
|
||||
hidden
|
||||
function (x)
|
||||
hidden = σ( x * Wxh + hidden * Whh + bh )
|
||||
y = hidden * Why + by
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in::Integer, hidden::Integer, out::Integer; init = initn) =
|
||||
Recurrent(init((in, hidden)), init((hidden, hidden)), init((hidden, out)),
|
||||
init(hidden), init(out), zeros(Float32, hidden))
|
Loading…
Reference in New Issue
Block a user