we'll do this differently
This commit is contained in:
parent
2ff7843bca
commit
9d7164f15f
@ -36,8 +36,6 @@ include("layers/normalisation.jl")
|
|||||||
|
|
||||||
include("data/Data.jl")
|
include("data/Data.jl")
|
||||||
|
|
||||||
include("jit/JIT.jl")
|
|
||||||
|
|
||||||
@require CuArrays include("cuda/cuda.jl")
|
@require CuArrays include("cuda/cuda.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
module JIT
|
|
||||||
|
|
||||||
using MacroTools
|
|
||||||
|
|
||||||
include("shapes.jl")
|
|
||||||
include("trace.jl")
|
|
||||||
include("lib.jl")
|
|
||||||
|
|
||||||
end
|
|
@ -1,40 +0,0 @@
|
|||||||
# Primitive definitions
|
|
||||||
|
|
||||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
|
||||||
Shape{T}(size(A,1))
|
|
||||||
|
|
||||||
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
|
||||||
Shape{T}(size(A,1),size(B,2))
|
|
||||||
|
|
||||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
|
||||||
A_mul_B!(C, A, B)
|
|
||||||
|
|
||||||
shape(::typeof(broadcast), f, xs...) =
|
|
||||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
|
||||||
|
|
||||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
|
||||||
|
|
||||||
shape(::typeof(reshape), x::Shape{T}, i...) where T =
|
|
||||||
Shape{T}(Base._reshape_uncolon(x, i))
|
|
||||||
|
|
||||||
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
|
|
||||||
|
|
||||||
# NNlib
|
|
||||||
|
|
||||||
using NNlib
|
|
||||||
using ..Tracker: _conv, _maxpool
|
|
||||||
|
|
||||||
shape(::typeof(softmax), x) = x
|
|
||||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
|
||||||
|
|
||||||
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
|
|
||||||
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
|
|
||||||
|
|
||||||
inplace!(::typeof(_conv), y, x, w, stride, pad) =
|
|
||||||
NNlib.conv!(y, x, w, stride = stride, pad = pad)
|
|
||||||
|
|
||||||
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
|
|
||||||
Shape{T}(NNlib.pdims(size(x), k, pad, k))
|
|
||||||
|
|
||||||
inplace!(::typeof(_maxpool), y, x, k, pad) =
|
|
||||||
NNlib.maxpool!(y, x, k, pad = pad)
|
|
@ -1,56 +0,0 @@
|
|||||||
using ..Tracker: TrackedArray
|
|
||||||
|
|
||||||
struct Shape{T,N}
|
|
||||||
dims::NTuple{N,Int}
|
|
||||||
end
|
|
||||||
|
|
||||||
VecShape{T} = Shape{T,1}
|
|
||||||
MatShape{T} = Shape{T,2}
|
|
||||||
|
|
||||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
|
||||||
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
|
||||||
|
|
||||||
Base.size(s::Shape) = s.dims
|
|
||||||
Base.size(s::Shape, n) = s.dims[n]
|
|
||||||
Base.ndims(s::Shape{T,N}) where {T,N} = N
|
|
||||||
Base.length(s::Shape) = prod(s.dims)
|
|
||||||
Base.eltype(s::Shape{T}) where T = T
|
|
||||||
|
|
||||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
|
||||||
|
|
||||||
function Base.show(io::IO, s::Shape{T}) where T
|
|
||||||
print(io, "Shape{$T}(")
|
|
||||||
join(io, s.dims, ", ")
|
|
||||||
print(io, ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
shape(x) = x
|
|
||||||
shape(x::Shape) = x
|
|
||||||
shape(x::Tuple) = shape.(x)
|
|
||||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
|
||||||
shape(x::TrackedArray) = shape(x.data)
|
|
||||||
|
|
||||||
bytes(s::Shape) = sizeof(s)
|
|
||||||
bytes(x::Tuple) = sum(bytes.(x))
|
|
||||||
|
|
||||||
# Recover structure from byte buffers
|
|
||||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
|
||||||
|
|
||||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
|
||||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
|
||||||
reshape(reinterpret(T, buf), size(sh))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Execution with caches
|
|
||||||
|
|
||||||
mutable struct Cached{F,A}
|
|
||||||
f::F
|
|
||||||
buffer::A
|
|
||||||
end
|
|
||||||
|
|
||||||
function (c::Cached)(args...)
|
|
||||||
sh = shape(c.f, shape(args)...)
|
|
||||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
|
||||||
y = restructure(sh, c.buffer)
|
|
||||||
inplace!(c.f, y, args...)
|
|
||||||
end
|
|
@ -1,75 +0,0 @@
|
|||||||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
|
||||||
|
|
||||||
using ..Tracker, DataFlow
|
|
||||||
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
|
||||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
|
||||||
inputnode, constant
|
|
||||||
|
|
||||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
|
||||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
|
||||||
|
|
||||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
|
||||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
|
||||||
|
|
||||||
function graph(x, inputs...; cache = ObjectIdDict())
|
|
||||||
haskey(cache, x) && return cache[x]
|
|
||||||
i = findfirst(y -> x === y, inputs)
|
|
||||||
cache[x] =
|
|
||||||
i > 0 ? inputnode(i) :
|
|
||||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
|
||||||
constant(x)
|
|
||||||
end
|
|
||||||
|
|
||||||
function trace(f, args...)
|
|
||||||
inputs = param.(args)
|
|
||||||
graph(f(inputs...), inputs...)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Graph manipulation
|
|
||||||
|
|
||||||
function liftparams(v)
|
|
||||||
ps = []
|
|
||||||
v = prewalk(DataFlow.bumpinputs(v)) do v
|
|
||||||
isconstant(v) && istracked(v.value.value) || return v
|
|
||||||
push!(ps, v.value.value)
|
|
||||||
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
|
||||||
end
|
|
||||||
return v, ps
|
|
||||||
end
|
|
||||||
|
|
||||||
function cacheall(v, buf = () -> UInt8[])
|
|
||||||
prewalk(v) do v
|
|
||||||
iscall(v) && isconstant(v[1]) || return v
|
|
||||||
f = v[1].value.value
|
|
||||||
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
code(v, n) = syntax(vertex(Lambda(n, v)))
|
|
||||||
|
|
||||||
struct Compiled{F,T<:Tuple}
|
|
||||||
model
|
|
||||||
func::F
|
|
||||||
params::T
|
|
||||||
end
|
|
||||||
|
|
||||||
# TODO when we support derivatives
|
|
||||||
# (c::Compiled)(args...) =
|
|
||||||
# Tracker.track(Tracker.Call(c, args...),
|
|
||||||
# c.func(Tracker.data.(c.params), args...))
|
|
||||||
|
|
||||||
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
|
||||||
|
|
||||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
|
||||||
|
|
||||||
function compile(f, args...)
|
|
||||||
v = trace(f, args...)
|
|
||||||
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
|
||||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
|
||||||
end
|
|
||||||
|
|
||||||
function source(f, args...)
|
|
||||||
v = trace(f, args...)
|
|
||||||
v, ps = liftparams(v)
|
|
||||||
code(v, length(args)+1) |> prettify
|
|
||||||
end
|
|
12
test/jit.jl
12
test/jit.jl
@ -1,12 +0,0 @@
|
|||||||
using Flux, Base.Test
|
|
||||||
using Flux.JIT: compile
|
|
||||||
|
|
||||||
@testset "JIT" begin
|
|
||||||
|
|
||||||
m = Dense(10, 5)
|
|
||||||
f = compile(m, rand(10))
|
|
||||||
x = rand(10)
|
|
||||||
|
|
||||||
@test Tracker.data(m(x)) == f(x)
|
|
||||||
|
|
||||||
end
|
|
@ -10,7 +10,6 @@ include("layers/normalisation.jl")
|
|||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
include("jit.jl")
|
|
||||||
|
|
||||||
if Base.find_in_path("CuArrays") ≠ nothing
|
if Base.find_in_path("CuArrays") ≠ nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user