humble beginnings of compiler
This commit is contained in:
parent
fc157a8c59
commit
70fbbf48fa
@ -34,6 +34,8 @@ include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
5
src/jit/JIT.jl
Normal file
5
src/jit/JIT.jl
Normal file
@ -0,0 +1,5 @@
|
||||
module JIT
|
||||
|
||||
include("trace.jl")
|
||||
|
||||
end
|
25
src/jit/trace.jl
Normal file
25
src/jit/trace.jl
Normal file
@ -0,0 +1,25 @@
|
||||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
@ -47,29 +47,6 @@ include("numeric.jl")
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function _graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function graph(f, args...)
|
||||
inputs = param.(args)
|
||||
_graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
Loading…
Reference in New Issue
Block a user