jit gpu support
This commit is contained in:
parent
261c6db371
commit
9ccbac8b80
@ -4,4 +4,11 @@ using CuArrays
|
||||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
import ..Flux.JIT: Shape, restructure
|
||||
|
||||
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
|
||||
buf = buf[1:sizeof(sh)]
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -13,7 +13,7 @@ graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
i = findfirst(y -> x === y, inputs)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
@ -64,7 +64,7 @@ Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(cacheall(v))
|
||||
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user