jit softmax

This commit is contained in:
Mike J Innes 2018-02-28 22:07:35 +00:00
parent 7606b1a399
commit ccef9f4dd4
4 changed files with 26 additions and 8 deletions

View File

@ -240,7 +240,7 @@ end
# Interface # Interface
import ..Flux: Flux, relu import ..Flux: Flux, relu
import ..Flux.Tracker: TrackedArray import ..Tracker: TrackedArray
using CUDAnative using CUDAnative
using CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims

View File

@ -1,5 +1,7 @@
module JIT module JIT
using MacroTools
include("shapes.jl") include("shapes.jl")
include("trace.jl") include("trace.jl")
include("lib.jl") include("lib.jl")

View File

@ -3,6 +3,9 @@
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T = shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
Shape{T}(size(A,1)) Shape{T}(size(A,1))
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
Shape{T}(size(A,1),size(B,2))
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B) A_mul_B!(C, A, B)
@ -10,3 +13,10 @@ shape(::typeof(broadcast), f, xs...) =
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...) Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...) inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
# NNlib
using NNlib
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)

View File

@ -1,7 +1,7 @@
# This is hacky; we'll eventually reuse Cassette for better tracing. # This is hacky; we'll eventually reuse Cassette for better tracing.
using ..Flux.Tracker, DataFlow using ..Tracker, DataFlow
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax, using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant inputnode, constant
@ -45,12 +45,10 @@ function cacheall(v, buf = () -> UInt8[])
end end
end end
function eval_func(v, n) code(v, n) = syntax(vertex(Lambda(n, v)))
v = vertex(Lambda(n, v))
v |> syntax |> eval
end
struct Compiled{F,T<:Tuple} struct Compiled{F,T<:Tuple}
model
func::F func::F
params::T params::T
end end
@ -59,8 +57,16 @@ end
Tracker.track(Tracker.Call(c, args...), Tracker.track(Tracker.Call(c, args...),
c.func(Tracker.data.(c.params), args...)) c.func(Tracker.data.(c.params), args...))
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
function compile(f, args...) function compile(f, args...)
v = trace(f, args...) v = trace(f, args...)
v, ps = liftparams(cacheall(v)) v, ps = liftparams(cacheall(v))
Compiled(eval_func(v, length(args)+1), (ps...,)) Compiled(f, eval(code(v, length(args)+1)), (ps...,))
end
function source(f, args...)
v = trace(f, args...)
v, ps = liftparams(v)
code(v, length(args)+1) |> prettify
end end