2018-02-07 17:43:25 +00:00
|
|
|
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
|
|
|
tracker::Tracked{A}
|
|
|
|
data::A
|
|
|
|
grad::A
|
|
|
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
|
|
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
|
|
|
end
|
|
|
|
|
|
|
|
tracker(x::TrackedArray) = x.tracker
|
|
|
|
|
|
|
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
|
|
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
|
|
|
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|
|
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
|
|
|
|
|
|
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|
|
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
|
|
|
|
|
|
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
|
|
|
|
2018-02-13 10:20:38 +00:00
|
|
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
|
|
|
print(io, "TrackedArray{…,$A}")
|
|
|
|
|
|
|
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
|
|
|
if repr
|
|
|
|
print(io, "param(")
|
|
|
|
Base.showarray(io, data(X), true)
|
|
|
|
print(io, ")")
|
|
|
|
else
|
|
|
|
header && print(io, "Tracked ")
|
|
|
|
Base.showarray(io, data(X), false, header = header)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
|
|
|
error("Can't differentiate `setindex!`")
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
back!(::TrackedArray) = error("Use back!(x, Δ)")
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
# Fallthrough methods
|
|
|
|
|
|
|
|
for f in :[Base.size, Base.ndims].args
|
|
|
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|
|
|
similar(data(x), dims...)
|
|
|
|
|
|
|
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
|
|
|
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
|
|
|
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
2018-02-07 17:43:25 +00:00
|
|
|
|
|
|
|
# Array Stdlib
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
2017-09-07 01:21:35 +00:00
|
|
|
Δ′ = zeros(xs.data)
|
2018-02-07 20:39:36 +00:00
|
|
|
Δ′[i...] = Δ
|
2017-09-07 03:09:32 +00:00
|
|
|
@back(xs, Δ′)
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.:-(xs::TrackedArray) = track(-, xs)
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
2017-08-23 16:50:43 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
|
|
|
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
2017-09-01 15:42:18 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
|
|
|
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
2017-09-03 21:10:23 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
|
|
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
2017-09-05 06:12:53 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
|
|
|
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
2017-09-03 06:12:44 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
|
|
|
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
2017-09-05 23:25:42 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
|
|
|
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
|
|
|
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
2017-09-05 06:11:28 +00:00
|
|
|
|
2017-12-08 15:10:09 +00:00
|
|
|
function back(::typeof(vcat), Δ, xs...)
|
2017-09-05 06:11:28 +00:00
|
|
|
i = Base.tail(map(_ -> :, size(Δ)))
|
2017-12-08 15:10:09 +00:00
|
|
|
start = 0
|
|
|
|
for xsi in xs
|
|
|
|
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
|
|
|
start += size(xsi, 1)
|
|
|
|
end
|
2017-09-05 06:11:28 +00:00
|
|
|
end
|
|
|
|
|
2017-12-15 16:18:16 +00:00
|
|
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
2018-02-07 20:39:36 +00:00
|
|
|
track(reshape, xs, dims...)
|
2017-12-15 16:18:16 +00:00
|
|
|
|
2018-02-08 19:27:57 +00:00
|
|
|
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
|
|
|
|
track(reshape, xs, dims)
|
|
|
|
|
2017-12-15 16:18:16 +00:00
|
|
|
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
|
|
|
back(xs, reshape(Δ, size(xs)))
|
|
|
|
|
2018-02-28 02:19:58 +00:00
|
|
|
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
|
|
|
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
|
2018-02-08 19:27:57 +00:00
|
|
|
|
2018-02-16 14:15:40 +00:00
|
|
|
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
2018-02-08 19:27:57 +00:00
|
|
|
m1, n1 = size(mat1)
|
|
|
|
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
|
|
|
|
|
|
|
m2, n2 = size(mat2)
|
|
|
|
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
|
|
|
|
|
|
|
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
|
|
|
end
|
|
|
|
|
2018-02-16 14:15:40 +00:00
|
|
|
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
|
|
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
|
|
|
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# Reductions
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
|
|
|
Base.sum(xs::TrackedArray) = track(sum, xs)
|
2018-02-09 19:00:26 +00:00
|
|
|
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2018-03-06 10:01:19 +00:00
|
|
|
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
|
|
|
Base.prod(xs::TrackedArray) = track(prod, xs)
|
|
|
|
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|
|
|
|
|
|
|
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim) ./ xs.data) .* Δ)
|
|
|
|
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ)
|
|
|
|
|
2017-09-07 01:21:35 +00:00
|
|
|
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
|
|
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
2017-09-02 03:33:05 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.mean(xs::TrackedArray) = track(mean, xs)
|
|
|
|
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
2017-10-30 08:21:02 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
|
|
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
|
|
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
2017-12-12 17:23:15 +00:00
|
|
|
|
|
|
|
function back(::typeof(dot), Δ, xs, ys)
|
2018-02-13 13:31:35 +00:00
|
|
|
@back(xs, Δ.*data(ys))
|
|
|
|
@back(ys, Δ.*data(xs))
|
2017-12-12 17:23:15 +00:00
|
|
|
end
|
|
|
|
|
2017-11-21 16:04:04 +00:00
|
|
|
# Hacks to get std working
|
|
|
|
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
|
|
|
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
|
|
|
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
|
|
|
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
|
|
|
|
2018-03-05 17:24:46 +00:00
|
|
|
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
|
|
|
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
2018-02-09 19:00:26 +00:00
|
|
|
|
2017-10-31 10:41:44 +00:00
|
|
|
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
2017-10-30 08:21:02 +00:00
|
|
|
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|
|
|
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# BLAS
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.diagm(x::TrackedVector) = track(diagm, x)
|
2018-02-05 18:29:35 +00:00
|
|
|
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
|
|
|
|
|
2017-12-12 17:07:39 +00:00
|
|
|
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
2017-11-08 22:00:19 +00:00
|
|
|
@eval begin
|
|
|
|
import Base.$f
|
2018-02-07 20:39:36 +00:00
|
|
|
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
|
|
|
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
|
|
|
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
2017-11-08 22:00:19 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
|
|
|
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
|
|
|
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
2017-11-08 22:00:19 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
|
|
|
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
|
|
|
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
2017-11-08 22:00:19 +00:00
|
|
|
end
|
|
|
|
end
|
2017-08-20 12:48:43 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
|
|
|
@back(a, A_mul_Bt(Δ, data(b)))
|
|
|
|
@back(b, At_mul_B(data(a), Δ))
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2017-11-08 22:00:19 +00:00
|
|
|
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
|
|
|
@back(a, A_mul_Bt(Δ, data(b))')
|
2017-12-12 17:07:39 +00:00
|
|
|
@back(b, data(a)*Δ)
|
|
|
|
end
|
|
|
|
|
|
|
|
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real})
|
|
|
|
@back(a, Δ * data(b))
|
|
|
|
@back(b, At_mul_B(data(a), Δ)')
|
2017-11-08 22:00:19 +00:00
|
|
|
end
|
|
|
|
|
2017-11-07 19:34:27 +00:00
|
|
|
# Fast path for matrix-vector
|
|
|
|
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
|
|
|
if isleaf(W)
|
|
|
|
W.grad .+= Δ .* data(x).'
|
|
|
|
else
|
|
|
|
back(W, A_mul_Bt(Δ, data(x)))
|
|
|
|
end
|
|
|
|
@back(x, At_mul_B(data(W), Δ))
|
|
|
|
end
|
|
|
|
|
2017-08-23 01:03:17 +00:00
|
|
|
# NNlib
|
|
|
|
|
2017-12-14 18:48:38 +00:00
|
|
|
using NNlib
|
2018-02-26 22:43:07 +00:00
|
|
|
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
softmax(xs::TrackedArray) = track(softmax, xs)
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
2018-01-21 07:20:59 +00:00
|
|
|
|
|
|
|
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
|
|
|
|
|
2017-12-18 18:05:38 +00:00
|
|
|
# TODO: can store kwargs efficiently in namedtuples
|
2018-02-26 22:43:07 +00:00
|
|
|
_conv(x, w, stride, pad) = conv(x, w, stride = stride, pad = pad)
|
|
|
|
|
|
|
|
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
|
|
|
track(_conv, x, w, stride, pad)
|
|
|
|
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0) where N =
|
|
|
|
track(_conv, x, w, stride, pad)
|
|
|
|
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0) where N =
|
|
|
|
track(_conv, x, w, stride, pad)
|
|
|
|
|
|
|
|
function back(::typeof(_conv), Δ, x, w, stride, pad)
|
|
|
|
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad))
|
|
|
|
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
2017-12-14 18:48:38 +00:00
|
|
|
end
|
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
|
|
|
track(_maxpool, x, k, pad)
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2018-02-26 22:43:07 +00:00
|
|
|
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
|
|
|
|
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
|
|
|
|
|
|
|
|
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
|
|
|
|
|
|
|
|
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
|
|
|
track(_meanpool, x, k, pad)
|
|
|
|
|
|
|
|
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
|
|
|
|
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
# Broadcasting
|
|
|
|
|
|
|
|
using ForwardDiff: Dual, partials
|
|
|
|
|
2018-02-05 17:22:09 +00:00
|
|
|
struct Broadcasted{F,T}
|
|
|
|
f::F
|
2017-08-19 15:02:19 +00:00
|
|
|
data::T
|
|
|
|
end
|
|
|
|
|
|
|
|
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
|
|
|
|
|
|
|
dualify(xs, n) = xs
|
2017-08-21 15:35:39 +00:00
|
|
|
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
2018-02-08 17:18:40 +00:00
|
|
|
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
|
2017-08-19 15:02:19 +00:00
|
|
|
|
|
|
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
2017-08-20 12:35:20 +00:00
|
|
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
2018-01-15 17:00:47 +00:00
|
|
|
out = broadcast(f, dargs...)
|
|
|
|
eltype(out) <: Dual || return out
|
2018-02-05 17:22:09 +00:00
|
|
|
b = Broadcasted(f, out)
|
2018-02-07 20:39:36 +00:00
|
|
|
track(Call(b, args...), b())
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2017-08-27 08:49:42 +00:00
|
|
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
unbroadcast(x::AbstractArray, Δ) =
|
2017-08-22 23:25:19 +00:00
|
|
|
size(x) == size(Δ) ? Δ :
|
2017-08-27 08:49:42 +00:00
|
|
|
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
2017-08-22 23:25:19 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
|
|
|
|
2017-08-28 00:40:59 +00:00
|
|
|
function getpartial(Δ, x, i)
|
|
|
|
@inbounds p = getindex(partials(x), i)
|
|
|
|
return Δ * p
|
|
|
|
end
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
2017-08-28 00:40:59 +00:00
|
|
|
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
2017-09-07 03:09:32 +00:00
|
|
|
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
|
2017-08-19 09:14:50 +00:00
|
|
|
end
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-02-21 23:21:20 +00:00
|
|
|
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
2017-08-23 16:50:43 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
2017-08-23 16:21:02 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
|
|
|
|
|
|
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|