gpu test fixes
This commit is contained in:
parent
c9663c1e71
commit
e2bf46b7fd
|
@ -194,7 +194,7 @@ end
|
|||
# Flux Interface
|
||||
|
||||
(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
|
||||
|
||||
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
|
|
|
@ -6,6 +6,8 @@ using Zygote
|
|||
|
||||
@testset "CuArrays" begin
|
||||
|
||||
CuArrays.allowscalar(false)
|
||||
|
||||
x = param(randn(5, 5))
|
||||
cx = gpu(x)
|
||||
@test cx isa CuArray
|
||||
|
@ -14,7 +16,7 @@ cx = gpu(x)
|
|||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx isa CuArray
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
@test (cx .+ 1) isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
|
@ -32,14 +34,14 @@ ys = Flux.onehotbatch(1:5,1:5)
|
|||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
fwd, back = Zygote.forward(sum, l)
|
||||
back(one(Float64))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
|
||||
c = gpu(CrossCor((2,2),3=>4))
|
||||
x = gpu(rand(10, 10, 3, 2))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
fwd, back = Zygote.forward(sum, l)
|
||||
back(one(Float64))
|
||||
@test gradient(x -> sum(c(x)), x)[1] isa CuArray
|
||||
|
||||
end
|
||||
|
||||
|
|
|
@ -234,7 +234,6 @@ end
|
|||
@test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1.
|
||||
|
||||
x′ = m(x)
|
||||
println(x′[1])
|
||||
@test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
|
|
Loading…
Reference in New Issue