compile layers
This commit is contained in:
parent
5a32976cbf
commit
a401f08cda
@ -1,3 +1,5 @@
|
|||||||
|
using ..Tracker: TrackedArray
|
||||||
|
|
||||||
struct Shape{T,N}
|
struct Shape{T,N}
|
||||||
dims::NTuple{N,Int}
|
dims::NTuple{N,Int}
|
||||||
end
|
end
|
||||||
@ -24,6 +26,7 @@ shape(x) = typeof(x)
|
|||||||
shape(x::Shape) = x
|
shape(x::Shape) = x
|
||||||
shape(x::Tuple) = shape.(x)
|
shape(x::Tuple) = shape.(x)
|
||||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||||
|
shape(x::TrackedArray) = shape(x.data)
|
||||||
|
|
||||||
bytes(s::Shape) = sizeof(s)
|
bytes(s::Shape) = sizeof(s)
|
||||||
bytes(x::Tuple) = sum(bytes.(x))
|
bytes(x::Tuple) = sum(bytes.(x))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||||
|
|
||||||
|
using ..Flux.Tracker, DataFlow
|
||||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||||
using DataFlow
|
|
||||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||||
inputnode, constant
|
inputnode, constant
|
||||||
|
|
||||||
@ -27,6 +27,16 @@ end
|
|||||||
|
|
||||||
# Graph manipulation
|
# Graph manipulation
|
||||||
|
|
||||||
|
function liftparams(v)
|
||||||
|
ps = []
|
||||||
|
v = prewalk(DataFlow.bumpinputs(v)) do v
|
||||||
|
isconstant(v) && istracked(v.value.value) || return v
|
||||||
|
push!(ps, v.value.value)
|
||||||
|
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
||||||
|
end
|
||||||
|
return v, ps
|
||||||
|
end
|
||||||
|
|
||||||
function cacheall(v, buf = () -> UInt8[])
|
function cacheall(v, buf = () -> UInt8[])
|
||||||
prewalk(v) do v
|
prewalk(v) do v
|
||||||
iscall(v) && isconstant(v[1]) || return v
|
iscall(v) && isconstant(v[1]) || return v
|
||||||
@ -40,7 +50,17 @@ function eval_func(v, n)
|
|||||||
v |> syntax |> eval
|
v |> syntax |> eval
|
||||||
end
|
end
|
||||||
|
|
||||||
|
struct Compiled{F,T<:Tuple}
|
||||||
|
func::F
|
||||||
|
params::T
|
||||||
|
end
|
||||||
|
|
||||||
|
(c::Compiled)(args...) =
|
||||||
|
Tracker.track(Tracker.Call(c, args...),
|
||||||
|
c.func(Tracker.data.(c.params), args...))
|
||||||
|
|
||||||
function compile(f, args...)
|
function compile(f, args...)
|
||||||
v = trace(f, args...)
|
v = trace(f, args...)
|
||||||
eval_func(cacheall(v), length(args))
|
v, ps = liftparams(cacheall(v))
|
||||||
|
Compiled(eval_func(v, length(args)+1), (ps...,))
|
||||||
end
|
end
|
||||||
|
12
test/jit.jl
Normal file
12
test/jit.jl
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
using Flux, Base.Test
|
||||||
|
using Flux.JIT: compile
|
||||||
|
|
||||||
|
@testset "JIT" begin
|
||||||
|
|
||||||
|
m = Dense(10, 5)
|
||||||
|
f = compile(m, rand(10))
|
||||||
|
x = rand(10)
|
||||||
|
|
||||||
|
@test m(x) == f(x)
|
||||||
|
|
||||||
|
end
|
@ -10,6 +10,7 @@ include("layers/normalisation.jl")
|
|||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
include("jit.jl")
|
||||||
|
|
||||||
if Base.find_in_path("CuArrays") ≠ nothing
|
if Base.find_in_path("CuArrays") ≠ nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user