humble beginnings of compiler
This commit is contained in:
parent
fc157a8c59
commit
70fbbf48fa
@ -34,6 +34,8 @@ include("layers/conv.jl")
|
|||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/normalisation.jl")
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
|
include("jit/JIT.jl")
|
||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@require CuArrays include("cuda/cuda.jl")
|
||||||
|
5
src/jit/JIT.jl
Normal file
5
src/jit/JIT.jl
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module JIT
|
||||||
|
|
||||||
|
include("trace.jl")
|
||||||
|
|
||||||
|
end
|
25
src/jit/trace.jl
Normal file
25
src/jit/trace.jl
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||||
|
|
||||||
|
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||||
|
using DataFlow
|
||||||
|
using DataFlow: inputnode, constant
|
||||||
|
|
||||||
|
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||||
|
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||||
|
|
||||||
|
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||||
|
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||||
|
|
||||||
|
function graph(x, inputs...; cache = ObjectIdDict())
|
||||||
|
haskey(cache, x) && return cache[x]
|
||||||
|
i = findfirst(inputs, x)
|
||||||
|
cache[x] =
|
||||||
|
i > 0 ? inputnode(i) :
|
||||||
|
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||||
|
constant(x)
|
||||||
|
end
|
||||||
|
|
||||||
|
function trace(f, args...)
|
||||||
|
inputs = param.(args)
|
||||||
|
graph(f(inputs...), inputs...)
|
||||||
|
end
|
@ -47,29 +47,6 @@ include("numeric.jl")
|
|||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
using DataFlow
|
|
||||||
using DataFlow: inputnode, constant
|
|
||||||
|
|
||||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
|
||||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
|
||||||
|
|
||||||
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
|
||||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
|
||||||
|
|
||||||
function _graph(x, inputs...; cache = ObjectIdDict())
|
|
||||||
haskey(cache, x) && return cache[x]
|
|
||||||
i = findfirst(inputs, x)
|
|
||||||
cache[x] =
|
|
||||||
i > 0 ? inputnode(i) :
|
|
||||||
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
|
|
||||||
constant(x)
|
|
||||||
end
|
|
||||||
|
|
||||||
function graph(f, args...)
|
|
||||||
inputs = param.(args)
|
|
||||||
_graph(f(inputs...), inputs...)
|
|
||||||
end
|
|
||||||
|
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||||
|
Loading…
Reference in New Issue
Block a user