2018-08-25 06:51:40 +00:00
|
|
|
import Base: *
|
2018-06-20 13:11:31 +00:00
|
|
|
|
2018-07-12 20:08:53 +00:00
|
|
|
import LinearAlgebra
|
2018-09-19 12:08:30 +00:00
|
|
|
import LinearAlgebra: inv, \, /
|
|
|
|
|
2018-07-17 14:57:39 +00:00
|
|
|
using Statistics
|
2018-07-18 13:39:20 +00:00
|
|
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
2018-06-20 14:18:07 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
|
|
|
tracker::Tracked{A}
|
|
|
|
data::A
|
|
|
|
grad::A
|
|
|
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
|
|
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
|
|
|
end
|
|
|
|
|
2018-07-09 18:44:14 +00:00
|
|
|
data(x::TrackedArray) = x.data
|
2018-02-07 17:43:25 +00:00
|
|
|
tracker(x::TrackedArray) = x.tracker
|
|
|
|
|
|
|
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
|
|
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
|
|
|
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
2018-07-09 18:44:14 +00:00
|
|
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
2018-02-07 17:43:25 +00:00
|
|
|
|
|
|
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
2018-07-09 18:44:14 +00:00
|
|
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
2018-02-07 17:43:25 +00:00
|
|
|
|
2018-07-18 05:51:40 +00:00
|
|
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
2018-02-07 17:43:25 +00:00
|
|
|
|
2018-02-13 10:20:38 +00:00
|
|
|
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
|
|
|
print(io, "TrackedArray{…,$A}")
|
|
|
|
|
2018-06-20 14:16:45 +00:00
|
|
|
function Base.summary(io::IO, x::TrackedArray)
|
|
|
|
print(io, "Tracked ")
|
|
|
|
summary(io, data(x))
|
2018-02-07 17:43:25 +00:00
|
|
|
end
|
|
|
|
|
2018-06-20 14:16:45 +00:00
|
|
|
Base.print_array(io::IO, x::TrackedArray) = Base.print_array(io, data(x))
|
|
|
|
|
2018-10-08 19:49:17 +00:00
|
|
|
Base.copy(x::TrackedArray) = x
|
2018-10-08 19:34:41 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
|
|
|
error("Can't differentiate `setindex!`")
|
|
|
|
|
2018-04-30 11:09:15 +00:00
|
|
|
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
# Fallthrough methods
|
|
|
|
|
2018-08-24 13:30:39 +00:00
|
|
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
2018-02-07 17:43:25 +00:00
|
|
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
|
|
|
end
|
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
|
|
|
size(data(x), i, j, is...)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|
|
|
similar(data(x), dims...)
|
|
|
|
|
|
|
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
|
|
|
|
2018-08-25 06:51:40 +00:00
|
|
|
for op in [:(==), :≈]
|
|
|
|
@eval Base.$op(x::TrackedArray, y::AbstractArray) = Base.$op(data(x), y)
|
|
|
|
@eval Base.$op(x::AbstractArray, y::TrackedArray) = Base.$op(x, data(y))
|
|
|
|
@eval Base.$op(x::TrackedArray, y::TrackedArray) = Base.$op(data(x), data(y))
|
|
|
|
end
|
2018-02-07 17:43:25 +00:00
|
|
|
|
|
|
|
# Array Stdlib
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2018-07-10 17:16:37 +00:00
|
|
|
@grad function getindex(xs::AbstractArray, i...)
|
2018-07-09 12:39:10 +00:00
|
|
|
data(xs)[i...], function (Δ)
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ′ = zero(xs)
|
|
|
|
Δ′[i...] = data(Δ)
|
|
|
|
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2017-08-19 15:02:19 +00:00
|
|
|
end
|
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.:-(xs::TrackedArray) = track(-, xs)
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
2017-08-23 16:50:43 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
2018-06-20 13:11:31 +00:00
|
|
|
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
2017-09-01 15:42:18 +00:00
|
|
|
|
2018-06-20 14:18:07 +00:00
|
|
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
2018-06-20 13:11:31 +00:00
|
|
|
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
2017-09-03 21:10:23 +00:00
|
|
|
|
2018-08-20 13:11:56 +00:00
|
|
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
2018-05-23 00:39:45 +00:00
|
|
|
|
2018-08-20 13:11:56 +00:00
|
|
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
2018-07-10 08:03:09 +00:00
|
|
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
2018-07-09 12:39:10 +00:00
|
|
|
Δ′ = zero(xs)
|
|
|
|
S = size(xs)
|
2018-07-06 10:28:18 +00:00
|
|
|
|
2018-05-23 00:39:45 +00:00
|
|
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
2018-07-18 13:39:20 +00:00
|
|
|
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
2018-05-23 00:39:45 +00:00
|
|
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
|
|
|
# wrap around based on original size S.
|
|
|
|
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
|
|
|
Δ′[src_idx...] += val
|
|
|
|
end
|
2018-07-10 08:03:09 +00:00
|
|
|
(nobacksies(:repeat, Δ′),)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2018-05-23 00:39:45 +00:00
|
|
|
end
|
|
|
|
|
2018-05-02 06:37:30 +00:00
|
|
|
for f in [:vcat, :hcat]
|
2018-07-12 19:42:32 +00:00
|
|
|
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
2018-05-02 06:37:30 +00:00
|
|
|
@eval begin
|
2018-05-07 12:03:52 +00:00
|
|
|
# This section is a bit of a hack since julia doesn't have a standardised
|
|
|
|
# promotion mechanism for concatenation yet
|
|
|
|
# https://github.com/JuliaLang/julia/pull/20815
|
2018-05-02 06:37:30 +00:00
|
|
|
|
2018-05-07 12:03:52 +00:00
|
|
|
# It should support tracked concatenation with rank ∈ (1,2) with a
|
|
|
|
# TrackedArray anywhere among the arguments This works as long as base has
|
|
|
|
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
2018-07-12 19:42:32 +00:00
|
|
|
Base.$f(a::$UArray...) = track($f, a...)
|
2018-05-02 12:57:32 +00:00
|
|
|
|
2018-05-07 12:03:52 +00:00
|
|
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
|
|
|
# first
|
2018-05-02 12:57:32 +00:00
|
|
|
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
2018-07-12 19:42:32 +00:00
|
|
|
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
2018-05-02 12:57:32 +00:00
|
|
|
|
2018-05-07 12:03:52 +00:00
|
|
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
|
|
|
# second
|
2018-05-02 12:57:32 +00:00
|
|
|
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
2018-07-12 19:42:32 +00:00
|
|
|
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
|
|
|
c::$UArray...) =
|
2018-05-07 12:03:52 +00:00
|
|
|
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
2018-05-02 06:37:30 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
@grad function vcat(xs...)
|
2018-07-10 08:03:09 +00:00
|
|
|
vcat(data.(xs)...), function (Δ)
|
2018-07-09 12:39:10 +00:00
|
|
|
start = 0
|
|
|
|
Δs = [begin
|
|
|
|
i = map(_ -> :, size(xsi)) |> Base.tail
|
|
|
|
d = Δ[start+1:start+size(xsi,1), i...]
|
|
|
|
start += size(xsi, 1)
|
|
|
|
d
|
|
|
|
end for xsi in xs]
|
|
|
|
return (Δs...,)
|
2017-12-08 15:10:09 +00:00
|
|
|
end
|
2017-09-05 06:11:28 +00:00
|
|
|
end
|
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
@grad function hcat(xs...)
|
2018-07-10 08:03:09 +00:00
|
|
|
hcat(data.(xs)...), function (Δ)
|
2018-07-09 12:39:10 +00:00
|
|
|
start = 0
|
|
|
|
Δs = [begin
|
|
|
|
d = if ndims(xsi) == 1
|
|
|
|
Δ[:, start+1]
|
|
|
|
else
|
|
|
|
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
|
|
|
Δ[:, start+1:start+size(xsi,2), i...]
|
|
|
|
end
|
|
|
|
start += size(xsi, 2)
|
|
|
|
d
|
|
|
|
end for xsi in xs]
|
|
|
|
return (Δs...,)
|
2018-05-02 13:56:08 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-08-03 14:14:10 +00:00
|
|
|
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
|
|
|
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
|
|
|
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
|
|
|
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
2018-05-02 06:37:30 +00:00
|
|
|
|
2018-07-12 19:42:32 +00:00
|
|
|
@grad function cat(Xs...; dims)
|
|
|
|
cat(data.(Xs)..., dims = dims), function (Δ)
|
2018-07-12 19:59:07 +00:00
|
|
|
start = ntuple(i -> 0, Val(ndims(Δ)))
|
2018-07-09 12:39:10 +00:00
|
|
|
Δs = [begin
|
|
|
|
dim_xs = 1:ndims(xs)
|
2018-07-12 19:59:07 +00:00
|
|
|
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
|
|
|
|
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
|
2018-07-09 12:39:10 +00:00
|
|
|
d = reshape(Δ[xs_in_Δ...],size(xs))
|
|
|
|
start = start .+ till_xs
|
|
|
|
d
|
|
|
|
end for xs in Xs]
|
2018-07-12 19:42:32 +00:00
|
|
|
return (Δs...,)
|
2018-05-02 13:56:08 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2018-04-02 20:09:57 +00:00
|
|
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
|
|
|
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
|
|
|
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
2018-02-08 19:27:57 +00:00
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
2017-12-15 16:18:16 +00:00
|
|
|
|
2018-02-28 02:19:58 +00:00
|
|
|
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
2018-02-08 19:27:57 +00:00
|
|
|
|
2018-02-16 14:15:40 +00:00
|
|
|
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
2018-02-08 19:27:57 +00:00
|
|
|
m1, n1 = size(mat1)
|
|
|
|
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
|
|
|
|
|
|
|
m2, n2 = size(mat2)
|
|
|
|
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
|
|
|
|
|
|
|
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
|
|
|
end
|
|
|
|
|
2018-02-16 14:15:40 +00:00
|
|
|
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
|
|
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
|
|
|
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|
|
|
|
2018-09-19 12:08:30 +00:00
|
|
|
|
|
|
|
inv(A::TrackedArray) = Tracker.track(inv, A)
|
|
|
|
@grad function inv(A)
|
|
|
|
return inv(Tracker.data(A)), function (Δ)
|
|
|
|
Ainv = inv(A)
|
|
|
|
∇A = - Ainv' * Δ * Ainv'
|
|
|
|
return (∇A, )
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
# (/) rdivide
|
|
|
|
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
|
|
|
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
|
|
|
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
|
|
|
@grad function (A / B)
|
|
|
|
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
|
|
|
Binv = inv(B)
|
|
|
|
∇B = - Binv' * A' * Δ * Binv'
|
|
|
|
return (Δ * Binv', ∇B)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
|
|
|
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
|
|
|
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
|
|
|
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
|
|
|
@grad function (A \ B)
|
|
|
|
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
|
|
|
Ainv = inv(A)
|
|
|
|
∇A = - Ainv' * Δ * B' * Ainv'
|
|
|
|
return (∇A, Ainv' * Δ)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# Reductions
|
|
|
|
|
2018-07-19 07:44:15 +00:00
|
|
|
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
2018-02-09 19:00:26 +00:00
|
|
|
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2018-07-19 07:44:15 +00:00
|
|
|
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
|
|
|
Δ -> (zero(xs) .+ Δ, )
|
2017-08-22 11:24:08 +00:00
|
|
|
|
2018-03-06 10:01:19 +00:00
|
|
|
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
|
|
|
Base.prod(xs::TrackedArray) = track(prod, xs)
|
|
|
|
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
2018-07-18 06:41:10 +00:00
|
|
|
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ -> (nobacksies(:sum,
|
|
|
|
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
|
|
|
nothing)
|
2018-03-06 10:01:19 +00:00
|
|
|
|
2017-09-07 01:21:35 +00:00
|
|
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
2017-09-02 03:33:05 +00:00
|
|
|
|
2018-07-19 08:58:43 +00:00
|
|
|
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
2017-10-30 08:21:02 +00:00
|
|
|
|
2018-07-19 08:58:43 +00:00
|
|
|
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
|
|
|
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
2018-04-27 21:14:01 +00:00
|
|
|
|
2018-06-20 14:18:07 +00:00
|
|
|
import LinearAlgebra: dot
|
|
|
|
|
|
|
|
dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
|
|
dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
|
|
|
dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
2017-12-12 17:23:15 +00:00
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
2017-12-12 17:23:15 +00:00
|
|
|
|
2017-11-21 16:04:04 +00:00
|
|
|
# Hacks to get std working
|
2018-07-19 08:58:43 +00:00
|
|
|
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
|
|
|
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
|
|
|
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
2017-11-21 16:04:04 +00:00
|
|
|
|
2018-07-18 18:20:00 +00:00
|
|
|
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
2018-03-05 17:24:46 +00:00
|
|
|
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
2018-02-09 19:00:26 +00:00
|
|
|
|
2018-07-19 08:58:43 +00:00
|
|
|
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
|
|
|
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
|
|
|
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
2017-10-30 08:21:02 +00:00
|
|
|
|
2018-07-19 08:58:43 +00:00
|
|
|
@grad function maximum(xs; dims = dims)
|
|
|
|
maximum(data(xs), dims = dims), function (Δ)
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ′ = zero(xs)
|
2018-07-19 08:58:43 +00:00
|
|
|
_, i = findmax(data(xs), dims = dims)
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ′[i] = data(Δ)
|
2018-07-19 08:58:43 +00:00
|
|
|
return (nobacksies(:maximum, Δ′),)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2018-04-27 21:14:01 +00:00
|
|
|
end
|
2018-07-18 18:20:00 +00:00
|
|
|
|
2018-07-19 08:58:43 +00:00
|
|
|
@grad function minimum(xs; dims = dims)
|
|
|
|
minimum(data(xs), dims = dims), function (Δ)
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ′ = zero(xs)
|
2018-07-19 08:58:43 +00:00
|
|
|
_, i = findmin(data(xs), dims = dims)
|
2018-07-10 08:03:09 +00:00
|
|
|
Δ′[i] = data(Δ)
|
2018-07-19 08:58:43 +00:00
|
|
|
return (nobacksies(:minimum, Δ′),)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2018-04-27 21:14:01 +00:00
|
|
|
end
|
|
|
|
|
2017-08-22 11:24:08 +00:00
|
|
|
# BLAS
|
|
|
|
|
2018-06-20 14:18:07 +00:00
|
|
|
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
2018-02-05 18:29:35 +00:00
|
|
|
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
2018-07-30 16:04:18 +00:00
|
|
|
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
2017-08-20 12:48:43 +00:00
|
|
|
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
2018-07-30 16:04:18 +00:00
|
|
|
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
2018-07-10 08:03:09 +00:00
|
|
|
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
2018-07-30 16:04:18 +00:00
|
|
|
x::AbstractVector * y::TrackedVector = track(*, x, y)
|
2018-06-20 13:11:31 +00:00
|
|
|
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-06-20 13:11:31 +00:00
|
|
|
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
|
|
|
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
2017-11-07 19:34:27 +00:00
|
|
|
|
2017-08-23 01:03:17 +00:00
|
|
|
# NNlib
|
|
|
|
|
2017-12-14 18:48:38 +00:00
|
|
|
using NNlib
|
2018-05-30 10:23:57 +00:00
|
|
|
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, maxpool, meanpool
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
softmax(xs::TrackedArray) = track(softmax, xs)
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
2017-08-23 01:03:17 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
2018-01-21 07:20:59 +00:00
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
2018-05-21 19:20:43 +00:00
|
|
|
|
2018-08-07 19:50:37 +00:00
|
|
|
depthwiseconv(x::TrackedArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
|
|
|
depthwiseconv(x::AbstractArray, w::TrackedArray; kw...) = track(depthwiseconv, x, w; kw...)
|
|
|
|
depthwiseconv(x::TrackedArray, w::AbstractArray; kw...) = track(depthwiseconv, x, w; kw...)
|
2018-07-13 08:42:46 +00:00
|
|
|
|
|
|
|
@grad depthwiseconv(x, w; kw...) =
|
|
|
|
depthwiseconv(data(x), data(w); kw...),
|
|
|
|
Δ -> nobacksies(:depthwiseconv,
|
|
|
|
(NNlib.∇depthwiseconv_data(data.((Δ, x, w))...; kw...),
|
|
|
|
NNlib.∇depthwiseconv_filter(data.((Δ, x, w))...; kw...)))
|
2018-05-30 10:23:57 +00:00
|
|
|
|
2018-08-03 14:14:10 +00:00
|
|
|
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
|
|
|
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
|
|
|
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
2017-12-14 18:48:38 +00:00
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
@grad conv(x, w; kw...) =
|
2018-07-10 08:03:09 +00:00
|
|
|
conv(data(x), data(w); kw...),
|
|
|
|
Δ -> nobacksies(:conv,
|
|
|
|
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
|
|
|
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2018-08-03 14:14:10 +00:00
|
|
|
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
@grad function maxpool(x, k; kw...)
|
2018-07-10 08:03:09 +00:00
|
|
|
y = maxpool(data(x), k; kw...)
|
|
|
|
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2018-02-26 22:43:07 +00:00
|
|
|
|
2018-08-03 14:14:10 +00:00
|
|
|
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
2018-02-26 22:43:07 +00:00
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
@grad function meanpool(x, k; kw...)
|
2018-07-10 08:03:09 +00:00
|
|
|
y = meanpool(data(x), k; kw...)
|
|
|
|
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
2017-12-15 02:29:14 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
# Broadcasting
|
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
using ForwardDiff: Dual, partials, value
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
2018-07-12 18:28:30 +00:00
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
unbroadcast(x::AbstractArray, Δ) =
|
|
|
|
size(x) == size(Δ) ? Δ :
|
|
|
|
length(x) == length(Δ) ? trim(x, Δ) :
|
|
|
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
2018-10-10 11:26:03 +00:00
|
|
|
unbroadcast(x::Base.RefValue, _) = nothing
|
2017-08-27 08:49:42 +00:00
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
dual(x, p) = x
|
|
|
|
dual(x::Real, p) = Dual(x, p)
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N}
|
|
|
|
dargs = ntuple(j -> dual(args[j], i==j), Val(N))
|
|
|
|
return Δ * f(dargs...).partials[1]
|
2017-08-28 00:40:59 +00:00
|
|
|
end
|
|
|
|
|
2018-08-24 13:07:08 +00:00
|
|
|
@inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N}
|
|
|
|
y = broadcast(f, data.(args)...)
|
|
|
|
eltype(y) <: Real || return y
|
|
|
|
eltype(y) == Bool && return y
|
|
|
|
function back(Δ)
|
2018-09-07 01:05:03 +00:00
|
|
|
Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N))
|
|
|
|
dxs = map(unbroadcast, args, Δargs)
|
|
|
|
return dxs
|
2018-07-09 12:39:10 +00:00
|
|
|
end
|
|
|
|
# So we can return non-tracked arrays
|
2018-07-09 18:44:14 +00:00
|
|
|
track(Call(back, tracker.(args)), y)
|
2017-08-19 09:14:50 +00:00
|
|
|
end
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-08-20 13:41:46 +00:00
|
|
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
2018-07-12 18:28:30 +00:00
|
|
|
|
|
|
|
struct TrackedStyle <: BroadcastStyle end
|
|
|
|
|
|
|
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
|
|
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
|
|
|
|
2018-08-20 13:41:46 +00:00
|
|
|
# We have to re-build the original broadcast struct to get the appropriate array
|
|
|
|
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
|
|
|
broadcast_rebuild(xs) = data(xs)
|
|
|
|
|
|
|
|
broadcast_rebuild(bc::Broadcasted) =
|
|
|
|
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
|
|
|
|
|
|
|
preprocess(x) = x
|
|
|
|
|
|
|
|
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
|
|
|
bc1 = Broadcast.flatten(bc)
|
|
|
|
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
|
|
|
∇broadcast(bc2.f, bc1.args...)
|
2018-07-12 18:28:30 +00:00
|
|
|
end
|
2018-08-20 12:08:04 +00:00
|
|
|
|
|
|
|
using Requires
|
|
|
|
|
|
|
|
# https://github.com/FluxML/Flux.jl/issues/353
|
2018-08-28 10:02:38 +00:00
|
|
|
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
2018-08-20 12:08:04 +00:00
|
|
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
|
|
|
isflat(bc) && return bc
|
|
|
|
args = cat_nested(bc)
|
|
|
|
let makeargs = make_makeargs(bc), f = bc.f
|
|
|
|
newf = @inline function(args::Vararg{Any,N}) where N
|
|
|
|
f(makeargs(args...)...)
|
|
|
|
end
|
|
|
|
return Broadcasted{Style}(newf, args, bc.axes)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
|
|
|
bc = t[1]
|
|
|
|
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
|
|
|
let makeargs = make_makeargs(makeargs, bc.args)
|
|
|
|
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
|
|
|
return @inline function(args::Vararg{Any,N}) where N
|
|
|
|
args1 = makeargs(args...)
|
|
|
|
a, b = headargs(args1...), tailargs(args1...)
|
|
|
|
(f(a...), b...)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|