diff --git a/src/Flux.jl b/src/Flux.jl index f055bf24..44864589 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,7 +23,6 @@ include("layers/chain.jl") include("layers/affine.jl") include("layers/activation.jl") include("layers/cost.jl") -include("layers/recurrent.jl") include("data.jl") diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl deleted file mode 100644 index 883addf7..00000000 --- a/src/layers/recurrent.jl +++ /dev/null @@ -1,49 +0,0 @@ -@net type Recurrent - Wxy; Wyy; by - y - function (x) - y = tanh( x * Wxy .+ y{-1} * Wyy .+ by ) - end -end - -Recurrent(in, out; init = initn) = - Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out)) - -@net type GatedRecurrent - Wxr; Wyr; br - Wxu; Wyu; bu - Wxh; Wyh; bh - y - function (x) - reset = σ( x * Wxr .+ y{-1} * Wyr .+ br ) - update = σ( x * Wxu .+ y{-1} * Wyu .+ bu ) - y′ = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh ) - y = (1 .- update) .* y′ .+ update .* y{-1} - end -end - -GatedRecurrent(in, out; init = initn) = - GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)..., - zeros(Float32, (1, out))) - -@net type LSTM - Wxf; Wyf; bf - Wxi; Wyi; bi - Wxo; Wyo; bo - Wxc; Wyc; bc - y; state - function (x) - # Gates - forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf ) - input = σ( x * Wxi .+ y{-1} * Wyi .+ bi ) - output = σ( x * Wxo .+ y{-1} * Wyo .+ bo ) - # State update and output - state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc ) - state = forget .* state{-1} .+ input .* state′ - y = output .* tanh(state) - end -end - -LSTM(in, out; init = initn) = - LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)..., - zeros(Float32, (1, out)), zeros(Float32, (1, out))) diff --git a/test/compiler.jl b/test/compiler.jl index b26fd288..204940fc 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,5 +1,5 @@ using DataFlow, MacroTools -using Flux: Param, Recurrent, squeeze, unsqueeze, stack +using Flux: squeeze, unsqueeze, stack using Flux.Compiler: @net, graph using DataFlow: Line, Frame @@ -21,6 +21,17 @@ Affine(in::Integer, out::Integer; init = Flux.initn) = end end +@net type Recurrent + Wxy; Wyy; by + y + function (x) + y = tanh( x * Wxy .+ y{-1} * Wyy .+ by ) + end +end + +Recurrent(in, out; init = Flux.initn) = + Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out)) + syntax(v::Vertex) = prettify(DataFlow.syntax(v)) syntax(x) = syntax(graph(x))