jit gpu support
This commit is contained in:
parent
261c6db371
commit
9ccbac8b80
@ -4,4 +4,11 @@ using CuArrays
|
|||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||||
|
|
||||||
|
import ..Flux.JIT: Shape, restructure
|
||||||
|
|
||||||
|
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
|
||||||
|
buf = buf[1:sizeof(sh)]
|
||||||
|
reshape(reinterpret(T, buf), size(sh))
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -13,7 +13,7 @@ graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
|||||||
|
|
||||||
function graph(x, inputs...; cache = ObjectIdDict())
|
function graph(x, inputs...; cache = ObjectIdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
i = findfirst(inputs, x)
|
i = findfirst(y -> x === y, inputs)
|
||||||
cache[x] =
|
cache[x] =
|
||||||
i > 0 ? inputnode(i) :
|
i > 0 ? inputnode(i) :
|
||||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||||
@ -64,7 +64,7 @@ Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
|||||||
|
|
||||||
function compile(f, args...)
|
function compile(f, args...)
|
||||||
v = trace(f, args...)
|
v = trace(f, args...)
|
||||||
v, ps = liftparams(cacheall(v))
|
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
||||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user