diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index e6ff4068..bcadcf4f 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -240,7 +240,7 @@ end # Interface import ..Flux: Flux, relu -import ..Flux.Tracker: TrackedArray +import ..Tracker: TrackedArray using CUDAnative using CuArrays: @cuindex, cudims diff --git a/src/jit/JIT.jl b/src/jit/JIT.jl index 6085ce5e..06a7da6b 100644 --- a/src/jit/JIT.jl +++ b/src/jit/JIT.jl @@ -1,5 +1,7 @@ module JIT +using MacroTools + include("shapes.jl") include("trace.jl") include("lib.jl") diff --git a/src/jit/lib.jl b/src/jit/lib.jl index 5a9fe776..42c8cac3 100644 --- a/src/jit/lib.jl +++ b/src/jit/lib.jl @@ -3,6 +3,9 @@ shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T = Shape{T}(size(A,1)) +shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T = + Shape{T}(size(A,1),size(B,2)) + inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = A_mul_B!(C, A, B) @@ -10,3 +13,10 @@ shape(::typeof(broadcast), f, xs...) = Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...) inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...) + +# NNlib + +using NNlib + +shape(::typeof(softmax), x) = x +inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x) diff --git a/src/jit/trace.jl b/src/jit/trace.jl index ae9766b4..565d7724 100644 --- a/src/jit/trace.jl +++ b/src/jit/trace.jl @@ -1,7 +1,7 @@ # This is hacky; we'll eventually reuse Cassette for better tracing. -using ..Flux.Tracker, DataFlow -using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf +using ..Tracker, DataFlow +using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax, inputnode, constant @@ -45,12 +45,10 @@ function cacheall(v, buf = () -> UInt8[]) end end -function eval_func(v, n) - v = vertex(Lambda(n, v)) - v |> syntax |> eval -end +code(v, n) = syntax(vertex(Lambda(n, v))) struct Compiled{F,T<:Tuple} + model func::F params::T end @@ -59,8 +57,16 @@ end Tracker.track(Tracker.Call(c, args...), c.func(Tracker.data.(c.params), args...)) +Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")") + function compile(f, args...) v = trace(f, args...) v, ps = liftparams(cacheall(v)) - Compiled(eval_func(v, length(args)+1), (ps...,)) + Compiled(f, eval(code(v, length(args)+1)), (ps...,)) +end + +function source(f, args...) + v = trace(f, args...) + v, ps = liftparams(v) + code(v, length(args)+1) |> prettify end