we'll do this differently

This commit is contained in:
Mike Innes 2018-04-14 01:57:46 +01:00
parent 2ff7843bca
commit 9d7164f15f
7 changed files with 0 additions and 195 deletions

View File

@ -36,8 +36,6 @@ include("layers/normalisation.jl")
include("data/Data.jl")
include("jit/JIT.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

View File

@ -1,9 +0,0 @@
module JIT
using MacroTools
include("shapes.jl")
include("trace.jl")
include("lib.jl")
end

View File

@ -1,40 +0,0 @@
# Primitive definitions
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
Shape{T}(size(A,1))
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
Shape{T}(size(A,1),size(B,2))
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B)
shape(::typeof(broadcast), f, xs...) =
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
shape(::typeof(reshape), x::Shape{T}, i...) where T =
Shape{T}(Base._reshape_uncolon(x, i))
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
# NNlib
using NNlib
using ..Tracker: _conv, _maxpool
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
inplace!(::typeof(_conv), y, x, w, stride, pad) =
NNlib.conv!(y, x, w, stride = stride, pad = pad)
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
Shape{T}(NNlib.pdims(size(x), k, pad, k))
inplace!(::typeof(_maxpool), y, x, k, pad) =
NNlib.maxpool!(y, x, k, pad = pad)

View File

@ -1,56 +0,0 @@
using ..Tracker: TrackedArray
struct Shape{T,N}
dims::NTuple{N,Int}
end
VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
Base.ndims(s::Shape{T,N}) where {T,N} = N
Base.length(s::Shape) = prod(s.dims)
Base.eltype(s::Shape{T}) where T = T
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
function Base.show(io::IO, s::Shape{T}) where T
print(io, "Shape{$T}(")
join(io, s.dims, ", ")
print(io, ")")
end
shape(x) = x
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
shape(x::TrackedArray) = shape(x.data)
bytes(s::Shape) = sizeof(s)
bytes(x::Tuple) = sum(bytes.(x))
# Recover structure from byte buffers
# Make sure to hold on to the parent buffer for the lifetime of the data.
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
reshape(reinterpret(T, buf), size(sh))
end
# Execution with caches
mutable struct Cached{F,A}
f::F
buffer::A
end
function (c::Cached)(args...)
sh = shape(c.f, shape(args)...)
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
y = restructure(sh, c.buffer)
inplace!(c.f, y, args...)
end

View File

@ -1,75 +0,0 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Tracker, DataFlow
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(y -> x === y, inputs)
cache[x] =
i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
function trace(f, args...)
inputs = param.(args)
graph(f(inputs...), inputs...)
end
# Graph manipulation
function liftparams(v)
ps = []
v = prewalk(DataFlow.bumpinputs(v)) do v
isconstant(v) && istracked(v.value.value) || return v
push!(ps, v.value.value)
DataFlow.vcall(getindex, inputnode(1), length(ps))
end
return v, ps
end
function cacheall(v, buf = () -> UInt8[])
prewalk(v) do v
iscall(v) && isconstant(v[1]) || return v
f = v[1].value.value
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
end
end
code(v, n) = syntax(vertex(Lambda(n, v)))
struct Compiled{F,T<:Tuple}
model
func::F
params::T
end
# TODO when we support derivatives
# (c::Compiled)(args...) =
# Tracker.track(Tracker.Call(c, args...),
# c.func(Tracker.data.(c.params), args...))
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...)
v = trace(f, args...)
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
end
function source(f, args...)
v = trace(f, args...)
v, ps = liftparams(v)
code(v, length(args)+1) |> prettify
end

View File

@ -1,12 +0,0 @@
using Flux, Base.Test
using Flux.JIT: compile
@testset "JIT" begin
m = Dense(10, 5)
f = compile(m, rand(10))
x = rand(10)
@test Tracker.data(m(x)) == f(x)
end

View File

@ -10,7 +10,6 @@ include("layers/normalisation.jl")
include("layers/stateless.jl")
include("optimise.jl")
include("data.jl")
include("jit.jl")
if Base.find_in_path("CuArrays") nothing
include("cuda/cuda.jl")