closes #118
This commit is contained in:
parent
872d5b902c
commit
1beb30e19a
@ -174,9 +174,11 @@ dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
|||||||
|
|
||||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
|
out = broadcast(f, dargs...)
|
||||||
|
eltype(out) <: Dual || return out
|
||||||
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||||
# Works around a 0.6 type inference issue
|
# Works around a 0.6 type inference issue
|
||||||
b = Broadcasted(broadcast(f, dargs...))
|
b = Broadcasted(out)
|
||||||
TrackedArray(Call(b, args...), b())
|
TrackedArray(Call(b, args...), b())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -50,4 +50,6 @@ end
|
|||||||
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
@test gradtest(x -> maxpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
@test gradtest(x -> avgpool2d(x, 2), rand(10, 10, 3, 2))
|
||||||
|
|
||||||
|
@test (param([1,2,3]) .< 2) == [true, false, false]
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user