fix gradient definitions
This commit is contained in:
parent
41b9412439
commit
5e319c7395
@ -340,33 +340,33 @@ function accum_transpose!(dst::CuArray, src::CuArray)
|
||||
return dst
|
||||
end
|
||||
|
||||
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
y, ho = y_
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# We don't have to make this assumption, it's just slightly more complex.
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
# function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||
# y, ho = y_
|
||||
# dy, dho = Δ
|
||||
# h_ = hBatch(x, data(h))
|
||||
# dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||
# @back(x, dx)
|
||||
# @back(h, unbroadcast(h, dh))
|
||||
# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# # We don't have to make this assumption, it's just slightly more complex.
|
||||
# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
# end
|
||||
|
||||
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
y, ho, co = y_
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
@back(x, dx)
|
||||
@back(h, unbroadcast(h, dh))
|
||||
@back(c, unbroadcast(h, dc))
|
||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
end
|
||||
# function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||
# y, ho, co = y_
|
||||
# dy, dho, dco = Δ
|
||||
# h_ = hBatch(x, data(h))
|
||||
# c_ = hBatch(x, data(c))
|
||||
# dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||
# @back(x, dx)
|
||||
# @back(h, unbroadcast(h, dh))
|
||||
# @back(c, unbroadcast(h, dc))
|
||||
# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||
# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||
# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||
# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||
# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||
# end
|
||||
|
@ -1,6 +1,7 @@
|
||||
module Tracker
|
||||
|
||||
using MacroTools
|
||||
using MacroTools: @q
|
||||
|
||||
import Base: ==
|
||||
|
||||
@ -51,8 +52,8 @@ track(f::Call) = track(f, f())
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f, xs...)
|
||||
y, back = _forward(f, data.(xs)...)
|
||||
function track(f, xs...; kw...)
|
||||
y, back = _forward(f, data.(xs)...; kw...)
|
||||
track(Call(back, xs), y)
|
||||
end
|
||||
|
||||
@ -60,8 +61,8 @@ macro grad(ex)
|
||||
@capture(shortdef(ex), (name_(args__) = body_) |
|
||||
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
||||
T == nothing && (T = [])
|
||||
unshift!(args, :(::typeof($name)))
|
||||
:(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
insert!(args, 1+isexpr(args[1], :parameters) , :(::typeof($name)))
|
||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
function update!(x, Δ)
|
||||
@ -83,7 +84,7 @@ Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
||||
the sign of the gradient applied to `x`.
|
||||
"""
|
||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))
|
||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
@ -62,45 +62,47 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.data)
|
||||
Δ′[i...] = Δ
|
||||
@back(xs, Δ′)
|
||||
@grad function getindex(xs, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
Δ′[i...] = Δ
|
||||
(Δ′, map(_->nothing, i)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
||||
@grad -(xs) = -xs, Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
|
||||
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
||||
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||
@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),)
|
||||
@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),)
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
@grad function repmat(xs, m, n = 1)
|
||||
repmat(xs, m, n), function (Δ)
|
||||
Δ′ = similar(xs)
|
||||
S = size(xs)
|
||||
for (i,v) in enumerate(Δ)
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
back(xs, Δ′)
|
||||
return (Δ′, nothing, nothing)
|
||||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||
|
||||
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
Δ′ = similar(xs.data)
|
||||
Δ′ .= 0
|
||||
S = size(xs.data)
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(xs, inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
@ -109,7 +111,8 @@ function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
back(xs, Δ′)
|
||||
(Δ′,)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@ -138,42 +141,51 @@ for f in [:vcat, :hcat]
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
@grad function vcat(xs...)
|
||||
vcat(xs...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
d = Δ[start+1:start+size(xsi,1), i...]
|
||||
start += size(xsi, 1)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(hcat), Δ, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) == 1
|
||||
@back(xsi, Δ[:, start+1])
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
@grad function hcat(xs...)
|
||||
hcat(xs...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
d = if ndims(xsi) == 1
|
||||
Δ[:, start+1]
|
||||
else
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||
Δ[:, start+1:start+size(xsi,2), i...]
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
d
|
||||
end for xsi in xs]
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
|
||||
function back(::typeof(cat), Δ, dims, Xs...)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
for xs in Xs
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
|
||||
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||
|
||||
start = start .+ till_xs
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, Xs...), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
d = reshape(Δ[xs_in_Δ...],size(xs))
|
||||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (nothing, Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
@ -181,11 +193,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
@grad reshape(xs, dims) = reshape(xs, dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
|
||||
@grad permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
@ -207,14 +218,16 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
@grad sum(xs, dim...) = sum(xs, dim...),
|
||||
Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,)
|
||||
@grad prod(xs, dim) = prod(xs, dim),
|
||||
Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
@ -230,10 +243,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
function back(::typeof(dot), Δ, xs, ys)
|
||||
@back(xs, Δ.*data(ys))
|
||||
@back(ys, Δ.*data(xs))
|
||||
end
|
||||
@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
@ -244,39 +254,30 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),)
|
||||
@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing)
|
||||
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmax(xs.data)
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(xs, r...), function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
_, i = findmax(xs, r...)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
return (Δ′,map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmax(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmin(xs.data)
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(xs, r...), function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
_, i = findmin(xs, r...)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmin(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
return (Δ′,map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
||||
@grad diagm(x) = diagm(x), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
@eval begin
|
||||
@ -295,30 +296,11 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
||||
@back(a, A_mul_Bt(Δ, data(b)))
|
||||
@back(b, At_mul_B(data(a), Δ))
|
||||
end
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
|
||||
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, A_mul_Bt(Δ, data(b))')
|
||||
@back(b, data(a)*Δ)
|
||||
end
|
||||
|
||||
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
||||
@back(a, Δ * data(b))
|
||||
@back(b, At_mul_B(data(a), Δ)')
|
||||
end
|
||||
|
||||
# Fast path for matrix-vector
|
||||
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
||||
if isleaf(W)
|
||||
W.grad .+= Δ .* data(x).'
|
||||
else
|
||||
back(W, A_mul_Bt(Δ, data(x)))
|
||||
end
|
||||
@back(x, At_mul_B(data(W), Δ))
|
||||
end
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
# NNlib
|
||||
|
||||
@ -327,65 +309,42 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||
@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),)
|
||||
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
||||
@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),)
|
||||
|
||||
# TODO: can store kwargs efficiently in namedtuples
|
||||
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N =
|
||||
track(_conv, x, w, stride, pad, dilation)
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(x, w; kw...),
|
||||
Δ -> (NNlib.∇conv_data(Δ, x, w; kw...),
|
||||
NNlib.∇conv_filter(Δ, x, w; kw...))
|
||||
|
||||
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation)
|
||||
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(x, k; kw...)
|
||||
y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing)
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_maxpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_meanpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(x, k; kw...)
|
||||
y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing)
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
||||
struct Broadcasted{F,T}
|
||||
f::F
|
||||
data::T
|
||||
end
|
||||
|
||||
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
b = Broadcasted(f, out)
|
||||
track(Call(b, args...), b())
|
||||
end
|
||||
dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
|
||||
dualify(xs::Real, ps) = Dual(xs, ps)
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
||||
|
||||
@ -400,9 +359,17 @@ function getpartial(Δ, x, i)
|
||||
return Δ * p
|
||||
end
|
||||
|
||||
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
||||
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
||||
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
|
||||
function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, args), y)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||
@ -415,4 +382,4 @@ Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
|
||||
|
@ -23,7 +23,7 @@ end
|
||||
|
||||
function back_(c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) == length(c.args)) ||
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs)
|
||||
end
|
||||
@ -48,13 +48,6 @@ end
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
|
||||
macro back(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
istracked(x) && back(x, $(esc(Δ)))
|
||||
end
|
||||
end
|
||||
|
||||
# Interface methods
|
||||
|
||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
|
@ -100,13 +100,13 @@ back(::typeof(getindex), Δ, t, i) =
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, xs), data.(xs))
|
||||
track(Call(collect, (xs,)), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
foreach(scan, c.args[1])
|
||||
end
|
||||
|
||||
function back(::typeof(collect), Δ, xs)
|
||||
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ)
|
||||
end
|
||||
|
@ -111,6 +111,7 @@ end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
# TODO unreliable
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@ -232,4 +233,12 @@ Tracker.back!(b)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
end
|
||||
|
||||
# Gradient Hooks
|
||||
@testset "Hooks" begin
|
||||
x = param(2)
|
||||
y = Tracker.hook(-, x)
|
||||
back!(y)
|
||||
@test grad(x) == -1
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
Loading…
Reference in New Issue
Block a user