2017-08-19 09:14:50 +00:00
|
|
|
import Base: *
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
2017-08-22 14:12:12 +00:00
|
|
|
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2017-09-03 21:10:23 +00:00
|
|
|
unarray(xs) = xs
|
|
|
|
unarray(xs::AbstractArray{T,0} where T) = xs[]
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
Base.getindex(xs::TrackedArray, i...) =
|
2017-09-07 01:21:35 +00:00
|
|
|
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
2017-09-07 01:21:35 +00:00
|
|
|
Δ′ = zeros(xs.data)
|
2017-09-03 21:10:23 +00:00
|
|
|
Δ′[i...] = unarray(Δ)
|
2017-09-07 03:09:32 +00:00
|
|
|
@back(xs, Δ′)
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
2017-08-23 16:50:43 +00:00
|
|
|
|
2017-09-01 15:42:18 +00:00
|
|
|
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
|
|
|
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
|
|
|
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
2017-09-03 21:10:23 +00:00
|
|
|
|
2017-09-05 06:12:53 +00:00
|
|
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
|
|
|
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
|
|
|
|
2017-09-03 06:12:44 +00:00
|
|
|
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
|
|
|
|
2017-09-05 23:25:42 +00:00
|
|
|
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
|
|
|
|
2017-09-05 06:28:11 +00:00
|
|
|
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
|
|
|
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
2017-09-05 06:11:28 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(::typeof(vcat), Δ, xs, ys)
|
2017-09-05 06:11:28 +00:00
|
|
|
i = Base.tail(map(_ -> :, size(Δ)))
|
2017-09-07 03:09:32 +00:00
|
|
|
@back(xs, Δ[1:size(xs,1), i...])
|
|
|
|
@back(ys, Δ[size(xs,1)+1:end, i...])
|
2017-09-05 06:11:28 +00:00
|
|
|
end
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# Reductions
|
|
|
|
|
2017-08-22 14:12:12 +00:00
|
|
|
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
2017-09-07 01:21:35 +00:00
|
|
|
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
2017-08-23 16:50:43 +00:00
|
|
|
Base.sum(xs::TrackedScalar, dim...) = xs
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2017-09-07 01:21:35 +00:00
|
|
|
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
|
|
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
2017-09-02 03:33:05 +00:00
|
|
|
|
2017-10-31 10:41:44 +00:00
|
|
|
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
2017-10-30 08:21:02 +00:00
|
|
|
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
|
|
|
|
2017-10-31 10:41:44 +00:00
|
|
|
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
2017-10-30 08:21:02 +00:00
|
|
|
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
|
|
|
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# BLAS
|
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
|
|
|
|
2017-08-20 12:48:43 +00:00
|
|
|
a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
|
|
|
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
|
|
|
|
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
|
|
|
@back(a, A_mul_Bt(Δ, data(b)))
|
|
|
|
@back(b, At_mul_B(data(a), Δ))
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2017-08-23 01:03:17 +00:00
|
|
|
# NNlib
|
|
|
|
|
|
|
|
import NNlib: softmax, ∇softmax
|
|
|
|
|
|
|
|
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
# Broadcasting
|
|
|
|
|
|
|
|
using ForwardDiff: Dual, partials
|
|
|
|
|
|
|
|
struct Broadcasted{T}
|
|
|
|
data::T
|
|
|
|
end
|
|
|
|
|
|
|
|
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
|
|
|
|
|
|
|
dualify(xs, n) = xs
|
2017-08-21 15:35:39 +00:00
|
|
|
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
2017-08-19 15:02:19 +00:00
|
|
|
|
|
|
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
2017-08-20 12:35:20 +00:00
|
|
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
|
|
|
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
|
|
|
# Works around a 0.6 type inference issue
|
|
|
|
b = Broadcasted(broadcast(f, dargs...))
|
|
|
|
TrackedArray(Call(b, args...), b())
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2017-08-27 08:49:42 +00:00
|
|
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
|
|
|
|
|
2017-08-22 23:25:19 +00:00
|
|
|
unbroadcast(x, Δ) =
|
|
|
|
size(x) == size(Δ) ? Δ :
|
2017-08-27 08:49:42 +00:00
|
|
|
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
|
2017-08-22 23:25:19 +00:00
|
|
|
|
2017-08-28 00:40:59 +00:00
|
|
|
function getpartial(Δ, x, i)
|
|
|
|
@inbounds p = getindex(partials(x), i)
|
|
|
|
return Δ * p
|
|
|
|
end
|
|
|
|
|
2017-09-07 03:09:32 +00:00
|
|
|
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
2017-08-28 00:40:59 +00:00
|
|
|
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
2017-09-07 03:09:32 +00:00
|
|
|
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
|
2017-08-19 09:14:50 +00:00
|
|
|
end
|
2017-08-19 15:02:19 +00:00
|
|
|
|
|
|
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
2017-08-23 16:50:43 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
2017-08-23 16:21:02 +00:00
|
|
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
|
|
|
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
|
|
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
|
|
|
|
|
|
|
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|