This commit is contained in:
Mike J Innes 2018-01-15 17:00:47 +00:00
parent 872d5b902c
commit 1beb30e19a
2 changed files with 5 additions and 1 deletions

View File

@ -174,9 +174,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(broadcast(f, dargs...))
b = Broadcasted(out)
TrackedArray(Call(b, args...), b())
end

View File

@ -50,4 +50,6 @@ end
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
@test (param([1,2,3]) .< 2) == [true, false, false]
end #testset