compile layers
This commit is contained in:
parent
5a32976cbf
commit
a401f08cda
|
@ -1,3 +1,5 @@
|
|||
using ..Tracker: TrackedArray
|
||||
|
||||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
@ -24,6 +26,7 @@ shape(x) = typeof(x)
|
|||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
shape(x::TrackedArray) = shape(x.data)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Flux.Tracker, DataFlow
|
||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
|
@ -27,6 +27,16 @@ end
|
|||
|
||||
# Graph manipulation
|
||||
|
||||
function liftparams(v)
|
||||
ps = []
|
||||
v = prewalk(DataFlow.bumpinputs(v)) do v
|
||||
isconstant(v) && istracked(v.value.value) || return v
|
||||
push!(ps, v.value.value)
|
||||
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
||||
end
|
||||
return v, ps
|
||||
end
|
||||
|
||||
function cacheall(v, buf = () -> UInt8[])
|
||||
prewalk(v) do v
|
||||
iscall(v) && isconstant(v[1]) || return v
|
||||
|
@ -40,7 +50,17 @@ function eval_func(v, n)
|
|||
v |> syntax |> eval
|
||||
end
|
||||
|
||||
struct Compiled{F,T<:Tuple}
|
||||
func::F
|
||||
params::T
|
||||
end
|
||||
|
||||
(c::Compiled)(args...) =
|
||||
Tracker.track(Tracker.Call(c, args...),
|
||||
c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
eval_func(cacheall(v), length(args))
|
||||
v, ps = liftparams(cacheall(v))
|
||||
Compiled(eval_func(v, length(args)+1), (ps...,))
|
||||
end
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
using Flux, Base.Test
|
||||
using Flux.JIT: compile
|
||||
|
||||
@testset "JIT" begin
|
||||
|
||||
m = Dense(10, 5)
|
||||
f = compile(m, rand(10))
|
||||
x = rand(10)
|
||||
|
||||
@test m(x) == f(x)
|
||||
|
||||
end
|
|
@ -10,6 +10,7 @@ include("layers/normalisation.jl")
|
|||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
include("jit.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
|
|
Loading…
Reference in New Issue