jit gpu support

This commit is contained in:
Mike J Innes 2018-03-07 19:18:27 +00:00
parent 261c6db371
commit 9ccbac8b80
2 changed files with 9 additions and 2 deletions

View File

@ -4,4 +4,11 @@ using CuArrays
CuArrays.cudnn_available() && include("cudnn.jl") CuArrays.cudnn_available() && include("cudnn.jl")
import ..Flux.JIT: Shape, restructure
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
buf = buf[1:sizeof(sh)]
reshape(reinterpret(T, buf), size(sh))
end
end end

View File

@ -13,7 +13,7 @@ graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
function graph(x, inputs...; cache = ObjectIdDict()) function graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x] haskey(cache, x) && return cache[x]
i = findfirst(inputs, x) i = findfirst(y -> x === y, inputs)
cache[x] = cache[x] =
i > 0 ? inputnode(i) : i > 0 ? inputnode(i) :
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) : istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
@ -64,7 +64,7 @@ Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...) function compile(f, args...)
v = trace(f, args...) v = trace(f, args...)
v, ps = liftparams(cacheall(v)) v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
Compiled(f, eval(code(v, length(args)+1)), (ps...,)) Compiled(f, eval(code(v, length(args)+1)), (ps...,))
end end