perf hacks
This commit is contained in:
parent
c565317d9e
commit
a8ccc79f61
@ -20,7 +20,7 @@ struct Call{F,As<:Tuple}
|
|||||||
args::As
|
args::As
|
||||||
end
|
end
|
||||||
|
|
||||||
Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
|
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
||||||
Call() = Call(nothing, ())
|
Call() = Call(nothing, ())
|
||||||
|
|
||||||
# When deserialising, the object_id changes
|
# When deserialising, the object_id changes
|
||||||
@ -46,7 +46,14 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||||||
|
|
||||||
function _forward end
|
function _forward end
|
||||||
|
|
||||||
function track(f, xs...; kw...)
|
function track(f::F, xs...) where F
|
||||||
|
y, back = _forward(f, xs...)
|
||||||
|
ts = map(tracker, xs)
|
||||||
|
c = Call(back, ts)
|
||||||
|
track(c, y)
|
||||||
|
end
|
||||||
|
|
||||||
|
function track_kw(f::F, xs...; kw...) where F
|
||||||
y, back = _forward(f, xs...; kw...)
|
y, back = _forward(f, xs...; kw...)
|
||||||
track(Call(back, tracker.(xs)), y)
|
track(Call(back, tracker.(xs)), y)
|
||||||
end
|
end
|
||||||
|
@ -101,7 +101,7 @@ Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||||
@ -324,9 +324,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||||||
|
|
||||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||||
|
|
||||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||||
|
|
||||||
@grad conv(x, w; kw...) =
|
@grad conv(x, w; kw...) =
|
||||||
conv(data(x), data(w); kw...),
|
conv(data(x), data(w); kw...),
|
||||||
@ -334,14 +334,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
|||||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||||
|
|
||||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||||
|
|
||||||
@grad function maxpool(x, k; kw...)
|
@grad function maxpool(x, k; kw...)
|
||||||
y = maxpool(data(x), k; kw...)
|
y = maxpool(data(x), k; kw...)
|
||||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||||
end
|
end
|
||||||
|
|
||||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||||
|
|
||||||
@grad function meanpool(x, k; kw...)
|
@grad function meanpool(x, k; kw...)
|
||||||
y = meanpool(data(x), k; kw...)
|
y = meanpool(data(x), k; kw...)
|
||||||
|
@ -152,3 +152,13 @@ function gradient(f, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
derivative(f, x) = gradient(f, x)[1]
|
derivative(f, x) = gradient(f, x)[1]
|
||||||
|
|
||||||
|
# Non-nesting versions
|
||||||
|
|
||||||
|
function gradient_(f, xs...)
|
||||||
|
xs = param.(xs)
|
||||||
|
l = f(xs...)
|
||||||
|
losscheck(l)
|
||||||
|
back!(l)
|
||||||
|
grad.(xs)
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user