remove rnns

This commit is contained in:
Mike J Innes 2017-08-19 20:52:17 +01:00
parent a581856954
commit 318e503d9b
3 changed files with 12 additions and 51 deletions

View File

@ -23,7 +23,6 @@ include("layers/chain.jl")
include("layers/affine.jl")
include("layers/activation.jl")
include("layers/cost.jl")
include("layers/recurrent.jl")
include("data.jl")

View File

@ -1,49 +0,0 @@
@net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
end
end
Recurrent(in, out; init = initn) =
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
@net type GatedRecurrent
Wxr; Wyr; br
Wxu; Wyu; bu
Wxh; Wyh; bh
y
function (x)
reset = σ( x * Wxr .+ y{-1} * Wyr .+ br )
update = σ( x * Wxu .+ y{-1} * Wyu .+ bu )
y = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh )
y = (1 .- update) .* y .+ update .* y{-1}
end
end
GatedRecurrent(in, out; init = initn) =
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
zeros(Float32, (1, out)))
@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
# State update and output
state = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
state = forget .* state{-1} .+ input .* state
y = output .* tanh(state)
end
end
LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
zeros(Float32, (1, out)), zeros(Float32, (1, out)))

View File

@ -1,5 +1,5 @@
using DataFlow, MacroTools
using Flux: Param, Recurrent, squeeze, unsqueeze, stack
using Flux: squeeze, unsqueeze, stack
using Flux.Compiler: @net, graph
using DataFlow: Line, Frame
@ -21,6 +21,17 @@ Affine(in::Integer, out::Integer; init = Flux.initn) =
end
end
@net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
end
end
Recurrent(in, out; init = Flux.initn) =
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
syntax(x) = syntax(graph(x))