humble beginnings of compiler

This commit is contained in:
Mike J Innes 2018-02-08 18:11:26 +00:00
parent fc157a8c59
commit 70fbbf48fa
4 changed files with 32 additions and 23 deletions

View File

@ -34,6 +34,8 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalisation.jl")
include("jit/JIT.jl")
include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl")

5
src/jit/JIT.jl Normal file
View File

@ -0,0 +1,5 @@
module JIT
include("trace.jl")
end

25
src/jit/trace.jl Normal file
View File

@ -0,0 +1,25 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow
using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
function trace(f, args...)
inputs = param.(args)
graph(f(inputs...), inputs...)
end

View File

@ -47,29 +47,6 @@ include("numeric.jl")
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
using DataFlow
using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
function _graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
function graph(f, args...)
inputs = param.(args)
_graph(f(inputs...), inputs...)
end
import Adapt.adapt
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))