we'll do this differently
This commit is contained in:
parent
2ff7843bca
commit
9d7164f15f
|
@ -36,8 +36,6 @@ include("layers/normalisation.jl")
|
|||
|
||||
include("data/Data.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
module JIT
|
||||
|
||||
using MacroTools
|
||||
|
||||
include("shapes.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
||||
end
|
|
@ -1,40 +0,0 @@
|
|||
# Primitive definitions
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||
Shape{T}(size(A,1))
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
||||
Shape{T}(size(A,1),size(B,2))
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||
|
||||
shape(::typeof(reshape), x::Shape{T}, i...) where T =
|
||||
Shape{T}(Base._reshape_uncolon(x, i))
|
||||
|
||||
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
using ..Tracker: _conv, _maxpool
|
||||
|
||||
shape(::typeof(softmax), x) = x
|
||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||
|
||||
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
|
||||
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
|
||||
|
||||
inplace!(::typeof(_conv), y, x, w, stride, pad) =
|
||||
NNlib.conv!(y, x, w, stride = stride, pad = pad)
|
||||
|
||||
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
|
||||
Shape{T}(NNlib.pdims(size(x), k, pad, k))
|
||||
|
||||
inplace!(::typeof(_maxpool), y, x, k, pad) =
|
||||
NNlib.maxpool!(y, x, k, pad = pad)
|
|
@ -1,56 +0,0 @@
|
|||
using ..Tracker: TrackedArray
|
||||
|
||||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
||||
VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
Base.ndims(s::Shape{T,N}) where {T,N} = N
|
||||
Base.length(s::Shape) = prod(s.dims)
|
||||
Base.eltype(s::Shape{T}) where T = T
|
||||
|
||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
||||
|
||||
function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, "Shape{$T}(")
|
||||
join(io, s.dims, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = x
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
shape(x::TrackedArray) = shape(x.data)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
||||
# Recover structure from byte buffers
|
||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
||||
|
||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
# Execution with caches
|
||||
|
||||
mutable struct Cached{F,A}
|
||||
f::F
|
||||
buffer::A
|
||||
end
|
||||
|
||||
function (c::Cached)(args...)
|
||||
sh = shape(c.f, shape(args)...)
|
||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
||||
y = restructure(sh, c.buffer)
|
||||
inplace!(c.f, y, args...)
|
||||
end
|
|
@ -1,75 +0,0 @@
|
|||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Tracker, DataFlow
|
||||
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(y -> x === y, inputs)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
# Graph manipulation
|
||||
|
||||
function liftparams(v)
|
||||
ps = []
|
||||
v = prewalk(DataFlow.bumpinputs(v)) do v
|
||||
isconstant(v) && istracked(v.value.value) || return v
|
||||
push!(ps, v.value.value)
|
||||
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
||||
end
|
||||
return v, ps
|
||||
end
|
||||
|
||||
function cacheall(v, buf = () -> UInt8[])
|
||||
prewalk(v) do v
|
||||
iscall(v) && isconstant(v[1]) || return v
|
||||
f = v[1].value.value
|
||||
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
|
||||
end
|
||||
end
|
||||
|
||||
code(v, n) = syntax(vertex(Lambda(n, v)))
|
||||
|
||||
struct Compiled{F,T<:Tuple}
|
||||
model
|
||||
func::F
|
||||
params::T
|
||||
end
|
||||
|
||||
# TODO when we support derivatives
|
||||
# (c::Compiled)(args...) =
|
||||
# Tracker.track(Tracker.Call(c, args...),
|
||||
# c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
||||
|
||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||
end
|
||||
|
||||
function source(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(v)
|
||||
code(v, length(args)+1) |> prettify
|
||||
end
|
12
test/jit.jl
12
test/jit.jl
|
@ -1,12 +0,0 @@
|
|||
using Flux, Base.Test
|
||||
using Flux.JIT: compile
|
||||
|
||||
@testset "JIT" begin
|
||||
|
||||
m = Dense(10, 5)
|
||||
f = compile(m, rand(10))
|
||||
x = rand(10)
|
||||
|
||||
@test Tracker.data(m(x)) == f(x)
|
||||
|
||||
end
|
|
@ -10,7 +10,6 @@ include("layers/normalisation.jl")
|
|||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
include("jit.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
|
|
Loading…
Reference in New Issue