compile layers

This commit is contained in:
Mike J Innes 2018-02-27 22:40:51 +00:00
parent 5a32976cbf
commit a401f08cda
4 changed files with 38 additions and 2 deletions

View File

@ -1,3 +1,5 @@
using ..Tracker: TrackedArray
struct Shape{T,N}
dims::NTuple{N,Int}
end
@ -24,6 +26,7 @@ shape(x) = typeof(x)
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
shape(x::TrackedArray) = shape(x.data)
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))

View File

@ -1,7 +1,7 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Flux.Tracker, DataFlow
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
@ -27,6 +27,16 @@ end
# Graph manipulation
function liftparams(v)
ps = []
v = prewalk(DataFlow.bumpinputs(v)) do v
isconstant(v) && istracked(v.value.value) || return v
push!(ps, v.value.value)
DataFlow.vcall(getindex, inputnode(1), length(ps))
end
return v, ps
end
function cacheall(v, buf = () -> UInt8[])
prewalk(v) do v
iscall(v) && isconstant(v[1]) || return v
@ -40,7 +50,17 @@ function eval_func(v, n)
v |> syntax |> eval
end
struct Compiled{F,T<:Tuple}
func::F
params::T
end
(c::Compiled)(args...) =
Tracker.track(Tracker.Call(c, args...),
c.func(Tracker.data.(c.params), args...))
function compile(f, args...)
v = trace(f, args...)
eval_func(cacheall(v), length(args))
v, ps = liftparams(cacheall(v))
Compiled(eval_func(v, length(args)+1), (ps...,))
end

12
test/jit.jl Normal file
View File

@ -0,0 +1,12 @@
using Flux, Base.Test
using Flux.JIT: compile
@testset "JIT" begin
m = Dense(10, 5)
f = compile(m, rand(10))
x = rand(10)
@test m(x) == f(x)
end

View File

@ -10,6 +10,7 @@ include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
include("jit.jl")
if Base.find_in_path("CuArrays") nothing
include("cuda/cuda.jl")