remove rnns
This commit is contained in:
parent
a581856954
commit
318e503d9b
@ -23,7 +23,6 @@ include("layers/chain.jl")
|
||||
include("layers/affine.jl")
|
||||
include("layers/activation.jl")
|
||||
include("layers/cost.jl")
|
||||
include("layers/recurrent.jl")
|
||||
|
||||
include("data.jl")
|
||||
|
||||
|
@ -1,49 +0,0 @@
|
||||
@net type Recurrent
|
||||
Wxy; Wyy; by
|
||||
y
|
||||
function (x)
|
||||
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in, out; init = initn) =
|
||||
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
||||
|
||||
@net type GatedRecurrent
|
||||
Wxr; Wyr; br
|
||||
Wxu; Wyu; bu
|
||||
Wxh; Wyh; bh
|
||||
y
|
||||
function (x)
|
||||
reset = σ( x * Wxr .+ y{-1} * Wyr .+ br )
|
||||
update = σ( x * Wxu .+ y{-1} * Wyu .+ bu )
|
||||
y′ = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh )
|
||||
y = (1 .- update) .* y′ .+ update .* y{-1}
|
||||
end
|
||||
end
|
||||
|
||||
GatedRecurrent(in, out; init = initn) =
|
||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
|
||||
zeros(Float32, (1, out)))
|
||||
|
||||
@net type LSTM
|
||||
Wxf; Wyf; bf
|
||||
Wxi; Wyi; bi
|
||||
Wxo; Wyo; bo
|
||||
Wxc; Wyc; bc
|
||||
y; state
|
||||
function (x)
|
||||
# Gates
|
||||
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
|
||||
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
|
||||
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
|
||||
# State update and output
|
||||
state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
|
||||
state = forget .* state{-1} .+ input .* state′
|
||||
y = output .* tanh(state)
|
||||
end
|
||||
end
|
||||
|
||||
LSTM(in, out; init = initn) =
|
||||
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
||||
zeros(Float32, (1, out)), zeros(Float32, (1, out)))
|
@ -1,5 +1,5 @@
|
||||
using DataFlow, MacroTools
|
||||
using Flux: Param, Recurrent, squeeze, unsqueeze, stack
|
||||
using Flux: squeeze, unsqueeze, stack
|
||||
using Flux.Compiler: @net, graph
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@ -21,6 +21,17 @@ Affine(in::Integer, out::Integer; init = Flux.initn) =
|
||||
end
|
||||
end
|
||||
|
||||
@net type Recurrent
|
||||
Wxy; Wyy; by
|
||||
y
|
||||
function (x)
|
||||
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in, out; init = Flux.initn) =
|
||||
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
||||
|
||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
||||
syntax(x) = syntax(graph(x))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user