This commit is contained in:
Mike Innes 2018-08-24 14:30:39 +01:00
parent 86cf22675f
commit 7d6ec2365f
2 changed files with 5 additions and 1 deletions

View File

@ -48,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
for f in :[Base.size, Base.ndims, Base.collect].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end

View File

@ -26,6 +26,10 @@ x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
xs = param(rand(5,5))
ys = Flux.onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))