jit softmax

This commit is contained in:
Mike J Innes 2018-02-28 22:07:35 +00:00
parent 7606b1a399
commit ccef9f4dd4
4 changed files with 26 additions and 8 deletions

View File

@ -240,7 +240,7 @@ end
# Interface
import ..Flux: Flux, relu
import ..Flux.Tracker: TrackedArray
import ..Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims

View File

@ -1,5 +1,7 @@
module JIT
using MacroTools
include("shapes.jl")
include("trace.jl")
include("lib.jl")

View File

@ -3,6 +3,9 @@
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
Shape{T}(size(A,1))
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
Shape{T}(size(A,1),size(B,2))
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B)
@ -10,3 +13,10 @@ shape(::typeof(broadcast), f, xs...) =
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
# NNlib
using NNlib
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)

View File

@ -1,7 +1,7 @@
# This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Flux.Tracker, DataFlow
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using ..Tracker, DataFlow
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
@ -45,12 +45,10 @@ function cacheall(v, buf = () -> UInt8[])
end
end
function eval_func(v, n)
v = vertex(Lambda(n, v))
v |> syntax |> eval
end
code(v, n) = syntax(vertex(Lambda(n, v)))
struct Compiled{F,T<:Tuple}
model
func::F
params::T
end
@ -59,8 +57,16 @@ end
Tracker.track(Tracker.Call(c, args...),
c.func(Tracker.data.(c.params), args...))
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...)
v = trace(f, args...)
v, ps = liftparams(cacheall(v))
Compiled(eval_func(v, length(args)+1), (ps...,))
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
end
function source(f, args...)
v = trace(f, args...)
v, ps = liftparams(v)
code(v, length(args)+1) |> prettify
end