some more derivatives
This commit is contained in:
parent
380d32dac9
commit
6c1a38e157
@ -43,10 +43,17 @@ function back!(x::TrackedArray, Δ)
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.similar].args
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "TrackedArray(")
|
||||
@ -58,4 +65,6 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
|
||||
end
|
||||
end
|
||||
|
||||
include("lib.jl")
|
||||
|
||||
end
|
||||
|
@ -1,10 +1,52 @@
|
||||
import Base: *
|
||||
|
||||
a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b))
|
||||
a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b))
|
||||
Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
||||
|
||||
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
|
||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs)
|
||||
Δ′[i...] = Δ
|
||||
back!(xs, Δ′)
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
|
||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
||||
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||
|
||||
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
||||
back!(a, A_mul_Bt(Δ, data(b)))
|
||||
back!(b, At_mul_B(data(a), Δ))
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
||||
struct Broadcasted{T}
|
||||
data::T
|
||||
end
|
||||
|
||||
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
||||
|
||||
dualify(xs, n) = xs
|
||||
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
||||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N})
|
||||
TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||
end
|
||||
|
||||
function back!(b::Broadcasted, Δ, args...)
|
||||
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
||||
back!.(args, Δargs)
|
||||
return
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||
|
||||
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|
||||
|
Loading…
Reference in New Issue
Block a user