From 45c5502f903fb2c65f044a502277cd58cf5a9102 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 16:18:41 +0100 Subject: [PATCH 1/6] obviate mapconst --- src/compiler/code.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 772dc01e..d3a1f688 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,4 +1,4 @@ -import DataFlow: mapconst, cse +import DataFlow: cse using MacroTools: @q export @net @@ -6,7 +6,7 @@ export @net function graphdef(ex, params = []) @capture(shortdef(ex), (args__,) -> body_) body = @> body MacroTools.flatten liftloops graphm DataFlow.il - body = mapconst(x -> x in params ? :(self.$x) : x, body) + body = map(x -> x in params ? :(self.$x) : x, body) return args, body end @@ -53,7 +53,7 @@ end import Lazy: groupby -reifyparams(v::IVertex) = mapconst(x -> x isa Param ? x.x : x, v) +reifyparams(v::IVertex) = map(x -> x isa Param ? x.x : x, v) # TODO: type hints for parameters @@ -69,14 +69,14 @@ function process_type(ex) $(build_type(T, params)) $(esc(:((self::$T)($(args...)) = $(build_forward(body, args))))) $(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);) - $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args, params)))) + $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params)))) nothing end end function process_anon(ex) args, body = graphdef(ex) - :(Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args)[1]))))) + :(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1]))))) end function process_def(ex) From 3532c7174f9a167740154018b1d95e8c86af6079 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 17:32:10 +0100 Subject: [PATCH 2/6] early throw on cyclic graphs --- src/compiler/interp.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index f3fe4a21..2e307006 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -15,6 +15,7 @@ end function interp(ctx, f, xs...) g = graph(f) + g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph") @icatch(ctx, g ≠ nothing ? interpret(ctx, reifyparams(g), xs...) : f(xs...)) From f7eb5179b1995b6d0bc782c6eda60782689d6ab6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 17:39:08 +0100 Subject: [PATCH 3/6] fix basic interpreters --- src/compiler/code.jl | 4 ++-- src/compiler/interp.jl | 7 ++----- src/compiler/shape.jl | 15 ++++++++------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index d3a1f688..f9192a90 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -12,7 +12,7 @@ end function makegraph(graph, args, params = []) graph = prewalk(graph) do v - value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ? + isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ? inputnode(i) : v end @@ -42,7 +42,7 @@ end function deref_params(v) map(v) do x - x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x + @capture(x, self.p_) ? :(Flux.state(self.$p)) : x end end diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index 2e307006..6b9a022d 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -1,5 +1,5 @@ function astuple(xs::Vertex) - isconstant(xs) && value(xs).value isa Tuple ? value(xs).value : + isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) : xs isa Vertex && value(xs) == tuple ? inputs(xs) : nothing end @@ -21,10 +21,7 @@ function interp(ctx, f, xs...) f(xs...)) end -interp(ctx::Context, c::Constant{<:Param}) = c.value.x -interp(ctx::Context, c::Constant) = c.value - function interpmodel(m, args...) - ctx = Context(mux(iline, ilambda, iargs, ituple, interp)) + ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp)) @ithrow interp(ctx, m, args...) end diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index a944cc62..7de80e61 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -8,6 +8,10 @@ end DataFlow.tocall(h::Hint, x) = :($x::$(h.typ)) +arghint(p::Param) = arghint(state(p)) +arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) +arghint(x) = constant(x) + function gethint(v::IVertex) while value(v) isa Union{Line,Frame} v = v[1] end value(v) isa Hint && return value(v).typ @@ -17,19 +21,16 @@ end ihint(f, ctx::Context, h::Hint, x) = vertex(h, x) ihint(f, args...) = f(args...) -hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value)) -hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) -hintify(ctx, c::Constant) = vertex(c) - -interpshape = mux(ilinev, ihint, iargs, hintify) - function hintify(ctx, f, xs...) - sh = infer(f, gethint.(xs)...) + xs = arghint.(xs) + sh = infer(f, map(gethint, xs)...) sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) : !any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) : vertex(f, xs...) end +interpshape = mux(ilinev, iconst, ihint, iargs, hintify) + function shapesv(f, args...) (g = graph(f)) == nothing && return ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)] From 7a2a72a74a178eb3a7097c6ae62c7eefb2d0ecf5 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 18:15:47 +0100 Subject: [PATCH 4/6] fix tensorflow --- src/backend/tensorflow/graph.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index c320c63a..5482ea64 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -1,5 +1,5 @@ using Base: @get! -using DataFlow: Constant, constant, Split +using DataFlow: constant, Split using DataFlow.Interpreter using DataFlow.Interpreter: stack using TensorFlow: RawTensor, TFException @@ -53,33 +53,34 @@ graph(p::MaxPool, x) = graph(op::Op, xs...) = op.f(xs...) function graph(ctx::Context, model, args...) - node = graph(model, interpv(ctx, args)...) + node = graph(model, args...) node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx)) return node end interp(ctx, c::Conv2D, x) = - nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID") + nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID") -interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = - haskey(ctx[:params], p.value) ? - ctx[:params][p.value] : - (ctx[:params][p.value] = +param(ctx, p::Flux.Param{<:AArray}) = + haskey(ctx[:params], p) ? + ctx[:params][p] : + (ctx[:params][p] = ctx[:variables] ? - Variable(Float32.(p.value.x)) : + Variable(Float32.(p.x)) : placeholder(Float32)) -interp(ctx, p::Constant) = p.value +param(ctx, x) = x function interp(ctx, model, args...) + args = param.(ctx, args) g = Flux.graph(model) g == nothing && return graph(ctx, model, args...) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") - interpret(ctx, g, interpv(ctx, args)...) + interpret(ctx, g, args...) end function tograph(model, args...; variables = false) - ctx = Context(mux(iline, ilambda, interp), + ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp), params = ObjectIdDict(), stacks = Dict(), variables = variables) out = interp(ctx, model, map(constant, args)...) return ctx[:params], ctx[:stacks], out From c7f8d86f9e255e6c5e9a58e9148fea0b2b0d877d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 18:24:14 +0100 Subject: [PATCH 5/6] fix mxnet --- src/backend/mxnet/graph.jl | 20 +++++++++++++------- src/backend/mxnet/model.jl | 9 --------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index fe9ccd07..313519a0 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) graph(::Input, x) = x +struct AlterParam + param + load + store +end + +Base.size(p::AlterParam) = size(p.load(p.param.x)) +Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x)) + graph(ctx::Context, d::Affine, x) = !ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) : register(ctx, @@ -74,19 +83,16 @@ end register(ctx::Context, node) = node -function var(ctx::Context, p) +function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam}) id = gensym() ctx[:params][id] = p return mx.Variable(id) end -graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value) - -graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value) - -graph(ctx::Context, p::Constant) = p.value +var(ctx::Context, x) = x function graph(ctx::Context, model, args...) + args = var.(ctx, args) g = Flux.graph(model) g == nothing && return register(ctx, @icatch ctx graph(model, args...)) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") @@ -96,7 +102,7 @@ end graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...) function tograph(model, args...; feedforward = false) - ctx = Context(mux(iline, ilambda, iargs, ituple, graph′), + ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′), params = Dict(), stacks = Dict(), feedforward = feedforward) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 507c01c2..7d3c7895 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,14 +1,5 @@ using Flux: collectt, shapecheckt -struct AlterParam - param - load - store -end - -Base.size(p::AlterParam) = size(p.load(p.param.x)) -Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x)) - function copyargs!(as, bs) for id in intersect(keys(as), keys(bs)) copy!(as[id], bs[id]) From 7ce782b959de8ec98f2e91b5caef816f7e628cef Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 1 Jun 2017 16:51:26 +0100 Subject: [PATCH 6/6] checkout dataflow for now --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9927c195..9ef6c40c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ julia: # uncomment the following lines to override the default test script script: - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi - - julia -e 'Pkg.clone(pwd()); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)' + - julia -e 'Pkg.clone(pwd()); Pkg.checkout("DataFlow"); Pkg.build("Flux"); Pkg.test("Flux"; coverage=true)' after_success: # - julia -e 'Pkg.add("Documenter")' - julia -e 'Pkg.clone("https://github.com/MikeInnes/Documenter.jl")'