track_kw should be unnecessary

This commit is contained in:
Mike J Innes 2018-08-03 15:14:10 +01:00
parent f5c9361617
commit e5b3d27016
2 changed files with 11 additions and 18 deletions

View File

@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f::F, xs...) where F
y, back = _forward(f, xs...)
ts = map(tracker, xs)
c = Call(back, ts)
track(c, y)
end
function track_kw(f::F, xs...; kw...) where F
function track(f::F, xs...; kw...) where F
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end

View File

@ -85,7 +85,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
@ -159,10 +159,10 @@ end
end
end
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
@ -312,9 +312,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) =
conv(data(x), data(w); kw...),
@ -322,14 +322,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...)
y = meanpool(data(x), k; kw...)