jit softmax
This commit is contained in:
parent
7606b1a399
commit
ccef9f4dd4
@ -240,7 +240,7 @@ end
|
|||||||
# Interface
|
# Interface
|
||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Flux.Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
module JIT
|
module JIT
|
||||||
|
|
||||||
|
using MacroTools
|
||||||
|
|
||||||
include("shapes.jl")
|
include("shapes.jl")
|
||||||
include("trace.jl")
|
include("trace.jl")
|
||||||
include("lib.jl")
|
include("lib.jl")
|
||||||
|
@ -3,6 +3,9 @@
|
|||||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||||
Shape{T}(size(A,1))
|
Shape{T}(size(A,1))
|
||||||
|
|
||||||
|
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
||||||
|
Shape{T}(size(A,1),size(B,2))
|
||||||
|
|
||||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||||
A_mul_B!(C, A, B)
|
A_mul_B!(C, A, B)
|
||||||
|
|
||||||
@ -10,3 +13,10 @@ shape(::typeof(broadcast), f, xs...) =
|
|||||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||||
|
|
||||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||||
|
|
||||||
|
# NNlib
|
||||||
|
|
||||||
|
using NNlib
|
||||||
|
|
||||||
|
shape(::typeof(softmax), x) = x
|
||||||
|
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||||
|
|
||||||
using ..Flux.Tracker, DataFlow
|
using ..Tracker, DataFlow
|
||||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||||
inputnode, constant
|
inputnode, constant
|
||||||
|
|
||||||
@ -45,12 +45,10 @@ function cacheall(v, buf = () -> UInt8[])
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function eval_func(v, n)
|
code(v, n) = syntax(vertex(Lambda(n, v)))
|
||||||
v = vertex(Lambda(n, v))
|
|
||||||
v |> syntax |> eval
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Compiled{F,T<:Tuple}
|
struct Compiled{F,T<:Tuple}
|
||||||
|
model
|
||||||
func::F
|
func::F
|
||||||
params::T
|
params::T
|
||||||
end
|
end
|
||||||
@ -59,8 +57,16 @@ end
|
|||||||
Tracker.track(Tracker.Call(c, args...),
|
Tracker.track(Tracker.Call(c, args...),
|
||||||
c.func(Tracker.data.(c.params), args...))
|
c.func(Tracker.data.(c.params), args...))
|
||||||
|
|
||||||
|
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||||
|
|
||||||
function compile(f, args...)
|
function compile(f, args...)
|
||||||
v = trace(f, args...)
|
v = trace(f, args...)
|
||||||
v, ps = liftparams(cacheall(v))
|
v, ps = liftparams(cacheall(v))
|
||||||
Compiled(eval_func(v, length(args)+1), (ps...,))
|
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||||
|
end
|
||||||
|
|
||||||
|
function source(f, args...)
|
||||||
|
v = trace(f, args...)
|
||||||
|
v, ps = liftparams(v)
|
||||||
|
code(v, length(args)+1) |> prettify
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user