perf hacks

This commit is contained in:
Mike J Innes 2018-07-30 20:08:44 +01:00
parent c565317d9e
commit a8ccc79f61
3 changed files with 25 additions and 8 deletions

View File

@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
args::As args::As
end end
Call(f, args) = Call{typeof(f),typeof(args)}(f, args) Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
Call() = Call(nothing, ()) Call() = Call(nothing, ())
# When deserialising, the object_id changes # When deserialising, the object_id changes
@ -46,7 +46,14 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end function _forward end
function track(f, xs...; kw...) function track(f::F, xs...) where F
y, back = _forward(f, xs...)
ts = map(tracker, xs)
c = Call(back, ts)
track(c, y)
end
function track_kw(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...) y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y) track(Call(back, tracker.(xs)), y)
end end

View File

@ -101,7 +101,7 @@ Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
end end
end end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(data(xs), inner = inner, outer = outer), function (Δ) repeat(data(xs), inner = inner, outer = outer), function (Δ)
@ -324,9 +324,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
@grad conv(x, w; kw...) = @grad conv(x, w; kw...) =
conv(data(x), data(w); kw...), conv(data(x), data(w); kw...),
@ -334,14 +334,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...), (NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...) @grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...) y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end end
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...) @grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...) y = meanpool(data(x), k; kw...)

View File

@ -152,3 +152,13 @@ function gradient(f, args...)
end end
derivative(f, x) = gradient(f, x)[1] derivative(f, x) = gradient(f, x)[1]
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end