some more derivatives
This commit is contained in:
parent
380d32dac9
commit
6c1a38e157
@ -43,10 +43,17 @@ function back!(x::TrackedArray, Δ)
|
|||||||
back!(x.f, Δ)
|
back!(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims, Base.similar].args
|
# Fallthrough methods
|
||||||
|
|
||||||
|
for f in :[Base.size, Base.ndims].args
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||||
|
similar(data(x), dims...)
|
||||||
|
|
||||||
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||||
if repr
|
if repr
|
||||||
print(io, "TrackedArray(")
|
print(io, "TrackedArray(")
|
||||||
@ -58,4 +65,6 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
include("lib.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,10 +1,52 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b))
|
Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
||||||
a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b))
|
|
||||||
|
|
||||||
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
|
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||||
|
Δ′ = zeros(xs)
|
||||||
|
Δ′[i...] = Δ
|
||||||
|
back!(xs, Δ′)
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||||
|
|
||||||
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
|
||||||
|
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix)
|
||||||
back!(a, A_mul_Bt(Δ, data(b)))
|
back!(a, A_mul_Bt(Δ, data(b)))
|
||||||
back!(b, At_mul_B(data(a), Δ))
|
back!(b, At_mul_B(data(a), Δ))
|
||||||
|
end
|
||||||
|
|
||||||
|
# Broadcasting
|
||||||
|
|
||||||
|
using ForwardDiff: Dual, partials
|
||||||
|
|
||||||
|
struct Broadcasted{T}
|
||||||
|
data::T
|
||||||
|
end
|
||||||
|
|
||||||
|
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
|
||||||
|
|
||||||
|
dualify(xs, n) = xs
|
||||||
|
dualify(xs::TrackedArray, ps) = Dual.(data(xs), Ref(ps))
|
||||||
|
|
||||||
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||||
|
dargs = ntuple(i -> dualify(args[i], ntuple(j -> i==j, Val{N})), Val{N})
|
||||||
|
TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(b::Broadcasted, Δ, args...)
|
||||||
|
Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args))
|
||||||
|
back!.(args, Δargs)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||||
|
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
||||||
|
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
||||||
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
|
||||||
|
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
|
||||||
|
|
||||||
|
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
|
||||||
|
Loading…
Reference in New Issue
Block a user