2018-02-07 17:43:25 +00:00
struct TrackedArray { T , N , A <: AbstractArray { T , N } } <: AbstractArray { T , N }
tracker :: Tracked { A }
data :: A
grad :: A
TrackedArray { T , N , A } ( t :: Tracked { A } , data :: A ) where { T , N , A } = new ( t , data )
TrackedArray { T , N , A } ( t :: Tracked { A } , data :: A , grad :: A ) where { T , N , A } = new ( t , data , grad )
end
tracker ( x :: TrackedArray ) = x . tracker
TrackedVector { T , A } = TrackedArray { T , 1 , A }
TrackedMatrix { T , A } = TrackedArray { T , 2 , A }
TrackedVecOrMat { T , A } = Union { TrackedVector { T , A } , TrackedMatrix { T , A } }
2018-02-07 20:39:36 +00:00
track ( c :: Call , x :: AbstractArray ) = TrackedArray ( c , x )
2018-02-07 17:43:25 +00:00
TrackedArray ( c :: Call , x :: A ) where A <: AbstractArray =
TrackedArray { eltype ( A ) , ndims ( A ) , A } ( Tracked { A } ( c , x ) , x )
TrackedArray ( c :: Call , x :: A , Δ :: A ) where A <: AbstractArray =
TrackedArray { eltype ( A ) , ndims ( A ) , A } ( Tracked { A } ( c , x , Δ ) , x , Δ )
TrackedArray ( x :: AbstractArray ) = TrackedArray ( Call ( nothing ) , x , zeros ( x ) )
2018-02-13 10:20:38 +00:00
Base . eltype ( x :: Type { <: TrackedArray { T } } ) where T <: Real = TrackedReal { T }
2018-02-07 17:43:25 +00:00
Base . show ( io :: IO , :: Type { TrackedArray { T , N , A } } ) where { T , N , A <: AbstractArray { T , N } } =
print ( io , " TrackedArray{…, $A } " )
function Base . showarray ( io :: IO , X :: TrackedArray , repr :: Bool = true ; header = true )
if repr
print ( io , " param( " )
Base . showarray ( io , data ( X ) , true )
print ( io , " ) " )
else
header && print ( io , " Tracked " )
Base . showarray ( io , data ( X ) , false , header = header )
end
end
Base . setindex! ( xs :: TrackedArray , v , i ... ) =
error ( " Can't differentiate `setindex!` " )
2018-04-30 11:09:15 +00:00
back! ( :: TrackedArray ) = error ( " Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)` " )
2018-02-07 20:39:36 +00:00
2018-02-07 17:43:25 +00:00
# Fallthrough methods
for f in : [ Base . size , Base . ndims ] . args
@eval @inline $ f ( x :: TrackedArray , a ... ) = $ f ( data ( x ) , a ... )
end
Base . similar ( x :: TrackedArray , dims :: Union { AbstractUnitRange , Integer } ... ) =
similar ( data ( x ) , dims ... )
Base . similar ( x :: TrackedArray , T :: Type ) = similar ( data ( x ) , T )
2018-02-07 20:39:36 +00:00
Base . : ( == ) ( x :: TrackedArray , y ) = data ( x ) == y
Base . : ( == ) ( y , x :: TrackedArray ) = y == data ( x )
Base . : ( == ) ( x :: TrackedArray , y :: TrackedArray ) = data ( x ) == data ( y )
2018-02-07 17:43:25 +00:00
# Array Stdlib
2018-02-07 20:39:36 +00:00
Base . getindex ( xs :: TrackedArray , i ... ) = track ( getindex , xs , i ... )
2017-08-19 09:14:50 +00:00
2017-09-07 03:09:32 +00:00
function back ( :: typeof ( getindex ) , Δ , xs :: TrackedArray , i ... )
2017-09-07 01:21:35 +00:00
Δ′ = zeros ( xs . data )
2018-02-07 20:39:36 +00:00
Δ′ [ i ... ] = Δ
2017-09-07 03:09:32 +00:00
@back ( xs , Δ′ )
2017-08-19 15:02:19 +00:00
end
2018-02-07 20:39:36 +00:00
Base . : - ( xs :: TrackedArray ) = track ( - , xs )
2017-08-19 15:02:19 +00:00
2017-09-07 03:09:32 +00:00
back ( :: typeof ( - ) , Δ , xs :: TrackedArray ) = back ( xs , - Δ )
2017-08-23 16:50:43 +00:00
2018-02-07 20:39:36 +00:00
Base . transpose ( xs :: TrackedArray ) = track ( transpose , xs )
Base . ctranspose ( xs :: TrackedArray ) = track ( ctranspose , xs )
2017-09-01 15:42:18 +00:00
2017-09-07 03:09:32 +00:00
back ( :: typeof ( transpose ) , Δ , xs ) = @back ( xs , trim ( xs , Δ . ' ) )
back ( :: typeof ( ctranspose ) , Δ , xs ) = @back ( xs , trim ( xs , Δ ' ) )
2017-09-03 21:10:23 +00:00
2018-02-07 20:39:36 +00:00
Base . repmat ( x :: TrackedVecOrMat , a :: Integer ... ) = track ( repmat , x , a ... )
Base . repmat ( x :: TrackedVecOrMat , a :: Int64 ... ) = track ( repmat , x , a ... )
2017-09-05 06:12:53 +00:00
2018-03-07 12:40:00 +00:00
function back ( :: typeof ( repmat ) , Δ , xs :: TrackedVecOrMat , m , n = 1 )
Δ′ = similar ( xs . data )
S = size ( xs . data )
for ( i , v ) in enumerate ( Δ )
d1 = divrem ( i - 1 , S [ 1 ] * m )
x = d1 [ 2 ] % S [ 1 ] + 1
y = d1 [ 1 ] % S [ 2 ] + 1
Δ′ [ x , y ] += v
end
back ( xs , Δ′ )
end
2018-05-02 06:37:30 +00:00
for f in [ :vcat , :hcat ]
@eval begin
2018-05-02 12:57:32 +00:00
# This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815
2018-05-02 06:37:30 +00:00
2018-05-02 12:57:32 +00:00
# It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments
# This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base . $ f ( a :: Union { TrackedArray , Vector , RowVector , Matrix } ... ) = track ( $ f , a ... )
# It should support tracked concatenation with rank>2 if the TrackedArray is first
Base . $ f ( a :: TrackedArray , b :: AbstractArray ... ) = track ( $ f , a , b ... )
Base . $ f ( a :: TrackedArray , b :: Union { TrackedArray , Vector , RowVector , Matrix } ... ) = track ( $ f , a , b ... ) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is second
Base . $ f ( a :: Array , b :: TrackedArray , c :: AbstractArray ... ) = track ( $ f , a , b , c ... )
Base . $ f ( a :: Union { Vector , RowVector , Matrix } , b :: TrackedArray , c :: Union { TrackedArray , Vector , RowVector , Matrix } ... ) = track ( $ f , a , b , c ... ) # resolves ambiguity introduced by previous row
2018-05-02 06:37:30 +00:00
end
end
2017-12-08 15:10:09 +00:00
function back ( :: typeof ( vcat ) , Δ , xs ... )
start = 0
for xsi in xs
2018-05-02 06:30:11 +00:00
i = map ( _ -> : , size ( xsi ) ) |> Base . tail
2017-12-08 15:10:09 +00:00
@back ( xsi , Δ [ start + 1 : start + size ( xsi , 1 ) , i ... ] )
start += size ( xsi , 1 )
end
2017-09-05 06:11:28 +00:00
end
2018-05-02 13:56:08 +00:00
function back ( :: typeof ( hcat ) , Δ , xs ... )
start = 0
for xsi in xs
if ndims ( xsi ) == 1
@back ( xsi , Δ [ : , start + 1 ] )
else
2018-05-02 06:30:11 +00:00
i = map ( _ -> : , size ( xsi ) ) |> Base . tail |> Base . tail
2018-05-02 13:56:08 +00:00
@back ( xsi , Δ [ : , start + 1 : start + size ( xsi , 2 ) , i ... ] )
end
start += size ( xsi , 2 )
end
end
2018-05-02 12:57:32 +00:00
Base . cat ( dims , a :: TrackedArray , b :: AbstractArray ... ) = track ( cat , dims , a , b ... )
Base . cat ( dims , a :: Union { RowVector , Array } , b :: TrackedArray , c :: AbstractArray ... ) = track ( cat , dims , a , b , c ... )
2018-05-02 06:37:30 +00:00
2018-05-02 07:03:54 +00:00
function back ( :: typeof ( cat ) , Δ , dims , Xs ... )
start = ntuple ( i -> 0 , Val { ndims ( Δ ) } )
for xs in Xs
dim_xs = 1 : ndims ( xs )
till_xs = ntuple ( ( i -> i in dims ? ( i in dim_xs ? size ( xs , i ) : 1 ) : 0 ) , Val { ndims ( Δ ) } )
xs_in_Δ = ntuple ( i -> till_xs [ i ] > 0 ? ( start [ i ] + 1 : start [ i ] + till_xs [ i ] ) : Colon ( ) , Val { ndims ( Δ ) } )
@back ( xs , reshape ( Δ [ xs_in_Δ ... ] , size ( xs ) ) )
start = start .+ till_xs
2018-05-02 13:56:08 +00:00
end
end
2018-04-02 20:09:57 +00:00
Base . reshape ( xs :: TrackedArray , dims :: Union { Colon , Int64 } ... ) = reshape ( xs , dims )
Base . reshape ( xs :: TrackedArray , dims :: Tuple { Vararg { Union { Int64 , Colon } } } ) = reshape ( xs , Base . _reshape_uncolon ( xs , dims ) )
Base . reshape ( xs :: TrackedArray , dims :: Tuple { Vararg { Int64 } } ) = track ( reshape , xs , dims )
2018-02-08 19:27:57 +00:00
2017-12-15 16:18:16 +00:00
back ( :: typeof ( reshape ) , Δ , xs :: TrackedArray , _ ... ) =
back ( xs , reshape ( Δ , size ( xs ) ) )
2018-02-28 02:19:58 +00:00
Base . permutedims ( xs :: TrackedArray , dims ) = track ( permutedims , xs , dims )
back ( :: typeof ( permutedims ) , Δ , xs :: TrackedArray , dims ) = back ( xs , permutedims ( Δ , invperm ( dims ) ) )
2018-02-08 19:27:57 +00:00
2018-02-16 14:15:40 +00:00
function _kron ( mat1 :: AbstractMatrix , mat2 :: AbstractMatrix )
2018-02-08 19:27:57 +00:00
m1 , n1 = size ( mat1 )
mat1_rsh = reshape ( mat1 , ( 1 , m1 , 1 , n1 ) )
m2 , n2 = size ( mat2 )
mat2_rsh = reshape ( mat2 , ( m2 , 1 , n2 , 1 ) )
return reshape ( mat1_rsh .* mat2_rsh , ( m1 * m2 , n1 * n2 ) )
end
2018-02-16 14:15:40 +00:00
Base . kron ( a :: TrackedMatrix , b :: TrackedMatrix ) = _kron ( a , b )
Base . kron ( a :: TrackedMatrix , b :: AbstractMatrix ) = _kron ( a , b )
Base . kron ( a :: AbstractMatrix , b :: TrackedMatrix ) = _kron ( a , b )
2017-08-22 11:24:08 +00:00
# Reductions
2018-02-07 20:39:36 +00:00
Base . sum ( xs :: TrackedArray , dim ) = track ( sum , xs , dim )
Base . sum ( xs :: TrackedArray ) = track ( sum , xs )
2018-02-09 19:00:26 +00:00
Base . sum ( f :: Union { Function , Type } , xs :: TrackedArray ) = sum ( f . ( xs ) )
2017-08-22 11:24:08 +00:00
2017-09-07 03:09:32 +00:00
back ( :: typeof ( sum ) , Δ , xs :: TrackedArray , dim ... ) = back ( xs , similar ( xs . data ) .= Δ )
2017-08-22 11:24:08 +00:00
2018-03-06 10:01:19 +00:00
Base . prod ( xs :: TrackedArray , dim ) = track ( prod , xs , dim )
Base . prod ( xs :: TrackedArray ) = track ( prod , xs )
Base . prod ( f :: Union { Function , Type } , xs :: TrackedArray ) = prod ( f . ( xs ) )
2018-03-07 08:24:44 +00:00
back ( :: typeof ( prod ) , Δ , xs :: TrackedArray , dim ... ) = back ( xs , similar ( xs . data ) .= ( prod ( xs . data , dim ... ) ./ xs . data ) .* Δ )
2018-03-07 13:01:07 +00:00
back ( :: typeof ( prod ) , Δ , xs :: TrackedArray ) = back ( xs , similar ( xs . data ) .= ( reshape ( .* ( circshift . ( [ reshape ( xs . data , length ( xs . data ) ) ] , 1 : length ( xs . data ) - 1 ) ... ) , size ( xs . data ) ) ) .* Δ )
2018-03-06 10:01:19 +00:00
2017-09-07 01:21:35 +00:00
Base . findfirst ( xs :: TrackedArray , args ... ) = findfirst ( xs . data , args ... )
2017-09-02 03:33:05 +00:00
2018-02-07 20:39:36 +00:00
Base . mean ( xs :: TrackedArray ) = track ( mean , xs )
Base . mean ( xs :: TrackedArray , region ) = track ( mean , xs , region )
2017-10-30 08:21:02 +00:00
2018-04-27 21:14:01 +00:00
Base . maximum ( xs :: TrackedArray ) = track ( maximum , xs )
Base . maximum ( xs :: TrackedArray , region ) = track ( maximum , xs , region )
Base . minimum ( xs :: TrackedArray ) = track ( minimum , xs )
Base . minimum ( xs :: TrackedArray , region ) = track ( minimum , xs , region )
2018-02-07 20:39:36 +00:00
LinAlg . dot ( xs :: TrackedVector , ys :: TrackedVector ) = track ( dot , xs , ys )
LinAlg . dot ( xs :: AbstractVector , ys :: TrackedVector ) = track ( dot , xs , ys )
LinAlg . dot ( xs :: TrackedVector , ys :: AbstractVector ) = track ( dot , xs , ys )
2017-12-12 17:23:15 +00:00
function back ( :: typeof ( dot ) , Δ , xs , ys )
2018-02-13 13:31:35 +00:00
@back ( xs , Δ .* data ( ys ) )
@back ( ys , Δ .* data ( xs ) )
2017-12-12 17:23:15 +00:00
end
2017-11-21 16:04:04 +00:00
# Hacks to get std working
Base . std ( x :: TrackedArray ; mean = Base . mean ( x ) ) =
sqrt . ( sum ( ( x .- mean ) .^ 2 ) ./ ( length ( x ) - 1 ) )
Base . std ( x :: TrackedArray , dim ; mean = Base . mean ( x , dim ) ) =
sqrt . ( sum ( ( x .- mean ) .^ 2 , dim ) ./ ( size ( x , dim ) - 1 ) )
2018-03-05 17:24:46 +00:00
Base . vecnorm ( x :: TrackedArray , p :: Real = 2 ) =
sum ( abs . ( x ) .^ p .+ eps ( 0f0 ) ) ^ ( 1 / p ) # avoid d(sqrt(x))/dx == Inf at 0
2018-02-09 19:00:26 +00:00
2017-10-31 10:41:44 +00:00
back ( :: typeof ( mean ) , Δ , xs :: TrackedArray ) = back ( xs , similar ( xs . data ) .= Δ ./ length ( xs . data ) )
2017-10-30 08:21:02 +00:00
back ( :: typeof ( mean ) , Δ , xs :: TrackedArray , region ) =
back ( xs , similar ( xs . data ) .= Δ ./ prod ( size ( xs . data , region ... ) ) )
2018-04-27 21:14:01 +00:00
function back ( :: typeof ( maximum ) , Δ , xs :: TrackedArray )
Δ′ = zeros ( xs . data )
_ , i = findmax ( xs . data )
Δ′ [ i ] = Δ
@back ( xs , Δ′ )
end
function back ( :: typeof ( maximum ) , Δ , xs :: TrackedArray , region )
Δ′ = zeros ( xs . data )
_ , is = findmax ( xs . data , region )
Δ′ [ is ] = Δ
@back ( xs , Δ′ )
end
function back ( :: typeof ( minimum ) , Δ , xs :: TrackedArray )
Δ′ = zeros ( xs . data )
_ , i = findmin ( xs . data )
Δ′ [ i ] = Δ
@back ( xs , Δ′ )
end
function back ( :: typeof ( minimum ) , Δ , xs :: TrackedArray , region )
Δ′ = zeros ( xs . data )
_ , is = findmin ( xs . data , region )
Δ′ [ is ] = Δ
@back ( xs , Δ′ )
end
2017-08-22 11:24:08 +00:00
# BLAS
2018-02-07 20:39:36 +00:00
Base . diagm ( x :: TrackedVector ) = track ( diagm , x )
2018-02-05 18:29:35 +00:00
back ( :: typeof ( diagm ) , Δ , x ) = @back ( x , diag ( Δ ) )
2017-12-12 17:07:39 +00:00
for f in : [ * , Ac_mul_B , A_mul_Bc ] . args
2017-11-08 22:00:19 +00:00
@eval begin
import Base . $ f
2018-02-07 20:39:36 +00:00
$ f ( a :: TrackedMatrix , b :: TrackedMatrix ) = track ( $ f , a , b )
$ f ( a :: TrackedMatrix , b :: AbstractMatrix ) = track ( $ f , a , b )
$ f ( a :: AbstractMatrix , b :: TrackedMatrix ) = track ( $ f , a , b )
2017-11-08 22:00:19 +00:00
2018-02-07 20:39:36 +00:00
$ f ( a :: TrackedMatrix , b :: TrackedVector ) = track ( $ f , a , b )
$ f ( a :: TrackedMatrix , b :: AbstractVector ) = track ( $ f , a , b )
$ f ( a :: AbstractMatrix , b :: TrackedVector ) = track ( $ f , a , b )
2017-11-08 22:00:19 +00:00
2018-02-07 20:39:36 +00:00
$ f ( a :: TrackedVector , b :: TrackedVector ) = track ( $ f , a , b )
$ f ( a :: TrackedVector , b :: AbstractVector ) = track ( $ f , a , b )
$ f ( a :: AbstractVector , b :: TrackedVector ) = track ( $ f , a , b )
2017-11-08 22:00:19 +00:00
end
end
2017-08-20 12:48:43 +00:00
2017-09-07 03:09:32 +00:00
function back ( :: typeof ( * ) , Δ , a :: AbstractMatrix , b :: AbstractVecOrMat )
@back ( a , A_mul_Bt ( Δ , data ( b ) ) )
@back ( b , At_mul_B ( data ( a ) , Δ ) )
2017-08-19 15:02:19 +00:00
end
2017-11-08 22:00:19 +00:00
function back ( :: typeof ( Ac_mul_B ) , Δ , a :: AbstractVecOrMat { <: Real } , b :: AbstractVecOrMat { <: Real } )
@back ( a , A_mul_Bt ( Δ , data ( b ) ) ' )
2017-12-12 17:07:39 +00:00
@back ( b , data ( a ) * Δ )
end
function back ( :: typeof ( A_mul_Bc ) , Δ , a :: AbstractVecOrMat { <: Real } , b :: AbstractVecOrMat { <: Real } )
@back ( a , Δ * data ( b ) )
@back ( b , At_mul_B ( data ( a ) , Δ ) ' )
2017-11-08 22:00:19 +00:00
end
2017-11-07 19:34:27 +00:00
# Fast path for matrix-vector
function back ( :: typeof ( * ) , Δ :: AbstractVector , W :: TrackedMatrix , x :: AbstractVector )
if isleaf ( W )
W . grad .+= Δ .* data ( x ) . '
else
back ( W , A_mul_Bt ( Δ , data ( x ) ) )
end
@back ( x , At_mul_B ( data ( W ) , Δ ) )
end
2017-08-23 01:03:17 +00:00
# NNlib
2017-12-14 18:48:38 +00:00
using NNlib
2018-02-26 22:43:07 +00:00
import NNlib : softmax , ∇softmax , logsoftmax , ∇logsoftmax , conv , maxpool , meanpool
2017-08-23 01:03:17 +00:00
2018-02-07 20:39:36 +00:00
softmax ( xs :: TrackedArray ) = track ( softmax , xs )
2017-08-23 01:03:17 +00:00
2017-09-07 03:09:32 +00:00
back ( :: typeof ( softmax ) , Δ , xs ) = @back ( xs , ∇softmax ( Δ , data ( xs ) ) )
2017-08-23 01:03:17 +00:00
2018-02-07 20:39:36 +00:00
logsoftmax ( xs :: TrackedArray ) = track ( logsoftmax , xs )
2018-01-21 07:20:59 +00:00
back ( :: typeof ( logsoftmax ) , Δ , xs ) = @back ( xs , ∇logsoftmax ( Δ , data ( xs ) ) )
2017-12-18 18:05:38 +00:00
# TODO: can store kwargs efficiently in namedtuples
2018-02-26 22:43:07 +00:00
_conv ( x , w , stride , pad ) = conv ( x , w , stride = stride , pad = pad )
conv ( x :: TrackedArray { <: Real , N } , w :: TrackedArray { <: Real , N } ; stride = 1 , pad = 0 ) where N =
track ( _conv , x , w , stride , pad )
conv ( x :: AbstractArray { <: Real , N } , w :: TrackedArray { <: Real , N } ; stride = 1 , pad = 0 ) where N =
track ( _conv , x , w , stride , pad )
conv ( x :: TrackedArray { <: Real , N } , w :: AbstractArray { <: Real , N } ; stride = 1 , pad = 0 ) where N =
track ( _conv , x , w , stride , pad )
function back ( :: typeof ( _conv ) , Δ , x , w , stride , pad )
@back ( x , NNlib . ∇conv_data ( Δ , data ( x ) , data ( w ) ; stride = stride , pad = pad ) )
@back ( w , NNlib . ∇conv_filter ( Δ , data ( x ) , data ( w ) ; stride = stride , pad = pad ) )
2017-12-14 18:48:38 +00:00
end
2018-03-19 19:42:04 +00:00
_maxpool ( x , k , pad , stride ) = maxpool ( x , k ; pad = pad , stride = stride )
2017-12-15 02:29:14 +00:00
2018-03-19 19:42:04 +00:00
maxpool ( x :: TrackedArray , k ; pad = map ( _ -> 0 , k ) , stride = k ) =
track ( _maxpool , x , k , pad , stride )
2017-12-15 02:29:14 +00:00
2018-03-19 19:42:04 +00:00
back_ ( :: typeof ( _maxpool ) , y , Δ , x , k , pad , stride ) =
back ( x , NNlib . ∇maxpool ( Δ , y , data ( x ) , k , pad = pad , stride = stride ) )
2018-02-26 22:43:07 +00:00
2018-03-19 19:42:04 +00:00
_meanpool ( x , k , pad , stride ) = meanpool ( x , k ; pad = pad , stride = stride )
2018-02-26 22:43:07 +00:00
2018-03-19 19:42:04 +00:00
meanpool ( x :: TrackedArray , k ; pad = map ( _ -> 0 , k ) , stride = k ) =
track ( _meanpool , x , k , pad , stride )
2018-02-26 22:43:07 +00:00
2018-03-19 19:42:04 +00:00
back_ ( :: typeof ( _meanpool ) , y , Δ , x , k , pad , stride ) =
back ( x , NNlib . ∇meanpool ( Δ , y , data ( x ) , k , pad = pad , stride = stride ) )
2017-12-15 02:29:14 +00:00
2017-08-19 15:02:19 +00:00
# Broadcasting
using ForwardDiff : Dual , partials
2018-02-05 17:22:09 +00:00
struct Broadcasted { F , T }
f :: F
2017-08-19 15:02:19 +00:00
data :: T
end
( b :: Broadcasted ) ( xs ... ) = map ( x -> x . value , b . data )
dualify ( xs , n ) = xs
2017-08-21 15:35:39 +00:00
dualify ( xs :: TrackedArray , ps ) = map ( x -> Dual ( x , ps ) , data ( xs ) )
2018-02-08 17:18:40 +00:00
dualify ( xs :: TrackedReal , ps ) = Dual ( data ( xs ) , ps )
2017-08-19 15:02:19 +00:00
function tracked_broadcast ( f , args :: Vararg { Any , N } ) where N
2017-08-20 12:35:20 +00:00
dargs = map ( ( x , i ) -> dualify ( x , ntuple ( j -> i == j , Val { N } ) ) , args , ntuple ( identity , Val { N } ) )
2018-01-15 17:00:47 +00:00
out = broadcast ( f , dargs ... )
eltype ( out ) <: Dual || return out
2018-02-05 17:22:09 +00:00
b = Broadcasted ( f , out )
2018-02-07 20:39:36 +00:00
track ( Call ( b , args ... ) , b ( ) )
2017-08-19 15:02:19 +00:00
end
2017-08-27 08:49:42 +00:00
trim ( x , Δ ) = reshape ( Δ , ntuple ( i -> size ( Δ , i ) , Val { ndims ( x ) } ) )
2018-02-07 20:39:36 +00:00
unbroadcast ( x :: AbstractArray , Δ ) =
2017-08-22 23:25:19 +00:00
size ( x ) == size ( Δ ) ? Δ :
2017-08-27 08:49:42 +00:00
trim ( x , sum ( Δ , filter ( n -> size ( x , n ) == 1 , 1 : ndims ( Δ ) ) ) )
2017-08-22 23:25:19 +00:00
2018-02-07 20:39:36 +00:00
unbroadcast ( x :: Number , Δ ) = sum ( Δ )
2017-08-28 00:40:59 +00:00
function getpartial ( Δ , x , i )
@inbounds p = getindex ( partials ( x ) , i )
return Δ * p
end
2017-09-07 03:09:32 +00:00
function back ( b :: Broadcasted , Δ , args :: Vararg { Any , N } ) where N
2017-08-28 00:40:59 +00:00
Δargs = ntuple ( i -> getpartial . ( Δ , b . data , i ) , Val { N } )
2017-09-07 03:09:32 +00:00
foreach ( ( x , Δ ) -> @back ( x , unbroadcast ( x , Δ ) ) , args , Δargs )
2017-08-19 09:14:50 +00:00
end
2017-08-19 15:02:19 +00:00
2018-02-21 23:21:20 +00:00
Base . Broadcast . _containertype ( :: Type { <: TrackedReal } ) = TrackedArray
2017-08-19 15:02:19 +00:00
Base . Broadcast . _containertype ( :: Type { <: TrackedArray } ) = TrackedArray
2017-08-23 16:50:43 +00:00
Base . Broadcast . promote_containertype ( :: Type { TrackedArray } , :: Type { TrackedArray } ) = TrackedArray
2017-08-19 15:02:19 +00:00
Base . Broadcast . promote_containertype ( :: Type { Array } , :: Type { TrackedArray } ) = TrackedArray
Base . Broadcast . promote_containertype ( :: Type { TrackedArray } , :: Type { Array } ) = TrackedArray
2017-08-23 16:21:02 +00:00
Base . Broadcast . promote_containertype ( :: Type { TrackedArray } , ct ) = TrackedArray
Base . Broadcast . promote_containertype ( ct , :: Type { TrackedArray } ) = TrackedArray
2017-08-19 15:02:19 +00:00
Base . Broadcast . broadcast_indices ( :: Type { TrackedArray } , A :: Ref ) = ( )
Base . Broadcast . broadcast_indices ( :: Type { TrackedArray } , A ) = indices ( A )
Base . Broadcast . broadcast_c ( f , :: Type { TrackedArray } , A , Bs ... ) = tracked_broadcast ( f , A , Bs ... )