recurrence proof of concept
This commit is contained in:
parent
d58fefb972
commit
abcb6d6351
|
@ -7,6 +7,7 @@ using MacroTools, Lazy, Flow
|
|||
include("model.jl")
|
||||
include("utils.jl")
|
||||
|
||||
include("compiler/graph.jl")
|
||||
include("compiler/diff.jl")
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
|
|
|
@ -11,10 +11,6 @@ function process_func(ex, params = [])
|
|||
return args, body
|
||||
end
|
||||
|
||||
immutable ModelInput
|
||||
name
|
||||
end
|
||||
|
||||
function makegraph(graph, args)
|
||||
@assert length(args) == 1
|
||||
mapconst(graph) do x
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# TODO: change the input approach
|
||||
immutable ModelInput
|
||||
name
|
||||
end
|
||||
|
||||
isinput(x) = isa(x, Constant) && isa(x.value, ModelInput) && isa(x.value.name, Integer)
|
||||
|
||||
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
||||
bumpinput(x) = x
|
||||
|
||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||
|
||||
function spliceinputs(v::IVertex, inputs::IVertex...)
|
||||
postwalk(v) do v
|
||||
isinput(value(v)) ?
|
||||
inputs[value(v).value.name] :
|
||||
v
|
||||
end
|
||||
end
|
|
@ -15,48 +15,70 @@ function liftloops!(ex, params)
|
|||
return ex
|
||||
end
|
||||
|
||||
bumpinput(i::ModelInput) = isa(i.name, Integer) ? ModelInput(i.name + 1) : i
|
||||
bumpinput(x) = x
|
||||
function hasloops(model)
|
||||
g = graph(model)
|
||||
g == nothing && return false
|
||||
iscyclic(g) && return true
|
||||
result = false
|
||||
map(m -> hasloops(m) && (result = true), g)
|
||||
return result
|
||||
end
|
||||
|
||||
bumpinputs(v::IVertex) = mapconst(bumpinput, v)
|
||||
|
||||
function unroll(delay::IVertex)
|
||||
prewalk(delay[1]) do v
|
||||
isa(value(v), Delay) ? constant(ModelInput(1)) : v
|
||||
function atomise(model)
|
||||
postwalk(graph(model)) do v
|
||||
hasloops(value(v)) || return v
|
||||
spliceinputs(atomise(value(v)), inputs(v)...)
|
||||
end
|
||||
end
|
||||
|
||||
function break!(model)
|
||||
iscyclic(graph(model)) || return model
|
||||
g = bumpinputs(graph(model))
|
||||
hinput(n) = vertex(getindex, constant(ModelInput(1)), constant(n))
|
||||
|
||||
function unroll!(delay::IVertex, n)
|
||||
prewalk!(delay[1]) do v
|
||||
v === delay ? hinput(n) : v
|
||||
end
|
||||
end
|
||||
|
||||
function break!(g::IVertex)
|
||||
g = bumpinputs(g)
|
||||
loops = []
|
||||
g = prewalk(g) do v
|
||||
g = prewalk!(g) do v
|
||||
isa(value(v), Delay) || return v
|
||||
push!(loops, unroll(v))
|
||||
constant(ModelInput(1))
|
||||
n = length(loops)+1
|
||||
push!(loops, unroll!(v, n))
|
||||
hinput(n)
|
||||
end
|
||||
cse(vertex(tuple, loops..., g))
|
||||
cse(vertex(tuple, vertex(tuple, loops...), g))
|
||||
end
|
||||
|
||||
# r = Recurrent(784, 10, 50)
|
||||
# r = Recurrent(10, 10)
|
||||
# r = Chain(Dense(10,10), Recurrent(10,10))
|
||||
# r = Chain(Recurrent(10,10),Recurrent(10,10))
|
||||
|
||||
# break!(r)
|
||||
# break!(atomise(r)) |> syntax |> prettify |> display
|
||||
|
||||
# @model type Recurrent
|
||||
# Wx; Wh; B
|
||||
# hidden
|
||||
#
|
||||
# function (x)
|
||||
# hidden = σ( Wx*x + Wh*hidden + B )
|
||||
# end
|
||||
# end
|
||||
#
|
||||
# Recurrent(in::Integer, out::Integer; init = initn) =
|
||||
# Recurrent(init(out, in), init(out, out), init(out), zeros(out))
|
||||
|
||||
@model type Recurrent
|
||||
Wxh; Whh; Bh
|
||||
Wxy; Why; By
|
||||
model
|
||||
hidden
|
||||
|
||||
function (x)
|
||||
hidden = σ( Wxh*x + Whh*hidden + Bh )
|
||||
y = σ( Wxy*x + Why*hidden + By )
|
||||
hidden = σ(model(vcat(x, hidden)))
|
||||
end
|
||||
end
|
||||
|
||||
Recurrent(in::Integer, out::Integer, hidden::Integer; init = initn) =
|
||||
Recurrent(init(hidden, in), init(hidden, hidden), init(hidden),
|
||||
init(out, in), init(out, hidden), init(hidden),
|
||||
zeros(hidden))
|
||||
Recurrent(in::Integer, out::Integer; init = initn) =
|
||||
Recurrent(Dense(in + out, out, init = init), zeros(out))
|
||||
|
||||
Base.show(io::IO, r::Recurrent) =
|
||||
print(io, "Flux.Recurrent(...)")
|
||||
|
|
|
@ -7,6 +7,8 @@ abstract Model
|
|||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||
update!(m, η) = m
|
||||
|
||||
graph(m) = nothing
|
||||
|
||||
# Model parameters
|
||||
|
||||
type Param{T}
|
||||
|
|
Loading…
Reference in New Issue