recurrence proof of concept
This commit is contained in:
parent
d58fefb972
commit
abcb6d6351
@ -7,6 +7,7 @@ using MacroTools, Lazy, Flow
|
|||||||
include("model.jl")
|
include("model.jl")
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
|
include("compiler/graph.jl")
|
||||||
include("compiler/diff.jl")
|
include("compiler/diff.jl")
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
|
@ -11,10 +11,6 @@ function process_func(ex, params = [])
|
|||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
|
||||||
immutable ModelInput
|
|
||||||
name
|
|
||||||
end
|
|
||||||
|
|
||||||
function makegraph(graph, args)
|
function makegraph(graph, args)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
mapconst(graph) do x
|
mapconst(graph) do x
|
||||||
|
19
src/compiler/graph.jl
Normal file
19
src/compiler/graph.jl
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# TODO: change the input approach
|
||||||
|
immutable ModelInput
|
||||||
|
name
|
||||||
|
end
|
||||||
|
|
||||||
|
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
|
||||||
|
|
||||||
|
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
||||||
|
bumpinput(x) = x
|
||||||
|
|
||||||
|
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||||
|
|
||||||
|
function spliceinputs(v::IVertex, inputs::IVertex...)
|
||||||
|
postwalk(v) do v
|
||||||
|
isinput(value(v)) ?
|
||||||
|
inputs[value(v).value.name] :
|
||||||
|
v
|
||||||
|
end
|
||||||
|
end
|
@ -15,48 +15,70 @@ function liftloops!(ex, params)
|
|||||||
return ex
|
return ex
|
||||||
end
|
end
|
||||||
|
|
||||||
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
function hasloops(model)
|
||||||
bumpinput(x) = x
|
g = graph(model)
|
||||||
|
g == nothing && return false
|
||||||
|
iscyclic(g) && return true
|
||||||
|
result = false
|
||||||
|
map(m -> hasloops(m) && (result = true), g)
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
function atomise(model)
|
||||||
|
postwalk(graph(model)) do v
|
||||||
function unroll(delay::IVertex)
|
hasloops(value(v)) || return v
|
||||||
prewalk(delay[1]) do v
|
spliceinputs(atomise(value(v)), inputs(v)...)
|
||||||
isa(value(v), Delay) ? constant(ModelInput(1)) : v
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function break!(model)
|
hinput(n) = vertex(getindex, constant(ModelInput(1)), constant(n))
|
||||||
iscyclic(graph(model)) || return model
|
|
||||||
g = bumpinputs(graph(model))
|
function unroll!(delay::IVertex, n)
|
||||||
|
prewalk!(delay[1]) do v
|
||||||
|
v === delay ? hinput(n) : v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function break!(g::IVertex)
|
||||||
|
g = bumpinputs(g)
|
||||||
loops = []
|
loops = []
|
||||||
g = prewalk(g) do v
|
g = prewalk!(g) do v
|
||||||
isa(value(v), Delay) || return v
|
isa(value(v), Delay) || return v
|
||||||
push!(loops, unroll(v))
|
n = length(loops)+1
|
||||||
constant(ModelInput(1))
|
push!(loops, unroll!(v, n))
|
||||||
|
hinput(n)
|
||||||
end
|
end
|
||||||
cse(vertex(tuple, loops..., g))
|
cse(vertex(tuple, vertex(tuple, loops...), g))
|
||||||
end
|
end
|
||||||
|
|
||||||
# r = Recurrent(784, 10, 50)
|
# r = Recurrent(10, 10)
|
||||||
|
# r = Chain(Dense(10,10), Recurrent(10,10))
|
||||||
|
# r = Chain(Recurrent(10,10),Recurrent(10,10))
|
||||||
|
|
||||||
# break!(r)
|
# break!(atomise(r)) |> syntax |> prettify |> display
|
||||||
|
|
||||||
|
# @model type Recurrent
|
||||||
|
# Wx; Wh; B
|
||||||
|
# hidden
|
||||||
|
#
|
||||||
|
# function (x)
|
||||||
|
# hidden = σ( Wx*x + Wh*hidden + B )
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
#
|
||||||
|
# Recurrent(in::Integer, out::Integer; init = initn) =
|
||||||
|
# Recurrent(init(out, in), init(out, out), init(out), zeros(out))
|
||||||
|
|
||||||
@model type Recurrent
|
@model type Recurrent
|
||||||
Wxh; Whh; Bh
|
model
|
||||||
Wxy; Why; By
|
|
||||||
hidden
|
hidden
|
||||||
|
|
||||||
function (x)
|
function (x)
|
||||||
hidden = σ( Wxh*x + Whh*hidden + Bh )
|
hidden = σ(model(vcat(x, hidden)))
|
||||||
y = σ( Wxy*x + Why*hidden + By )
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Recurrent(in::Integer, out::Integer, hidden::Integer; init = initn) =
|
Recurrent(in::Integer, out::Integer; init = initn) =
|
||||||
Recurrent(init(hidden, in), init(hidden, hidden), init(hidden),
|
Recurrent(Dense(in + out, out, init = init), zeros(out))
|
||||||
init(out, in), init(out, hidden), init(hidden),
|
|
||||||
zeros(hidden))
|
|
||||||
|
|
||||||
Base.show(io::IO, r::Recurrent) =
|
Base.show(io::IO, r::Recurrent) =
|
||||||
print(io, "Flux.Recurrent(...)")
|
print(io, "Flux.Recurrent(...)")
|
||||||
|
@ -7,6 +7,8 @@ abstract Model
|
|||||||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||||
update!(m, η) = m
|
update!(m, η) = m
|
||||||
|
|
||||||
|
graph(m) = nothing
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
|
|
||||||
type Param{T}
|
type Param{T}
|
||||||
|
Loading…
Reference in New Issue
Block a user