closes #118
This commit is contained in:
parent
872d5b902c
commit
1beb30e19a
|
@ -174,9 +174,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
|||
|
||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||
# Works around a 0.6 type inference issue
|
||||
b = Broadcasted(broadcast(f, dargs...))
|
||||
b = Broadcasted(out)
|
||||
TrackedArray(Call(b, args...), b())
|
||||
end
|
||||
|
||||
|
|
|
@ -50,4 +50,6 @@ end
|
|||
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||
|
||||
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue