fixes #367
This commit is contained in:
parent
86cf22675f
commit
7d6ec2365f
@ -48,7 +48,7 @@ back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back
|
|||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims].args
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -26,6 +26,10 @@ x = [1,2,3]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
|
xs = param(rand(5,5))
|
||||||
|
ys = Flux.onehotbatch(1:5,1:5)
|
||||||
|
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||||
|
|
||||||
c = gpu(Conv((2,2),3=>4))
|
c = gpu(Conv((2,2),3=>4))
|
||||||
l = c(gpu(rand(10,10,3,2)))
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
Flux.back!(sum(l))
|
Flux.back!(sum(l))
|
||||||
|
Loading…
Reference in New Issue
Block a user