From e5b3d270167ae6c6f89ee7c3895483cd9e3549fb Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 3 Aug 2018 15:14:10 +0100 Subject: [PATCH] track_kw should be unnecessary --- src/tracker/Tracker.jl | 9 +-------- src/tracker/array.jl | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2d805af9..2c4951a9 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end -function track(f::F, xs...) where F - y, back = _forward(f, xs...) - ts = map(tracker, xs) - c = Call(back, ts) - track(c, y) -end - -function track_kw(f::F, xs...; kw...) where F +function track(f::F, xs...; kw...) where F y, back = _forward(f, xs...; kw...) track(Call(back, tracker.(xs)), y) end diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 7111d780..13dfe393 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -85,7 +85,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) -Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...) +Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) repeat(data(xs), inner = inner, outer = outer), function (Δ) @@ -159,10 +159,10 @@ end end end -Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims) -Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) -Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) -Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) +Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) +Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) +Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) +Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) @grad function cat(Xs...; dims) cat(data.(Xs)..., dims = dims), function (Δ) @@ -312,9 +312,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) -conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...) -conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...) -conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...) +conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) +conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) +conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = conv(data(x), data(w); kw...), @@ -322,14 +322,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...) (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) -maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...) +maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) @grad function maxpool(x, k; kw...) y = maxpool(data(x), k; kw...) y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) end -meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...) +meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) @grad function meanpool(x, k; kw...) y = meanpool(data(x), k; kw...)