jit softmax
This commit is contained in:
parent
7606b1a399
commit
ccef9f4dd4
@ -240,7 +240,7 @@ end
|
||||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Flux.Tracker: TrackedArray
|
||||
import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
module JIT
|
||||
|
||||
using MacroTools
|
||||
|
||||
include("shapes.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
@ -3,6 +3,9 @@
|
||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||
Shape{T}(size(A,1))
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
||||
Shape{T}(size(A,1),size(B,2))
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
@ -10,3 +13,10 @@ shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
|
||||
shape(::typeof(softmax), x) = x
|
||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Flux.Tracker, DataFlow
|
||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using ..Tracker, DataFlow
|
||||
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
@ -45,12 +45,10 @@ function cacheall(v, buf = () -> UInt8[])
|
||||
end
|
||||
end
|
||||
|
||||
function eval_func(v, n)
|
||||
v = vertex(Lambda(n, v))
|
||||
v |> syntax |> eval
|
||||
end
|
||||
code(v, n) = syntax(vertex(Lambda(n, v)))
|
||||
|
||||
struct Compiled{F,T<:Tuple}
|
||||
model
|
||||
func::F
|
||||
params::T
|
||||
end
|
||||
@ -59,8 +57,16 @@ end
|
||||
Tracker.track(Tracker.Call(c, args...),
|
||||
c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(cacheall(v))
|
||||
Compiled(eval_func(v, length(args)+1), (ps...,))
|
||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||
end
|
||||
|
||||
function source(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(v)
|
||||
code(v, length(args)+1) |> prettify
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user