From a8ccc79f61e81d38d3235e53650fe9466693cbf9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 30 Jul 2018 20:08:44 +0100 Subject: [PATCH] perf hacks --- src/tracker/Tracker.jl | 11 +++++++++-- src/tracker/array.jl | 12 ++++++------ src/tracker/back.jl | 10 ++++++++++ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 65b8db11..21f3a43b 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -20,7 +20,7 @@ struct Call{F,As<:Tuple} args::As end -Call(f, args) = Call{typeof(f),typeof(args)}(f, args) +Call(f::F, args::T) where {F,T} = Call{F,T}(f, args) Call() = Call(nothing, ()) # When deserialising, the object_id changes @@ -46,7 +46,14 @@ track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end -function track(f, xs...; kw...) +function track(f::F, xs...) where F + y, back = _forward(f, xs...) + ts = map(tracker, xs) + c = Call(back, ts) + track(c, y) +end + +function track_kw(f::F, xs...; kw...) where F y, back = _forward(f, xs...; kw...) track(Call(back, tracker.(xs)), y) end diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 6c7f93e3..90c7f1ec 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -101,7 +101,7 @@ Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) end end -Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) +Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) repeat(data(xs), inner = inner, outer = outer), function (Δ) @@ -324,9 +324,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) -conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) -conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) -conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) +conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...) +conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...) +conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...) @grad conv(x, w; kw...) = conv(data(x), data(w); kw...), @@ -334,14 +334,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) -maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) +maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...) @grad function maxpool(x, k; kw...) y = maxpool(data(x), k; kw...) y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) end -meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) +meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...) @grad function meanpool(x, k; kw...) y = meanpool(data(x), k; kw...) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 08cf9d6a..3264b348 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -152,3 +152,13 @@ function gradient(f, args...) end derivative(f, x) = gradient(f, x)[1] + +# Non-nesting versions + +function gradient_(f, xs...) + xs = param.(xs) + l = f(xs...) + losscheck(l) + back!(l) + grad.(xs) +end