From 70fbbf48fabe15cf941a64b824b7186f4a88646b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 8 Feb 2018 18:11:26 +0000 Subject: [PATCH] humble beginnings of compiler --- src/Flux.jl | 2 ++ src/jit/JIT.jl | 5 +++++ src/jit/trace.jl | 25 +++++++++++++++++++++++++ src/tracker/Tracker.jl | 23 ----------------------- 4 files changed, 32 insertions(+), 23 deletions(-) create mode 100644 src/jit/JIT.jl create mode 100644 src/jit/trace.jl diff --git a/src/Flux.jl b/src/Flux.jl index 30baf2bd..8ad4d1f9 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,6 +34,8 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalisation.jl") +include("jit/JIT.jl") + include("data/Data.jl") @require CuArrays include("cuda/cuda.jl") diff --git a/src/jit/JIT.jl b/src/jit/JIT.jl new file mode 100644 index 00000000..3283005a --- /dev/null +++ b/src/jit/JIT.jl @@ -0,0 +1,5 @@ +module JIT + +include("trace.jl") + +end diff --git a/src/jit/trace.jl b/src/jit/trace.jl new file mode 100644 index 00000000..5557741e --- /dev/null +++ b/src/jit/trace.jl @@ -0,0 +1,25 @@ +# This is hacky; we'll eventually reuse Cassette for better tracing. + +using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf +using DataFlow +using DataFlow: inputnode, constant + +vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) +vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) + +graph(x::Tracked, inputs...; cache = ObjectIdDict()) = + vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...) + +function graph(x, inputs...; cache = ObjectIdDict()) + haskey(cache, x) && return cache[x] + i = findfirst(inputs, x) + cache[x] = + i > 0 ? inputnode(i) : + istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) : + constant(x) +end + +function trace(f, args...) + inputs = param.(args) + graph(f(inputs...), inputs...) +end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index cb2547bc..f1fa7a7d 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -47,29 +47,6 @@ include("numeric.jl") param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) -using DataFlow -using DataFlow: inputnode, constant - -vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) -vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) - -_graph(x::Tracked, inputs...; cache = ObjectIdDict()) = - vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...) - -function _graph(x, inputs...; cache = ObjectIdDict()) - haskey(cache, x) && return cache[x] - i = findfirst(inputs, x) - cache[x] = - i > 0 ? inputnode(i) : - istracked(x) ? _graph(tracker(x), inputs...; cache = cache) : - constant(x) -end - -function graph(f, args...) - inputs = param.(args) - _graph(f(inputs...), inputs...) -end - import Adapt.adapt adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))