val deprecations

This commit is contained in:
Mike J Innes 2018-07-12 20:59:07 +01:00
parent 474f578517
commit 89872c5a8b

View File

@ -165,11 +165,11 @@ Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_k
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)})
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = [begin
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
d
@ -350,13 +350,13 @@ end
function ∇broadcast(f, args::Vararg{Any,N}) where N
sizes = _size.(args)
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val(N))), args, ntuple(identity, Val(N)))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
y = value.(out)
back = function (Δ_)
Δ = data(Δ_)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val(N))
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
end