From 9ccbac8b80cc3dc50f0fdbf02063d811add7b539 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 7 Mar 2018 19:18:27 +0000 Subject: [PATCH] jit gpu support --- src/cuda/cuda.jl | 7 +++++++ src/jit/trace.jl | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index eaa3fe00..1ee10908 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -4,4 +4,11 @@ using CuArrays CuArrays.cudnn_available() && include("cudnn.jl") +import ..Flux.JIT: Shape, restructure + +function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T + buf = buf[1:sizeof(sh)] + reshape(reinterpret(T, buf), size(sh)) +end + end diff --git a/src/jit/trace.jl b/src/jit/trace.jl index 01bf0afc..8266096f 100644 --- a/src/jit/trace.jl +++ b/src/jit/trace.jl @@ -13,7 +13,7 @@ graph(x::Tracked, inputs...; cache = ObjectIdDict()) = function graph(x, inputs...; cache = ObjectIdDict()) haskey(cache, x) && return cache[x] - i = findfirst(inputs, x) + i = findfirst(y -> x === y, inputs) cache[x] = i > 0 ? inputnode(i) : istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) : @@ -64,7 +64,7 @@ Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")") function compile(f, args...) v = trace(f, args...) - v, ps = liftparams(cacheall(v)) + v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU Compiled(f, eval(code(v, length(args)+1)), (ps...,)) end