track_kw should be unnecessary
This commit is contained in:
parent
f5c9361617
commit
e5b3d27016
@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...) where F
|
||||
y, back = _forward(f, xs...)
|
||||
ts = map(tracker, xs)
|
||||
c = Call(back, ts)
|
||||
track(c, y)
|
||||
end
|
||||
|
||||
function track_kw(f::F, xs...; kw...) where F
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
@ -85,7 +85,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
@ -159,10 +159,10 @@ end
|
||||
end
|
||||
end
|
||||
|
||||
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
@ -312,9 +312,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
@ -322,14 +322,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
|
Loading…
Reference in New Issue
Block a user