recurrence proof of concept

This commit is contained in:
Mike J Innes 2016-08-31 02:37:53 +01:00
parent d58fefb972
commit abcb6d6351
5 changed files with 69 additions and 29 deletions

View File

@ -7,6 +7,7 @@ using MacroTools, Lazy, Flow
include("model.jl")
include("utils.jl")
include("compiler/graph.jl")
include("compiler/diff.jl")
include("compiler/code.jl")
include("compiler/loops.jl")

View File

@ -11,10 +11,6 @@ function process_func(ex, params = [])
return args, body
end
immutable ModelInput
name
end
function makegraph(graph, args)
@assert length(args) == 1
mapconst(graph) do x

19
src/compiler/graph.jl Normal file
View File

@ -0,0 +1,19 @@
# TODO: change the input approach
immutable ModelInput
name
end
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
bumpinput(x) = x
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
function spliceinputs(v::IVertex, inputs::IVertex...)
postwalk(v) do v
isinput(value(v)) ?
inputs[value(v).value.name] :
v
end
end

View File

@ -15,48 +15,70 @@ function liftloops!(ex, params)
return ex
end
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
bumpinput(x) = x
function hasloops(model)
g = graph(model)
g == nothing && return false
iscyclic(g) && return true
result = false
map(m -> hasloops(m) && (result = true), g)
return result
end
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
function unroll(delay::IVertex)
prewalk(delay[1]) do v
isa(value(v), Delay) ? constant(ModelInput(1)) : v
function atomise(model)
postwalk(graph(model)) do v
hasloops(value(v)) || return v
spliceinputs(atomise(value(v)), inputs(v)...)
end
end
function break!(model)
iscyclic(graph(model)) || return model
g = bumpinputs(graph(model))
hinput(n) = vertex(getindex, constant(ModelInput(1)), constant(n))
function unroll!(delay::IVertex, n)
prewalk!(delay[1]) do v
v === delay ? hinput(n) : v
end
end
function break!(g::IVertex)
g = bumpinputs(g)
loops = []
g = prewalk(g) do v
g = prewalk!(g) do v
isa(value(v), Delay) || return v
push!(loops, unroll(v))
constant(ModelInput(1))
n = length(loops)+1
push!(loops, unroll!(v, n))
hinput(n)
end
cse(vertex(tuple, loops..., g))
cse(vertex(tuple, vertex(tuple, loops...), g))
end
# r = Recurrent(784, 10, 50)
# r = Recurrent(10, 10)
# r = Chain(Dense(10,10), Recurrent(10,10))
# r = Chain(Recurrent(10,10),Recurrent(10,10))
# break!(r)
# break!(atomise(r)) |> syntax |> prettify |> display
# @model type Recurrent
# Wx; Wh; B
# hidden
#
# function (x)
# hidden = σ( Wx*x + Wh*hidden + B )
# end
# end
#
# Recurrent(in::Integer, out::Integer; init = initn) =
# Recurrent(init(out, in), init(out, out), init(out), zeros(out))
@model type Recurrent
Wxh; Whh; Bh
Wxy; Why; By
model
hidden
function (x)
hidden = σ( Wxh*x + Whh*hidden + Bh )
y = σ( Wxy*x + Why*hidden + By )
hidden = σ(model(vcat(x, hidden)))
end
end
Recurrent(in::Integer, out::Integer, hidden::Integer; init = initn) =
Recurrent(init(hidden, in), init(hidden, hidden), init(hidden),
init(out, in), init(out, hidden), init(hidden),
zeros(hidden))
Recurrent(in::Integer, out::Integer; init = initn) =
Recurrent(Dense(in + out, out, init = init), zeros(out))
Base.show(io::IO, r::Recurrent) =
print(io, "Flux.Recurrent(...)")

View File

@ -7,6 +7,8 @@ abstract Model
back!(m::Model, ) = error("Backprop not implemented for $(typeof(m))")
update!(m, η) = m
graph(m) = nothing
# Model parameters
type Param{T}