fixes #367
This commit is contained in:
parent
86cf22675f
commit
7d6ec2365f
|
@ -48,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
|
|
|
@ -26,6 +26,10 @@ x = [1,2,3]
|
|||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
xs = param(rand(5,5))
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
|
|
Loading…
Reference in New Issue